# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import unittest
from typing import Optional

import torch
from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
from executorch.backends.xnnpack.test.tester import Quantize, Tester
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig


class Conv2d(torch.nn.Module):
    def __init__(
        self,
        in_channels=2,
        out_channels=1,
        kernel_size=(3, 3),
        stride=(2, 2),
        padding=(1, 1),
        dilation=(1, 1),
        groups=1,
        bias=True,
        padding_mode="zeros",
        batches=1,
        width=8,
        height=8,
        dtype=torch.float,
    ):
        super().__init__()
        self.batches = batches
        self.width = width
        self.height = height
        self.in_channels = in_channels
        self.dtype = dtype

        self.conv = torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        ).to(dtype)

    def forward(self, x):
        return self.conv(x)

    def get_inputs(self):
        return (
            torch.randn(self.batches, self.in_channels, self.height, self.width).to(
                self.dtype
            ),
        )


class Conv2dSeq(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.first = torch.nn.Conv2d(
            in_channels=1,
            out_channels=3,
            kernel_size=(3, 3),
            padding=1,
            bias=False,
        )
        self.second = torch.nn.Conv2d(
            in_channels=3,
            out_channels=2,
            kernel_size=(3, 3),
            padding=1,
            bias=False,
        )

    def forward(self, x):
        y = self.first(x)
        return self.second(y)

    def get_inputs(self):
        return (torch.randn(1, 1, 3, 3),)


class Conv2dBatchNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(
            2,
            2,
            (2, 2),
            bias=False,
            padding=[1, 1],
            stride=[4, 4],
        )
        self.bn = randomize_bn(2)
        self.hardtanh = torch.nn.Hardtanh()
        self.conv2 = torch.nn.Conv2d(
            2,
            2,
            (2, 2),
            bias=False,
            padding=[1, 1],
            stride=[4, 4],
        )

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn(y)
        y = self.hardtanh(y)
        y = self.conv2(y)
        y = self.bn(y)
        y = self.hardtanh(y)
        return y

    def get_inputs(self):
        return (torch.randn(2, 2, 4, 4),)


class Conv2dPermute(torch.nn.Module):
    def __init__(self, permute_order):
        super().__init__()
        self.conv = torch.nn.Conv2d(
            2,
            2,
            (2, 2),
            bias=False,
            padding=[2, 2],
            stride=[2, 2],
        )
        self.permute_order = permute_order

    def forward(self, x):
        result = self.conv(x)
        channels_last = torch.permute(result, self.permute_order)
        return channels_last

    def get_inputs(self):
        return (torch.randn(2, 2, 4, 4),)


class TestConv2d(unittest.TestCase):
    def _test(
        self,
        m: torch.nn.Module,
        quant_config: Optional[QuantizationConfig] = None,
        conv_count=1,
        dtype: torch.dtype = torch.float,
    ):
        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
        #  torch.nn.modules.module.Module]` is not a function.
        tester = Tester(m.eval(), m.get_inputs())

        if quant_config is not None:
            tester = tester.quantize(Quantize(quantization_config=quant_config))
            tester.check(["torch.ops.quantized_decomposed"])

        (
            tester.export()
            .check_count({"torch.ops.aten.conv2d": conv_count})
            .to_edge_transform_and_lower()
            .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
            .check_not(
                [
                    "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default"
                ]
            )
            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
            .to_executorch()
            .serialize()
            .run_method_and_compare_outputs(qtol=1)
        )

    def test_fp16_conv2d(self) -> None:
        for has_bias in (True, False):
            self._test(Conv2d(bias=has_bias, dtype=torch.float16))

    def test_fp32_conv2d(self) -> None:
        for has_bias in (True, False):
            self._test(Conv2d(bias=has_bias))

    def test_fp32_conv2d_permute(self) -> None:
        for perm_order in list(itertools.permutations([0, 1, 2, 3])):
            self._test(Conv2dPermute(perm_order))

    def test_qs8_conv2d_test(self) -> None:
        for has_bias in (True, False):
            self._test(
                Conv2d(bias=has_bias), quant_config=get_symmetric_quantization_config()
            )

    def test_qs8_conv2d_per_channel(self) -> None:
        self._test(
            Conv2d(),
            quant_config=get_symmetric_quantization_config(is_per_channel=True),
        )

    def test_fp32_conv2d_seq(self) -> None:
        self._test(Conv2dSeq(), conv_count=2)

    def test_qs8_conv2d_seq(self) -> None:
        self._test(
            Conv2dSeq(), conv_count=2, quant_config=get_symmetric_quantization_config()
        )

    def test_fp32_conv2d_single_int_params(self):
        self._test(
            Conv2d(
                kernel_size=3,
                stride=2,
                padding="valid",
                dilation=1,
            )
        )

    def test_fp32_conv2d_depthwise(self):
        # Depthwise Convolution Requirements:
        # - Groups must equal In Channels
        # - Out Channels must be a positive multiple of In Channels
        self._test(Conv2d(groups=2, in_channels=2, out_channels=6))

    def test_qs8_conv2d_depthwise(self):
        self._test(
            Conv2d(groups=2, in_channels=2, out_channels=6),
            quant_config=get_symmetric_quantization_config(),
        )

    def test_fp32_conv2d_bn(self):
        class Conv2dBatchNorm(torch.nn.Module):
            def __init__(self, in_features: int, out_features: int, kernel_size):
                super().__init__()
                self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size)
                self.bn = randomize_bn(out_features)
                self.in_features = in_features
                self.kernel_size = kernel_size

            def forward(self, x):
                y = self.conv2d(x)
                y = self.bn(y)
                return y

            def get_inputs(self):
                return (
                    torch.randn(
                        2,
                        self.in_features,
                        self.kernel_size[0] * 2,
                        self.kernel_size[1] * 2,
                    ),
                )

        self._test(Conv2dBatchNorm(in_features=2, out_features=2, kernel_size=(2, 2)))

    def test_fp32_conv2d_bn_hardtanh_mean_sequence(self):
        """
        This test makes sure that we can fuse batchnorm and hardtanh
        even with inserting copy nodes at some spots in the graph to change
        memory format
        """

        class Conv2dBatchNormHardTanh(torch.nn.Module):
            def __init__(self, in_channels: int, out_channels: int, kernel_size):
                super().__init__()
                self.conv = torch.nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    padding=[1, 1],
                    stride=[2, 2],
                )
                self.in_channels = in_channels
                self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
                self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)

            def forward(self, x):
                x = self.conv(x)
                x = self.native_batchnorm(x)
                x = self.hardtanh(x)
                x = torch.mean(x, (-1, -2), keepdim=True)
                return x

            def get_inputs(self):
                return (torch.randn(2, self.in_channels, 8, 8),)

        self._test(
            Conv2dBatchNormHardTanh(in_channels=2, out_channels=1, kernel_size=(2, 2))
        )

    def test_qs8_conv2d_bn(self):
        self._test(
            Conv2dBatchNorm(),
            quant_config=get_symmetric_quantization_config(),
            conv_count=2,
        )

    def test_qs8_conv2d_relu(self):
        class ConvReLU(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = torch.nn.Conv2d(
                    2,
                    2,
                    (2, 2),
                    bias=False,
                    padding=[1, 1],
                    stride=[4, 4],
                )
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                y = self.conv1(x)
                y = self.relu(y)
                return y

            def get_inputs(self):
                return (torch.randn(2, 2, 4, 4),)

        self._test(
            ConvReLU(),
            quant_config=get_symmetric_quantization_config(),
        )

    def test_qs8_conv2d_dw_relu(self):
        # Depthwise Convolution Requirements:
        # - Groups must equal In Channels
        # - Out Channels must be a positive multiple of In Channels
        groups = 2
        stride = [2, 2]
        padding = [1, 1]
        dilation = [1, 1]
        in_channels = groups
        out_channels = 3 * in_channels
        width = 8
        height = 8
        batches = 1

        class ModelConvReLU(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = torch.nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=(3, 3),
                    stride=stride,
                    padding=padding,
                    groups=groups,
                    dilation=dilation,
                    bias=True,
                )
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                y = self.conv1(x)
                y = self.relu(y)
                return y

            def get_inputs(self):
                return (torch.randn(batches, in_channels, height, width) * 11,)

        for per_channel_quant in (False, True):
            model = ModelConvReLU()
            self._test(
                model,
                quant_config=get_symmetric_quantization_config(
                    is_per_channel=per_channel_quant
                ),
            )

    def test_qs8_conv2d_relu_seq(self):
        class ConvReLUSeq(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.model = torch.nn.Sequential(
                    torch.nn.Conv2d(1, 1, 1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(1, 64, 1),
                    torch.nn.ReLU(),
                )

            def forward(self, x):
                return self.model(x)

            def get_inputs(self):
                return (torch.randn(1, 1, 1, 1),)

        self._test(
            ConvReLUSeq(),
            quant_config=get_symmetric_quantization_config(),
            conv_count=2,
        )

    def test_qs8_conv2d_relu_multi_users(self):
        class Conv2dReluMultiUsers(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 64, 1)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                conv_default = self.conv1(x)
                y = self.relu(conv_default)
                conv_default_2 = self.conv2(y)
                return conv_default + conv_default_2

            def get_inputs(self):
                return (torch.randn(1, 1, 1, 1),)

        self._test(
            Conv2dReluMultiUsers(),
            quant_config=get_symmetric_quantization_config(),
            conv_count=2,
        )
