# Owner(s): ["oncall: quantization"]

import torch
import torch.nn as nn
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_quantization import skipIfNoFBGEMM

from torch.ao.quantization import default_qconfig
from torch.ao.quantization import QuantWrapper
import torch.ao.ns._numeric_suite as ns

from torch.ao.quantization._correct_bias import (
    _supported_modules,
    _supported_modules_quantized,
    bias_correction,
    get_module,
    get_param,
    parent_child_names
)

import copy


class TestBiasCorrectionEager(QuantizationTestCase):
    def compute_sqnr(self, x, y):
        Ps = torch.norm(x)
        Pn = torch.norm(x - y)
        return 20 * torch.log10(Ps / Pn)

    def correct_artificial_bias_quantize(self, float_model, img_data):
        ''' Adding artificial bias and testing if bias persists after bias
            correction. This test case changes the bias of a quantized submodule
        '''
        artificial_model = copy.deepcopy(float_model)
        artificial_model.qconfig = default_qconfig
        torch.ao.quantization.prepare(artificial_model, inplace=True)
        for data in img_data:
            artificial_model(data[0])
        torch.ao.quantization.convert(artificial_model, inplace=True)

        # manually changing bias
        for name, submodule in artificial_model.named_modules():
            if type(submodule) in _supported_modules:
                x = get_param(submodule, 'bias')
                weight = get_param(submodule, 'weight')
                if x is not None:
                    submodule.set_weight_bias(weight, x.data * 3)

        bias_correction(float_model, artificial_model, img_data, target_modules=_supported_modules_quantized)

        # Trims off the shadow module,
        for name, submodule in artificial_model.named_modules():
            if isinstance(submodule, ns.Shadow):
                parent_name, child_name = parent_child_names(name)
                parent = get_module(artificial_model, parent_name)
                parent._modules[child_name] = submodule.orig_module

        for name, artificial_submodule in artificial_model.named_modules():
            if type(artificial_submodule) in _supported_modules_quantized:
                submodule = get_module(float_model, name)
                float_bias = get_param(submodule, 'bias')
                artificial_bias = get_param(artificial_submodule, 'bias')

                self.assertTrue(self.compute_sqnr(float_bias, artificial_bias) > 30,
                                "Correcting quantized bias produced too much noise, sqnr score too low")

    @skipIfNoFBGEMM
    def test_linear_chain(self):
        class LinearChain(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = nn.Linear(3, 4)
                self.linear2 = nn.Linear(4, 5)
                self.linear3 = nn.Linear(5, 6)

            def forward(self, x):
                x = self.linear1(x)
                x = self.linear2(x)
                x = self.linear3(x)
                return x
        float_model = QuantWrapper(LinearChain())
        img_data = [(torch.rand(10, 3, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
                    for _ in range(50)]
        self.correct_artificial_bias_quantize(float_model, img_data)

    @skipIfNoFBGEMM
    def test_conv_chain(self):
        class ConvChain(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv2d1 = nn.Conv2d(3, 4, 5, 5)
                self.conv2d2 = nn.Conv2d(4, 5, 5, 5)
                self.conv2d3 = nn.Conv2d(5, 6, 5, 5)

            def forward(self, x):
                x = self.conv2d1(x)
                x = self.conv2d2(x)
                x = self.conv2d3(x)
                return x
        float_model = QuantWrapper(ConvChain())
        img_data = [(torch.rand(10, 3, 125, 125, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
                    for _ in range(50)]
        self.correct_artificial_bias_quantize(float_model, img_data)
