# Owner(s): ["oncall: quantization"]

import copy
import math

import torch
import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.qat as nnqat
import torch.ao.nn.qat.dynamic as nnqatd
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.backends.mkldnn
import torch.nn as nn
import torch.testing._internal.hypothesis_utils as hu

from hypothesis import given, strategies as st
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.ao.quantization import (
    convert,
    default_embedding_qat_qconfig,
    default_qat_qconfig,
    default_qconfig,
    default_symmetric_qnnpack_qat_qconfig,
    DeQuantStub,
    FixedQParamsFakeQuantize,
    FusedMovingAvgObsFakeQuantize,
    get_default_qat_qconfig,
    get_embedding_qat_module_mappings,
    get_embedding_static_quant_module_mappings,
    NoopObserver,
    prepare,
    prepare_qat,
    quantize_qat,
    QuantStub,
)
from torch.ao.quantization.qconfig import qconfig_equals
from torch.nn import BatchNorm2d, Conv2d, init, ReLU
from torch.nn.modules.utils import _pair
from torch.testing._internal.common_quantization import (
    DeFusedEmbeddingBagLinear,
    ManualConvLinearQATModel,
    ManualConvLinearSymmQATModel,
    ManualDropoutQATModel,
    ManualEmbeddingBagLinear,
    ManualLinearDynamicQATModel,
    ManualLinearQATModel,
    QuantizationTestCase,
    QuantStubModel,
    test_only_eval_fn,
    test_only_train_fn,
    TwoLayerLinearModel,
)

from torch.testing._internal.common_quantized import (
    override_qengines,
    override_quantized_engine,
    supported_qengines,
)

from torch.testing._internal.common_utils import skipIfNoXNNPACK

hu.assert_deadline_disabled()
from functools import reduce

class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
    """
    Conv-BN fusion implemented with explicit folding. Useful
    to verify numerical equivalency with non-folded version.
    """
    def __init__(self,
                 # ConvNd args
                 in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding,
                 groups,
                 bias,
                 padding_mode,
                 # BatchNormNd args
                 # num_features: out_channels
                 eps=1e-05, momentum=0.1,
                 # affine: True
                 # track_running_stats: True
                 # Args for this module
                 freeze_bn=False,
                 qconfig=None):
        nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
                                         stride, padding, dilation, transposed,
                                         output_padding, groups, False, padding_mode)
        assert qconfig, 'qconfig must be provided for QAT module'
        self.qconfig = qconfig
        self.eps = eps
        self.momentum = momentum
        self.freeze_bn = freeze_bn if self.training else True
        self.num_features = out_channels
        self.gamma = nn.Parameter(torch.empty(out_channels))
        self.beta = nn.Parameter(torch.empty(out_channels))
        self.affine = True
        self.track_running_stats = True
        self.running_mean = nn.Buffer(torch.zeros(out_channels))
        self.running_var = nn.Buffer(torch.ones(out_channels))
        self.num_batches_tracked = nn.Buffer(torch.tensor(0, dtype=torch.long))
        self.activation_post_process = self.qconfig.activation()
        self.weight_fake_quant = self.qconfig.weight()
        if bias:
            self.bias = nn.Parameter(torch.empty(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_bn_parameters()

    def reset_running_stats(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.num_batches_tracked.zero_()

    def reset_bn_parameters(self):
        self.reset_running_stats()
        init.uniform_(self.gamma)
        init.zeros_(self.beta)
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def reset_parameters(self):
        super().reset_parameters()
        # A hack to avoid resetting on undefined parameters
        if hasattr(self, 'gamma'):
            self.reset_bn_parameters()

    def update_bn_stats(self):
        self.freeze_bn = False
        return self

    def freeze_bn_stats(self):
        self.freeze_bn = True
        return self

    def _forward(self, input):
        # exponential_average_factor is self.momentum set to
        # (when it is available) only so that if gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and not self.freeze_bn and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # we use running statistics from the previous batch, so this is an
        # approximation of the approach mentioned in the whitepaper, but we only
        # need to do one convolution in this case instead of two
        running_std = torch.sqrt(self.running_var + self.eps)
        scale_factor = self.gamma / running_std
        scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1])
        if self.bias is not None:
            zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
        else:
            zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
        conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias)

        if self.training and not self.freeze_bn:
            # recovering original conv to get original batch_mean and batch_var
            if self.bias is not None:
                conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
            else:
                conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
            batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
            batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
            n = float(conv_orig.numel() / conv_orig.size()[1])
            unbiased_batch_var = batch_var * (n / (n - 1))
            batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps)

            conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \
                (self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1])
            self.running_mean = exponential_average_factor * batch_mean.detach() + \
                (1 - exponential_average_factor) * self.running_mean
            self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \
                (1 - exponential_average_factor) * self.running_var
        else:
            if self.bias is None:
                conv = conv + (self.beta - self.gamma * self.running_mean /
                               running_std).reshape([1, -1, 1, 1])
            else:
                conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1])
        return conv

    def extra_repr(self):
        # TODO(jerryzh): extend
        return super().extra_repr()

    def forward(self, input):
        return self.activation_post_process(self._forward(input))

    @classmethod
    def from_float(cls, mod, qconfig=None):
        r"""Create a qat module from a float module or qparams_dict
            Args: `mod` a float module, either produced by torch.ao.quantization utilities
            or directly from user
        """
        assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
            cls._FLOAT_MODULE.__name__
        if not qconfig:
            assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
            assert mod.qconfig, 'Input float module must have a valid qconfig'
            qconfig = mod.qconfig
        conv, bn = mod[0], mod[1]
        qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
                         conv.stride, conv.padding, conv.dilation,
                         conv.groups, conv.bias is not None,
                         conv.padding_mode,
                         bn.eps, bn.momentum,
                         False,
                         qconfig)
        qat_convbn.weight = conv.weight
        qat_convbn.bias = conv.bias
        qat_convbn.gamma = bn.weight
        qat_convbn.beta = bn.bias
        qat_convbn.running_mean = bn.running_mean
        qat_convbn.running_var = bn.running_var
        qat_convbn.num_batches_tracked = bn.num_batches_tracked
        return qat_convbn

class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d

    def __init__(self,
                 # ConvNd args
                 in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=None,
                 padding_mode='zeros',
                 # BatchNorm2d args
                 # num_features: out_channels
                 eps=1e-05, momentum=0.1,
                 # affine: True
                 # track_running_stats: True
                 # Args for this module
                 freeze_bn=False,
                 qconfig=None):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        _ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
                                    padding, dilation, False, _pair(0), groups, bias, padding_mode,
                                    eps, momentum, freeze_bn, qconfig)

class TestQuantizeEagerQAT(QuantizationTestCase):
    def setUp(self):
        super().setUp()

        self.embed_linear_data_train = [[torch.randint(0, 10, (12, 12), dtype=torch.long),
                                         torch.randn((12, 1), dtype=torch.float)]
                                        for _ in range(2)]
        self.embed_data = [[torch.randint(0, 10, (12, 1))]]


    def test_manual(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualLinearQATModel(qengine)
                model = prepare_qat(model)
                self.checkObservers(model)
                test_only_train_fn(model, self.train_data)
                model = convert(model)

                def checkQuantized(model):
                    self.assertEqual(type(model.fc1), nnq.Linear)
                    self.assertEqual(type(model.fc2), nnq.Linear)
                    test_only_eval_fn(model, self.calib_data)
                    self.checkScriptable(model, self.calib_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn,
                                     [self.train_data])
                checkQuantized(model)

    def test_dropout(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualDropoutQATModel(qengine)
                model = prepare_qat(model)
                self.checkObservers(model)
                test_only_train_fn(model, self.train_data)
                model = convert(model)

                def checkQuantized(model):
                    self.assertEqual(type(model.fc1), nnq.Linear)
                    self.assertEqual(type(model.dropout), nnq.Dropout)
                    test_only_eval_fn(model, self.calib_data)
                    self.checkScriptable(model, self.calib_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn,
                                     [self.train_data])
                checkQuantized(model)

    def test_eval_only_fake_quant(self):
        r"""Using FakeQuant in evaluation only mode,
        this is useful for estimating accuracy loss when we quantize the
        network
        """
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualLinearQATModel(qengine)

                model = prepare_qat(model)
                self.checkObservers(model)

                model.eval()
                test_only_eval_fn(model, self.calib_data)

    def test_conv_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualConvLinearQATModel()

                model = prepare_qat(model)
                self.checkObservers(model)

                test_only_train_fn(model, self.img_data_2d_train)
                model = convert(model)

                def checkQuantized(model):
                    self.assertEqual(type(model.conv), nnq.Conv2d)
                    self.assertEqual(type(model.fc1), nnq.Linear)
                    self.assertEqual(type(model.fc2), nnq.Linear)
                    test_only_eval_fn(model, self.img_data_2d)
                    self.checkScriptable(model, self.img_data_2d)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = ManualConvLinearQATModel()
                model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
                checkQuantized(model)

    @skipIfNoXNNPACK
    def test_conv_linear_symm(self):
        r"""Same as test_conv_linear but with Symmetric quantization.
        Supported only with qengine=qnnpack, which uses symmetric
        kernels from xnnpack library."""
        for qengine in supported_qengines:
            if qengine != 'qnnpack':
                continue
            with override_quantized_engine(qengine):
                model = ManualConvLinearSymmQATModel()

                model = prepare_qat(model)
                self.checkObservers(model)

                test_only_train_fn(model, self.img_data_2d_train)
                model = convert(model)

                def checkQuantized(model):
                    self.assertEqual(type(model.conv), nnq.Conv2d)
                    self.assertEqual(type(model.fc1), nnq.Linear)
                    self.assertEqual(type(model.fc2), nnq.Linear)
                    test_only_eval_fn(model, self.img_data_2d)
                    self.checkScriptable(model, self.img_data_2d)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = ManualConvLinearSymmQATModel()
                model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
                checkQuantized(model)

    def test_dynamic_qat_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                # Dynamic QAT without memoryless observers should fail
                with self.assertRaisesRegex(ValueError,
                                            "Dynamic QAT requires a memoryless observer." +
                                            "This means a MovingAverage observer with averaging constant equal to 1"
                                            ):
                    model = ManualLinearDynamicQATModel(default_qat_qconfig)
                    model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})

                model = ManualLinearDynamicQATModel()
                model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
                self.assertEqual(type(model.fc1), nnqatd.Linear)
                self.assertEqual(type(model.fc2), nnqatd.Linear)
                self.checkObservers(model)
                test_only_train_fn(model, self.train_data)
                model = convert(model, mapping={nnqatd.Linear: nnqd.Linear})
                self.assertEqual(type(model.fc1), nnqd.Linear)
                self.assertEqual(type(model.fc2), nnqd.Linear)
                test_only_eval_fn(model, self.calib_data)
                self.checkScriptable(model, self.calib_data)
                self.checkNoQconfig(model)

    def test_defused_embedding_bag_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = DeFusedEmbeddingBagLinear().train()
                model = prepare_qat(model, mapping=get_embedding_qat_module_mappings())
                self.checkObservers(model)

                test_only_train_fn(model, self.embed_linear_data_train)
                # make sure activation_post_process is inserted after Linear.
                self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize)
                # make sure that Embedding has a noop for activation.
                self.assertEqual(type(model.emb.activation_post_process), NoopObserver)
                # make sure that FakeQuant zero_points are correct dtype
                self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
                self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)

                model = convert(model, mapping=get_embedding_static_quant_module_mappings())

                def checkQuantized(model):
                    # make sure Embedding is now a QuantizedEmbedding
                    self.assertEqual(type(model.emb), nn.quantized.Embedding)
                    # make sure Linear is now a QuantizedLinear
                    self.assertEqual(type(model.linear), nn.quantized.Linear)

                    test_only_eval_fn(model, self.embed_data)
                    self.checkScriptable(model, self.embed_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)


    def test_embedding_bag_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualEmbeddingBagLinear().train()
                model = prepare_qat(model, mapping=get_embedding_qat_module_mappings())
                self.checkObservers(model)

                test_only_train_fn(model, self.embed_linear_data_train)
                # make sure not activation_post_process is inserted for EmbeddingBag
                self.assertFalse(hasattr(model, "activation_post_process"))
                # make sure that FakeQuant zero_points are correct dtype
                self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
                self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
                model = convert(model, mapping=get_embedding_static_quant_module_mappings())

                def checkQuantized(model):
                    # Make sure EmbeddingBag is now a quantized EmbeddingBag.
                    self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
                    # Also test that Linear has been quantized.
                    self.assertTrue(type(model.linear), nnq.Linear)

                    test_only_eval_fn(model, self.embed_data)
                    self.checkScriptable(model, self.embed_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = ManualEmbeddingBagLinear()

    def test_train_save_load_eval(self):
        r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict
        During eval, we first call prepare_qat and conver on the model and then load the state_dict
        and compare results against original model
        """
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = TwoLayerLinearModel()
                model = torch.ao.quantization.QuantWrapper(model)
                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
                model = prepare_qat(model)

                fq_state_dict = model.state_dict()

                test_only_train_fn(model, self.train_data)
                model = convert(model)

                quant_state_dict = model.state_dict()

                x = torch.rand(2, 5, dtype=torch.float)
                ref = model(x)

                # Create model again for eval. Check result using quantized state_dict
                model = TwoLayerLinearModel()
                model = torch.ao.quantization.QuantWrapper(model)
                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
                torch.ao.quantization.prepare_qat(model, inplace=True)
                new_state_dict = model.state_dict()

                # Check to make sure the model after prepare_qat has the same state_dict as original.
                self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys()))

                torch.ao.quantization.convert(model, inplace=True)
                model.eval()
                model.load_state_dict(quant_state_dict)
                out = model(x)
                self.assertEqual(ref, out)

                # Check model created using prepare has same state dict as quantized state_dict
                model = TwoLayerLinearModel()
                model.eval()
                model = torch.ao.quantization.QuantWrapper(model)
                model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
                torch.ao.quantization.prepare(model, inplace=True)
                torch.ao.quantization.convert(model, inplace=True)
                self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys()))
                model.eval()
                model.load_state_dict(quant_state_dict)
                out = model(x)
                self.assertEqual(ref, out)

    @override_qengines
    def test_forward_hooks_preserved(self):
        r"""Test QAT on preserving pre forward and post forward hooks of original model
        """
        qengine = torch.backends.quantized.engine
        model = QuantStubModel()
        counter = {
            'pre_forwards': 0,
            'forwards': 0,
        }

        def fw_pre_hook(h_module, input):
            counter['pre_forwards'] += 1

        def fw_hook(h_module, input, output):
            counter['forwards'] += 1

        model.fc.register_forward_pre_hook(fw_pre_hook)
        model.fc.register_forward_hook(fw_hook)

        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
        model = prepare_qat(model)

        def checkHooksIsPresent(model, before_convert=True):
            forward_hooks = 1
            if before_convert:
                self.assertEqual(len(model.quant._forward_hooks.values()), 1,
                                 "Quantization observer hook has disappeared")
                forward_hooks = 2
            self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
            self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
            self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1,
                             "Extra pre forward hooks have appeared on a layer")
            self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks,
                             "Extra post forward hooks have appeared on a layer")

        checkHooksIsPresent(model, True)
        x = torch.rand(2, 5, dtype=torch.float)
        model(x)
        torch.ao.quantization.convert(model, inplace=True)
        checkHooksIsPresent(model, False)

    def test_add_scalar_uses_input_qparams(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.quant = torch.ao.quantization.QuantStub()
                self.ff = torch.ao.nn.quantized.FloatFunctional()

            def forward(self, x):
                x = self.quant(x)
                x = self.ff.add_scalar(x, 1.0)
                return x

        m = M()
        m.qconfig = torch.ao.quantization.default_qconfig
        mp = torch.ao.quantization.prepare_qat(m)
        mp(torch.randn(4, 4))
        mq = torch.ao.quantization.convert(mp)
        res = mq(torch.randn(4, 4))
        eps = 1e-5
        self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps)

    def test_mul_scalar_uses_input_qparams(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.quant = torch.ao.quantization.QuantStub()
                self.ff = torch.ao.nn.quantized.FloatFunctional()

            def forward(self, x):
                x = self.quant(x)
                x = self.ff.mul_scalar(x, 2.0)
                return x

        m = M()
        m.qconfig = torch.ao.quantization.default_qconfig
        mp = torch.ao.quantization.prepare_qat(m)
        mp(torch.randn(4, 4))
        mq = torch.ao.quantization.convert(mp)
        res = mq(torch.randn(4, 4))
        eps = 1e-5
        self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps)

    @override_qengines
    def test_qat_embedding_bag_errors(self):
        default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)

        # Test constructor parameters checks here.
        with self.assertRaisesRegex(AssertionError,
                                    "qconfig must be provided for QAT module"):
            nnqat.EmbeddingBag(10, 5, qconfig=None)

        with self.assertRaisesRegex(AssertionError,
                                    "Embedding Bag weights requires a qscheme of " +
                                    "torch.per_channel_affine_float_qparams"):
            nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)

        # Test from_float checks here.
        embed = nn.Embedding(10, 5)
        with self.assertRaisesRegex(AssertionError,
                                    "qat.EmbeddingBag.from_float only works for EmbeddingBag"):
            nnqat.EmbeddingBag.from_float(embed)
        embed_bag = nn.EmbeddingBag(10, 5)
        with self.assertRaisesRegex(AssertionError,
                                    "Input float module must have qconfig defined"):
            nnqat.EmbeddingBag.from_float(embed_bag)
        embed_bag.qconfig = None
        with self.assertRaisesRegex(AssertionError,
                                    "Input float module must have a valid qconfig"):
            nnqat.EmbeddingBag.from_float(embed_bag)
        embed_bag.qconfig = default_qat_qconfig
        with self.assertRaisesRegex(AssertionError,
                                    "Embedding Bag weights requires a qscheme of " +
                                    "torch.per_channel_affine_float_qparams"):
            nnqat.EmbeddingBag.from_float(embed_bag)

    def test_embedding_qat_qconfig_equal(self):
        # Embedding QAT uses a NoopObserver class for activation,
        # and a FakeQuant for weight, make sure that qconfig comparison
        # functions properly for a mix of partial function and class in
        # qconfig.
        model = ManualEmbeddingBagLinear().train()
        model = prepare_qat(model)

        self.assertTrue(qconfig_equals(model.emb.qconfig,
                                       default_embedding_qat_qconfig))

class TestQuantizeEagerQATNumerics(QuantizationTestCase):
    def _test_activation_convert_numerics_impl(self, Act, data):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.act = Act()
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.act(x)
                x = self.dequant(x)
                return x

        m = M().train()
        m.qconfig = default_qat_qconfig
        m = prepare_qat(m)
        before_convert = m(data)
        m = convert(m)
        after_convert = m(data)
        self.assertEqual(before_convert, after_convert)

    def test_fixed_qparam_ops(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sigmoid = torch.nn.Sigmoid()
                self.hardsigmoid = torch.nn.Hardsigmoid()
                self.tanh = torch.nn.Tanh()
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.sigmoid(x)
                x = self.hardsigmoid(x)
                x = self.tanh(x)
                x = self.dequant(x)
                return x

        m = M().train()
        m.qconfig = default_qat_qconfig
        m = prepare_qat(m)
        for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
            self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize)
        data = torch.randn(1, 3, 2, 4)
        before_convert = m(data)
        m = convert(m)
        after_convert = m(data)
        self.assertEqual(before_convert, after_convert)
        # make sure activation post process is removed
        for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
            # verify fake quant module is removd
            self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process'))
            # verify that hooks are removed
            self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)

        # make sure no fake quantize module is inserted for eval mode

        def checkNoFQModule(m):
            for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
                self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
                self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)

        m = M().eval()
        m.qconfig = default_qconfig
        m = prepare(m)
        checkNoFQModule(m)
        m = convert(m)
        checkNoFQModule(m)

    def test_leaky_relu(self):
        data = torch.randn(1, 3, 2, 4)
        self._test_activation_convert_numerics_impl(nn.LeakyReLU, data)

    def test_relu(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = nn.ReLU()

            def forward(self, x):
                x = self.relu(x)
                return x

        m = M().train()
        m.qconfig = default_qconfig
        m = prepare_qat(m)
        # make sure no activation_post_process is inserted for relu
        self.assertFalse(hasattr(m, "activation_post_process"))
        m = convert(m)
        # make sure ReLU module is not changed
        self.assertTrue(type(m.relu), nn.ReLU)

    @given(batch_size=st.integers(2, 4),
           input_channels_per_group=st.sampled_from([2, 3, 4]),
           height=st.integers(5, 10),
           width=st.integers(5, 10),
           output_channels_per_group=st.sampled_from([2, 3]),
           groups=st.integers(1, 3),
           kernel_h=st.integers(1, 3),
           kernel_w=st.integers(1, 3),
           stride_h=st.integers(1, 2),
           stride_w=st.integers(1, 2),
           pad_h=st.integers(0, 2),
           pad_w=st.integers(0, 2),
           dilation=st.integers(1, 1),
           padding_mode=st.sampled_from(['zeros', 'circular']),
           use_relu=st.booleans(),
           eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
           momentum=st.sampled_from([0.1, 0.2, 0.3]),
           freeze_bn=st.booleans(),
           zero_gamma=st.booleans(),
           has_bias=st.booleans(),
           use_slow_fusion=st.booleans())
    def test_conv_bn_relu(
            self,
            batch_size,
            input_channels_per_group,
            height,
            width,
            output_channels_per_group,
            groups,
            kernel_h,
            kernel_w,
            stride_h,
            stride_w,
            pad_h,
            pad_w,
            dilation,
            padding_mode,
            use_relu,
            eps,
            momentum,
            freeze_bn,
            zero_gamma,
            has_bias,
            use_slow_fusion,
    ):
        input_channels = input_channels_per_group * groups
        output_channels = output_channels_per_group * groups
        dilation_h = dilation_w = dilation

        conv_op = Conv2d(
            input_channels,
            output_channels,
            (kernel_h, kernel_w),
            (stride_h, stride_w),
            (pad_h, pad_w),
            (dilation_h, dilation_w),
            groups,
            has_bias,
            padding_mode
        ).to(dtype=torch.double)
        bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
        relu_op = ReLU()

        cls = ConvBnReLU2d if use_relu else ConvBn2d
        qat_op = cls(
            input_channels,
            output_channels,
            (kernel_h, kernel_w),
            (stride_h, stride_w),
            (pad_h, pad_w),
            (dilation_h, dilation_w),
            groups,
            has_bias,
            padding_mode,
            eps,
            momentum,
            freeze_bn=True,
            qconfig=default_qat_qconfig
        ).to(dtype=torch.double)
        qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion

        # the approximate fusion will not work if bn.weight has 0
        if zero_gamma and use_slow_fusion:
            torch.nn.init.zeros_(qat_op.bn.weight)

        qat_op.apply(torch.ao.quantization.disable_fake_quant)
        if freeze_bn:
            qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
        else:
            qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)

        # align inputs and internal parameters
        input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
        conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
        if has_bias:
            conv_op.bias = torch.nn.Parameter(qat_op.bias.detach())
        bn_op.running_mean = qat_op.bn.running_mean.clone()
        bn_op.running_var = qat_op.bn.running_var.clone()
        bn_op.weight = torch.nn.Parameter(qat_op.bn.weight.detach())
        bn_op.bias = torch.nn.Parameter(qat_op.bn.bias.detach())

        def compose(functions):
            # functions are reversed for natural reading order
            return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)

        if not use_relu:
            def relu_op(x):  # noqa: F811
                return x

        if freeze_bn:
            def ref_op(x):
                x = conv_op(x)
                x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
                    (bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
                    .reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
                x = relu_op(x)
                return x
        else:
            ref_op = compose([conv_op, bn_op, relu_op])

        input_clone = input.clone().detach().requires_grad_()
        for i in range(2):
            result_ref = ref_op(input)
            result_actual = qat_op(input_clone)
            self.assertEqual(result_ref, result_actual)

            # backward
            dout = torch.randn(result_ref.size(), dtype=torch.double)
            loss = (result_ref - dout).sum()
            loss.backward()
            input_grad_ref = input.grad.cpu()
            weight_grad_ref = conv_op.weight.grad.cpu()
            gamma_grad_ref = bn_op.weight.grad.cpu()
            beta_grad_ref = bn_op.bias.grad.cpu()
            running_mean_ref = bn_op.running_mean
            running_var_ref = bn_op.running_var
            num_batches_tracked_ref = bn_op.num_batches_tracked
            loss = (result_actual - dout).sum()
            loss.backward()
            input_grad_actual = input_clone.grad.cpu()
            weight_grad_actual = qat_op.weight.grad.cpu()
            gamma_grad_actual = qat_op.bn.weight.grad.cpu()
            beta_grad_actual = qat_op.bn.bias.grad.cpu()
            running_mean_actual = qat_op.bn.running_mean
            running_var_actual = qat_op.bn.running_var
            num_batches_tracked_actual = qat_op.bn.num_batches_tracked
            precision = 1e-10
            self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
            self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
            self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
            self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
            self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
            self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
            self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)

    @given(batch_size=st.integers(2, 4),
           input_channels_per_group=st.sampled_from([2, 3, 4]),
           height=st.integers(5, 10),
           width=st.integers(5, 10),
           output_channels_per_group=st.sampled_from([2, 3]),
           groups=st.integers(1, 3),
           kernel_h=st.integers(1, 3),
           kernel_w=st.integers(1, 3),
           stride_h=st.integers(1, 2),
           stride_w=st.integers(1, 2),
           pad_h=st.integers(0, 2),
           pad_w=st.integers(0, 2),
           dilation=st.integers(1, 1),
           padding_mode=st.sampled_from(['zeros', 'circular']),
           eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
           momentum=st.sampled_from([0.1, 0.2, 0.3]),
           freeze_bn=st.booleans(),
           bias=st.booleans())
    def test_conv_bn_folded_vs_unfolded(
            self,
            batch_size,
            input_channels_per_group,
            height,
            width,
            output_channels_per_group,
            groups,
            kernel_h,
            kernel_w,
            stride_h,
            stride_w,
            pad_h,
            pad_w,
            dilation,
            padding_mode,
            eps,
            momentum,
            freeze_bn,
            bias,
    ):
        input_channels = input_channels_per_group * groups
        output_channels = output_channels_per_group * groups
        dilation_h = dilation_w = dilation

        qat_op = ConvBn2d(
            input_channels,
            output_channels,
            (kernel_h, kernel_w),
            (stride_h, stride_w),
            (pad_h, pad_w),
            (dilation_h, dilation_w),
            groups,
            bias,  # bias
            padding_mode,
            eps,
            momentum,
            freeze_bn=freeze_bn,
            qconfig=default_qat_qconfig
        ).to(dtype=torch.double)

        qat_ref_op = _ReferenceConvBn2d(
            input_channels,
            output_channels,
            (kernel_h, kernel_w),
            (stride_h, stride_w),
            (pad_h, pad_w),
            (dilation_h, dilation_w),
            groups,
            bias,  # bias
            padding_mode,
            eps,
            momentum,
            freeze_bn=freeze_bn,
            qconfig=default_qat_qconfig
        ).to(dtype=torch.double)

        qat_op.apply(torch.ao.quantization.disable_fake_quant)
        qat_ref_op.apply(torch.ao.quantization.disable_fake_quant)

        # align inputs and internal parameters
        qat_ref_op.weight = torch.nn.Parameter(qat_op.weight.detach().clone())
        qat_ref_op.running_mean = qat_op.bn.running_mean.clone()
        qat_ref_op.running_var = qat_op.bn.running_var.clone()
        qat_ref_op.gamma = torch.nn.Parameter(qat_op.bn.weight.detach().clone())
        qat_ref_op.beta = torch.nn.Parameter(qat_op.bn.bias.detach().clone())
        if qat_op.bias is not None:
            qat_ref_op.bias = torch.nn.Parameter(qat_op.bias.detach().clone())

        lr = 0.01
        qat_op_optim = torch.optim.SGD(qat_op.parameters(), lr=lr)
        qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)

        for i in range(5):

            # make sure that calling model.train() does not override the
            # bn freeze setting
            qat_op.train()
            qat_ref_op.train()

            qat_op_optim.zero_grad()
            qat_ref_op_optim.zero_grad()

            input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
            input_clone = input.clone().detach().requires_grad_()

            if i > 2:
                qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
                qat_ref_op.freeze_bn_stats()

            if i > 3:
                qat_op.apply(torch.ao.quantization.disable_observer)
                qat_ref_op.apply(torch.ao.quantization.disable_observer)

            result_ref = qat_ref_op(input)
            result_actual = qat_op(input_clone)
            self.assertEqual(result_ref, result_actual)

            # backward
            dout = torch.randn(result_ref.size(), dtype=torch.double) + 10.0

            loss = (result_ref - dout).sum()
            loss.backward()
            input_grad_ref = input.grad.cpu()
            weight_grad_ref = qat_ref_op.weight.grad.cpu()
            gamma_grad_ref = qat_ref_op.gamma.grad.cpu()
            beta_grad_ref = qat_ref_op.beta.grad.cpu()
            running_mean_ref = qat_ref_op.running_mean
            running_var_ref = qat_ref_op.running_var
            num_batches_tracked_ref = qat_ref_op.num_batches_tracked

            loss = (result_actual - dout).sum()
            loss.backward()
            input_grad_actual = input_clone.grad.cpu()
            weight_grad_actual = qat_op.weight.grad.cpu()
            gamma_grad_actual = qat_op.bn.weight.grad.cpu()
            beta_grad_actual = qat_op.bn.bias.grad.cpu()
            running_mean_actual = qat_op.bn.running_mean
            running_var_actual = qat_op.bn.running_var
            num_batches_tracked_actual = qat_op.bn.num_batches_tracked

            precision = 1e-5
            self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
            self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
            self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
            self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
            self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
            self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
            self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)

            qat_op_optim.step()
            qat_ref_op_optim.step()

    @override_qengines
    def test_linear_bn_numerics(self):
        qengine = torch.backends.quantized.engine
        m_ref = nn.Sequential(
            nn.Linear(4, 4),
            nn.BatchNorm1d(4),
        )
        m_ref_copy = copy.deepcopy(m_ref)
        m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
        qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
        m_ref_copy[0].qconfig = qconfig
        m = nniqat.LinearBn1d.from_float(m_ref_copy[0])

        # without fake_quants, fused QAT module should match fp32 module
        m.apply(torch.ao.quantization.disable_fake_quant)
        data = torch.randn(4, 4)
        r1 = m_ref(data)
        r2 = m(data)
        self.assertTrue(torch.allclose(r1, r2))

    @skipIfNoXNNPACK
    @override_qengines
    def test_linear_bn_symm_numerics(self):
        qengine = torch.backends.quantized.engine
        if qengine != "qnnpack":
            return  # Only qnnpack support symmetric quantization
        m_ref = nn.Sequential(
            nn.Linear(4, 4),
            nn.BatchNorm1d(4),
        )
        m_ref_copy = copy.deepcopy(m_ref)
        m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
        qconfig = default_symmetric_qnnpack_qat_qconfig
        m_ref_copy[0].qconfig = qconfig
        m = nniqat.LinearBn1d.from_float(m_ref_copy[0])

        # without fake_quants, fused QAT module should match fp32 module
        m.apply(torch.ao.quantization.disable_fake_quant)
        data = torch.randn(4, 4)
        r1 = m_ref(data)
        r2 = m(data)
        self.assertTrue(torch.allclose(r1, r2))

    @override_qengines
    def test_linear_bn_workflow(self):
        qengine = torch.backends.quantized.engine
        m = nn.Sequential(
            QuantStub(),
            nn.Linear(4, 4),
            nn.BatchNorm1d(4),
        )
        data = torch.randn(4, 4)
        m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
        m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
        mp = prepare_qat(m)
        mp(data)
        mq = convert(mp)
        self.assertTrue(type(mq[1]) == nnq.Linear)
        self.assertTrue(type(mq[2]) == nn.Identity)


    @skipIfNoXNNPACK
    @override_qengines
    def test_linear_precomputed_fake_quant(self):
        qengine = torch.backends.quantized.engine
        if qengine != "qnnpack":
            return  # Only qnnpack support symmetric quantization
        m_ref = nn.Linear(4, 4)

        m_ref_copy = copy.deepcopy(m_ref)
        qconfig = default_qconfig
        m_ref_copy.qconfig = qconfig
        weight_post_process = copy.deepcopy(qconfig.weight())
        activation = copy.deepcopy(qconfig.activation())
        activation(torch.randn(4, 4))
        m_ref_copy.activation_post_process = activation
        m_ref_copy = nnq.Linear.from_float(m_ref_copy)
        weight_post_process = qconfig.weight()
        weight_post_process.min_val = torch.tensor(-1)
        weight_post_process.max_val = torch.tensor(1)
        m_ref.weight_post_process = weight_post_process
        m_ref.activation_post_process = activation
        m_ref.qconfig = qconfig
        m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
        self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale)


if __name__ == '__main__':
    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
                       "\tpython test/test_quantization.py TESTNAME\n\n"
                       "instead.")
