# Owner(s): ["oncall: quantization"]
import copy
import unittest
from typing import List

import torch
import torch._export
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import QuantizationAnnotation, Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
from torch.fx import Node
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS


class TestHelperModules:
    class Conv2dWithObsSharingOps(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 3, 3)
            self.hardtanh = torch.nn.Hardtanh()
            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
            self.linear = torch.nn.Linear(3, 3)

        def forward(self, x):
            x = self.conv(x)
            x = self.adaptive_avg_pool2d(x)
            x = self.hardtanh(x)
            x = x.view(-1, 3)
            x = self.linear(x)
            return x


def _tag_partitions(
    backend_name: str, op_name: str, annotated_partitions: List[List[Node]]
):
    for index, partition_nodes in enumerate(annotated_partitions):
        tag_name = backend_name + "_" + op_name + "_" + str(index)
        for node in partition_nodes:
            assert "quantization_tag" not in node.meta, f"{node} is already tagged"
            node.meta["quantization_tag"] = tag_name


_QUANT_OPS = {
    torch.ops.quantized_decomposed.quantize_per_tensor.default,
    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
    torch.ops.quantized_decomposed.quantize_per_channel.default,
    torch.ops.quantized_decomposed.dequantize_per_channel.default,
    torch.ops.quantized_decomposed.choose_qparams.tensor,
}


# TODO: rename to TestPortMetadataPass to align with the util name?
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
class TestMetaDataPorting(QuantizationTestCase):
    def _test_quant_tag_preservation_through_decomp(
        self, model, example_inputs, from_node_to_tags
    ):
        ep = torch.export.export(model, example_inputs)
        found_tags = True
        not_found_nodes = ""
        for from_node, tag in from_node_to_tags.items():
            for n in ep.graph_module.graph.nodes:
                from_node_meta = n.meta.get("from_node", None)
                if from_node_meta is None:
                    continue
                if not isinstance(from_node_meta, list):
                    raise ValueError(
                        f"from_node metadata is of type {type(from_node_meta)}, but expected list"
                    )
                for meta in from_node_meta:
                    node_target = meta[1]
                    if node_target == from_node:
                        node_tag = n.meta.get("quantization_tag", None)
                        if node_tag is None or tag != node_tag:
                            not_found_nodes += str(n.target) + ", "
                            found_tags = False
                            break
                if not found_tags:
                    break
        self.assertTrue(
            found_tags,
            f"Decomposition did not preserve quantization tag for {not_found_nodes}",
        )

    def _test_metadata_porting(
        self,
        model,
        example_inputs,
        quantizer,
        node_tags=None,
    ) -> torch.fx.GraphModule:
        m_eager = model.eval()

        # program capture
        m = copy.deepcopy(m_eager)
        m = torch._export.capture_pre_autograd_graph(
            m,
            example_inputs,
        )

        m = prepare_pt2e(m, quantizer)
        # Calibrate
        m(*example_inputs)
        m = convert_pt2e(m)

        pt2_quant_output = m(*example_inputs)
        recorded_node_tags = {}
        for n in m.graph.nodes:
            if "quantization_tag" not in n.meta:
                continue
            if n.op == "call_function" and n.target in _QUANT_OPS:
                key = n.target
            elif n.op == "get_attr":
                key = "get_attr"
            else:
                continue

            if key not in recorded_node_tags:
                recorded_node_tags[key] = set()

            if (
                n.op == "call_function"
                and n.meta["quantization_tag"] in recorded_node_tags[key]
            ):
                raise ValueError(
                    f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that "
                    "is associated with another node of the same type"
                )
            recorded_node_tags[key].add(n.meta["quantization_tag"])

        self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
        for k, v in recorded_node_tags.items():
            self.assertEqual(v, node_tags[k])
        return m

    def test_simple_metadata_porting(self):
        """
        Model under test
        conv2d -> avgpool -> hardtanh -> linear
        Check quantization tags on conv2d, avgpool and linear are correctly set
        """

        class BackendAQuantizer(Quantizer):
            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
                backend_string = "BackendA"
                quantization_config = get_symmetric_quantization_config(
                    is_per_channel=True
                )
                annotated_partitions = OP_TO_ANNOTATOR["linear"](
                    gm, quantization_config
                )
                _tag_partitions(backend_string, "linear", annotated_partitions)
                annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
                _tag_partitions(backend_string, "conv2d", annotated_partitions)
                annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
                    gm, quantization_config
                )
                _tag_partitions(
                    backend_string, "adaptive_avg_pool2d", annotated_partitions
                )

            def validate(self, model: torch.fx.GraphModule) -> None:
                pass

        example_inputs = (torch.randn(1, 3, 5, 5),)
        get_attr_tags = {
            "BackendA_conv2d_0",
            "BackendA_linear_0",
        }
        quantize_per_tensor_tags = {
            "BackendA_conv2d_0",
            "BackendA_adaptive_avg_pool2d_0",
            "BackendA_linear_0",
        }
        dequantize_per_tensor_tags = {
            "BackendA_adaptive_avg_pool2d_0",
            "BackendA_conv2d_0",
            "BackendA_linear_0",
        }
        dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
        node_tags = {
            "get_attr": get_attr_tags,
            torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
        }
        m = self._test_metadata_porting(
            TestHelperModules.Conv2dWithObsSharingOps(),
            example_inputs,
            BackendAQuantizer(),
            node_tags,
        )

        from_node_to_tags = {
            torch.ops.aten.adaptive_avg_pool2d.default: "BackendA_adaptive_avg_pool2d_0",
            torch.ops.aten.linear.default: "BackendA_linear_0",
        }
        self._test_quant_tag_preservation_through_decomp(
            m, example_inputs, from_node_to_tags
        )

    def test_metadata_porting_with_no_quant_inbetween(self):
        """
        Model under test
        conv2d -> avgpool -> hardtanh -> linear
        Dont quantize avgpool
        Check quantization tags on conv2d and linear are correctly set
        """

        class BackendAQuantizer(Quantizer):
            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
                backend_string = "BackendA"
                quantization_config = get_symmetric_quantization_config(
                    is_per_channel=True
                )
                annotated_partitions = OP_TO_ANNOTATOR["linear"](
                    gm, quantization_config
                )
                _tag_partitions(backend_string, "linear", annotated_partitions)
                annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
                _tag_partitions(backend_string, "conv2d", annotated_partitions)

            def validate(self, model: torch.fx.GraphModule) -> None:
                pass

        example_inputs = (torch.randn(1, 3, 5, 5),)
        get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
        quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
        dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
        dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
        node_tags = {
            "get_attr": get_attr_tags,
            torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
        }
        self._test_metadata_porting(
            TestHelperModules.Conv2dWithObsSharingOps(),
            example_inputs,
            BackendAQuantizer(),
            node_tags,
        )

    @unittest.skip("Temporarily disabled")
    def test_metadata_porting_for_dq(self):
        """
        Model under test
        conv2d -> avgpool -> hardtanh -> linear
        Quantize all except linear.
        Quantize linear with dynamic quantization
        Check quantization tags on conv2d, avgpool and linear are correctly set
        """

        class BackendAQuantizer(Quantizer):
            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
                backend_string = "BackendA"
                # static quantiazation
                quantization_config = get_symmetric_quantization_config(
                    is_per_channel=True
                )
                annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
                _tag_partitions(backend_string, "conv2d", annotated_partitions)
                annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
                    gm, quantization_config
                )
                _tag_partitions(
                    backend_string, "adaptive_avg_pool2d", annotated_partitions
                )

                # dynamic quantization
                quantization_config_dynamic = get_symmetric_quantization_config(
                    is_per_channel=True, is_dynamic=True
                )
                annotated_partitions = OP_TO_ANNOTATOR["linear"](
                    gm, quantization_config_dynamic
                )
                _tag_partitions(backend_string, "linear_dynamic", annotated_partitions)

            def validate(self, model: torch.fx.GraphModule) -> None:
                pass

        example_inputs = (torch.randn(1, 3, 5, 5),)
        # TODO: add get_attr_tags when the test is re-enabled
        get_attr_tags = {}
        quantize_per_tensor_tags = {
            "BackendA_conv2d_0",
            "BackendA_adaptive_avg_pool2d_0",
        }
        quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
        choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
        dequantize_per_tensor_tags = {
            "BackendA_adaptive_avg_pool2d_0",
            "BackendA_conv2d_0",
        }
        dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
        dequantize_per_channel_tags = {
            "BackendA_conv2d_0",
            "BackendA_linear_dynamic_0",
        }
        node_tags = {
            "get_attr": get_attr_tags,
            torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
            torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags,
        }
        self._test_metadata_porting(
            TestHelperModules.Conv2dWithObsSharingOps(),
            example_inputs,
            BackendAQuantizer(),
            node_tags,
        )

    def test_metadata_porting_for_two_dq(self):
        """
        Model under test
        conv2d -> avgpool -> hardtanh -> linear
        Quantize linear and conv with dynamic quantization
        Check quantization tags on conv2d, avgpool and linear are correctly set
        """

        class BackendAQuantizer(Quantizer):
            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
                backend_string = "BackendA"

                # dynamic quantization
                quantization_config_dynamic = get_symmetric_quantization_config(
                    is_per_channel=True, is_dynamic=True
                )
                annotated_partitions = OP_TO_ANNOTATOR["conv"](
                    gm, quantization_config_dynamic
                )
                _tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
                annotated_partitions = OP_TO_ANNOTATOR["linear"](
                    gm, quantization_config_dynamic
                )
                _tag_partitions(backend_string, "linear_dynamic", annotated_partitions)

            def validate(self, model: torch.fx.GraphModule) -> None:
                pass

        example_inputs = (torch.randn(1, 3, 5, 5),)
        get_attr_tags = {
            "BackendA_conv2d_dynamic_0",
            "BackendA_linear_dynamic_0",
        }
        choose_qparams_tensor_tags = {
            "BackendA_conv2d_dynamic_0",
            "BackendA_linear_dynamic_0",
        }
        quantize_per_tensor_tensor_tags = {
            "BackendA_conv2d_dynamic_0",
            "BackendA_linear_dynamic_0",
        }
        dequantize_per_tensor_tensor_tags = {
            "BackendA_conv2d_dynamic_0",
            "BackendA_linear_dynamic_0",
        }
        dequantize_per_channel_tags = {
            "BackendA_conv2d_dynamic_0",
            "BackendA_linear_dynamic_0",
        }
        node_tags = {
            "get_attr": get_attr_tags,
            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
            torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
        }
        self._test_metadata_porting(
            TestHelperModules.Conv2dWithObsSharingOps(),
            example_inputs,
            BackendAQuantizer(),
            node_tags,
        )

    def test_metadata_porting_for_dq_no_static_q(self):
        """
        Model under test
        conv2d -> avgpool -> hardtanh -> linear
        Dont quantize anything except linear.
        Quantize linear with dynamic quantization
        Check quantization tags on conv2d, avgpool and linear are correctly set
        """

        class BackendAQuantizer(Quantizer):
            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
                backend_string = "BackendA"
                # dynamic quantization
                quantization_config_dynamic = get_symmetric_quantization_config(
                    is_per_channel=True, is_dynamic=True
                )
                annotated_partitions = OP_TO_ANNOTATOR["linear"](
                    gm, quantization_config_dynamic
                )
                _tag_partitions(backend_string, "linear_dynamic", annotated_partitions)

            def validate(self, model: torch.fx.GraphModule) -> None:
                pass

        example_inputs = (torch.randn(1, 3, 5, 5),)
        get_attr_tags = {"BackendA_linear_dynamic_0"}
        choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
        quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
        dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
        dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
        node_tags = {
            "get_attr": get_attr_tags,
            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
            torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
        }
        self._test_metadata_porting(
            TestHelperModules.Conv2dWithObsSharingOps(),
            example_inputs,
            BackendAQuantizer(),
            node_tags,
        )

    def test_no_metadata_porting(self):
        class BackendAQuantizer(Quantizer):
            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
                backend_string = "BackendA"
                quantization_config = get_symmetric_quantization_config(
                    is_per_channel=True
                )
                OP_TO_ANNOTATOR["linear"](gm, quantization_config)
                OP_TO_ANNOTATOR["conv"](gm, quantization_config)
                OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)

            def validate(self, model: torch.fx.GraphModule) -> None:
                pass

        example_inputs = (torch.randn(1, 3, 5, 5),)
        node_tags = {}
        m = self._test_metadata_porting(
            TestHelperModules.Conv2dWithObsSharingOps(),
            example_inputs,
            BackendAQuantizer(),
            node_tags,
        )

        from_node_to_tags = {}
        self._test_quant_tag_preservation_through_decomp(
            m, example_inputs, from_node_to_tags
        )

    def test_no_metadata_porting_through_unknown_ops(self):
        """
        Model under test
        matmul -> add -> relu
        matmul has get_attr as first input, but the quantization_tag should not be
        propagated to add even if it's part of a chain that ends at get_attr
        """

        class MatmulWithConstInput(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.register_parameter("w", torch.nn.Parameter(torch.rand(8, 16)))

            def forward(self, x, y):
                x = torch.matmul(self.w, x)
                z = x + y
                return torch.nn.functional.relu(z)

        class BackendAQuantizer(Quantizer):
            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
                backend_string = "BackendA"
                qconfig = get_symmetric_quantization_config()
                for n in gm.graph.nodes:
                    if n.op != "call_function":
                        continue

                    n.meta["quantization_annotation"] = QuantizationAnnotation(
                        input_qspec_map={n.args[0]: qconfig.input_activation},
                        output_qspec=qconfig.output_activation,
                    )

                    tag = str(n.target)
                    n.meta["quantization_tag"] = tag
                    for arg in n.args:
                        if arg.op == "get_attr":
                            arg.meta["quantization_tag"] = tag

            def validate(self, model: torch.fx.GraphModule) -> None:
                pass

        example_inputs = (torch.randn(16, 24), torch.randn(8, 24))
        get_attr_tags = {"aten.matmul.default"}
        quantize_per_tensor_tensor_tags = {
            "aten.matmul.default",
            "aten.add.Tensor",
            "aten.relu.default",
        }
        dequantize_per_tensor_tensor_tags = {
            "aten.matmul.default",
            "aten.add.Tensor",
            "aten.relu.default",
        }
        node_tags = {
            "get_attr": get_attr_tags,
            torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tensor_tags,
            torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tensor_tags,
        }
        m = self._test_metadata_porting(
            MatmulWithConstInput(),
            example_inputs,
            BackendAQuantizer(),
            node_tags,
        )
