# Owner(s): ["oncall: quantization"]

from .common import AOMigrationTestCase


class TestAOMigrationQuantization(AOMigrationTestCase):
    r"""Modules and functions related to the
    `torch/quantization` migration to `torch/ao/quantization`.
    """

    def test_function_import_quantize(self):
        function_list = [
            "_convert",
            "_observer_forward_hook",
            "_propagate_qconfig_helper",
            "_remove_activation_post_process",
            "_remove_qconfig",
            "_add_observer_",
            "add_quant_dequant",
            "convert",
            "_get_observer_dict",
            "_get_unique_devices_",
            "_is_activation_post_process",
            "prepare",
            "prepare_qat",
            "propagate_qconfig_",
            "quantize",
            "quantize_dynamic",
            "quantize_qat",
            "_register_activation_post_process_hook",
            "swap_module",
        ]
        self._test_function_import("quantize", function_list)

    def test_function_import_stubs(self):
        function_list = [
            "QuantStub",
            "DeQuantStub",
            "QuantWrapper",
        ]
        self._test_function_import("stubs", function_list)

    def test_function_import_quantize_jit(self):
        function_list = [
            "_check_is_script_module",
            "_check_forward_method",
            "script_qconfig",
            "script_qconfig_dict",
            "fuse_conv_bn_jit",
            "_prepare_jit",
            "prepare_jit",
            "prepare_dynamic_jit",
            "_convert_jit",
            "convert_jit",
            "convert_dynamic_jit",
            "_quantize_jit",
            "quantize_jit",
            "quantize_dynamic_jit",
        ]
        self._test_function_import("quantize_jit", function_list)

    def test_function_import_fake_quantize(self):
        function_list = [
            "_is_per_channel",
            "_is_per_tensor",
            "_is_symmetric_quant",
            "FakeQuantizeBase",
            "FakeQuantize",
            "FixedQParamsFakeQuantize",
            "FusedMovingAvgObsFakeQuantize",
            "default_fake_quant",
            "default_weight_fake_quant",
            "default_fixed_qparams_range_neg1to1_fake_quant",
            "default_fixed_qparams_range_0to1_fake_quant",
            "default_per_channel_weight_fake_quant",
            "default_histogram_fake_quant",
            "default_fused_act_fake_quant",
            "default_fused_wt_fake_quant",
            "default_fused_per_channel_wt_fake_quant",
            "_is_fake_quant_script_module",
            "disable_fake_quant",
            "enable_fake_quant",
            "disable_observer",
            "enable_observer",
        ]
        self._test_function_import("fake_quantize", function_list)

    def test_function_import_fuse_modules(self):
        function_list = [
            "_fuse_modules",
            "_get_module",
            "_set_module",
            "fuse_conv_bn",
            "fuse_conv_bn_relu",
            "fuse_known_modules",
            "fuse_modules",
            "get_fuser_method",
        ]
        self._test_function_import("fuse_modules", function_list)

    def test_function_import_quant_type(self):
        function_list = [
            "QuantType",
            "_get_quant_type_to_str",
        ]
        self._test_function_import("quant_type", function_list)

    def test_function_import_observer(self):
        function_list = [
            "_PartialWrapper",
            "_with_args",
            "_with_callable_args",
            "ABC",
            "ObserverBase",
            "_ObserverBase",
            "MinMaxObserver",
            "MovingAverageMinMaxObserver",
            "PerChannelMinMaxObserver",
            "MovingAveragePerChannelMinMaxObserver",
            "HistogramObserver",
            "PlaceholderObserver",
            "RecordingObserver",
            "NoopObserver",
            "_is_activation_post_process",
            "_is_per_channel_script_obs_instance",
            "get_observer_state_dict",
            "load_observer_state_dict",
            "default_observer",
            "default_placeholder_observer",
            "default_debug_observer",
            "default_weight_observer",
            "default_histogram_observer",
            "default_per_channel_weight_observer",
            "default_dynamic_quant_observer",
            "default_float_qparams_observer",
        ]
        self._test_function_import("observer", function_list)

    def test_function_import_qconfig(self):
        function_list = [
            "QConfig",
            "default_qconfig",
            "default_debug_qconfig",
            "default_per_channel_qconfig",
            "QConfigDynamic",
            "default_dynamic_qconfig",
            "float16_dynamic_qconfig",
            "float16_static_qconfig",
            "per_channel_dynamic_qconfig",
            "float_qparams_weight_only_qconfig",
            "default_qat_qconfig",
            "default_weight_only_qconfig",
            "default_activation_only_qconfig",
            "default_qat_qconfig_v2",
            "get_default_qconfig",
            "get_default_qat_qconfig",
            "_assert_valid_qconfig",
            "QConfigAny",
            "_add_module_to_qconfig_obs_ctr",
            "qconfig_equals",
        ]
        self._test_function_import("qconfig", function_list)

    def test_function_import_quantization_mappings(self):
        function_list = [
            "no_observer_set",
            "get_default_static_quant_module_mappings",
            "get_static_quant_module_class",
            "get_dynamic_quant_module_class",
            "get_default_qat_module_mappings",
            "get_default_dynamic_quant_module_mappings",
            "get_default_qconfig_propagation_list",
            "get_default_compare_output_module_list",
            "get_default_float_to_quantized_operator_mappings",
            "get_quantized_operator",
            "_get_special_act_post_process",
            "_has_special_act_post_process",
        ]
        dict_list = [
            "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
            "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
            "DEFAULT_QAT_MODULE_MAPPINGS",
            "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
            # "_INCLUDE_QCONFIG_PROPAGATE_LIST",
            "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
            "DEFAULT_MODULE_TO_ACT_POST_PROCESS",
        ]
        self._test_function_import("quantization_mappings", function_list)
        self._test_dict_import("quantization_mappings", dict_list)

    def test_function_import_fuser_method_mappings(self):
        function_list = [
            "fuse_conv_bn",
            "fuse_conv_bn_relu",
            "fuse_linear_bn",
            "get_fuser_method",
        ]
        dict_list = ["_DEFAULT_OP_LIST_TO_FUSER_METHOD"]
        self._test_function_import("fuser_method_mappings", function_list)
        self._test_dict_import("fuser_method_mappings", dict_list)

    def test_function_import_utils(self):
        function_list = [
            "activation_dtype",
            "activation_is_int8_quantized",
            "activation_is_statically_quantized",
            "calculate_qmin_qmax",
            "check_min_max_valid",
            "get_combined_dict",
            "get_qconfig_dtypes",
            "get_qparam_dict",
            "get_quant_type",
            "get_swapped_custom_module_class",
            "getattr_from_fqn",
            "is_per_channel",
            "is_per_tensor",
            "weight_dtype",
            "weight_is_quantized",
            "weight_is_statically_quantized",
        ]
        self._test_function_import("utils", function_list)
