# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.

from typing import Callable, List

import torch
from torch._ops import OpOverload
from torch._subclasses import FakeTensor

from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.ao.quantization.quantizer.utils import (
    _annotate_input_qspec_map,
    _annotate_output_qspec,
)

from torch.export import export_for_training
from torch.fx import Graph, Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
    SubgraphMatcherWithNameNodeMap,
)

from .qconfig import QuantizationConfig


OP_TO_ANNOTATOR = {}


def annotate(graph: Graph, quant_config: QuantizationConfig) -> None:
    # Pattern annotation
    _annotate_rmsnorm_pattern(graph, quant_config)
    _annotate_fused_activation_pattern(graph, quant_config)

    # Per-op annotation
    for node in graph.nodes:
        if node.op == "placeholder":
            annotate_placeholder(node, quant_config)
        elif node.op == "call_function":
            annotate_func = OP_TO_ANNOTATOR.get(node.target, None)
            if annotate_func is not None:
                annotate_func(node, quant_config)


def register_annotator(ops: List[OpOverload]):

    def decorator(annotator_fn: Callable):
        for op in ops:
            OP_TO_ANNOTATOR[op] = annotator_fn

    return decorator


def _is_annotated(node: Node):
    """
    Given a list of nodes (that represents an operator pattern),
    return True if any of the node
    is annotated, otherwise return False
    """
    KEY = "quantization_annotation"
    return KEY in node.meta and node.meta[KEY]._annotated


def _mark_as_annotated(nodes: List[Node]):
    KEY = "quantization_annotation"
    for node in nodes:
        if KEY not in node.meta:
            node.meta[KEY] = QuantizationAnnotation()
        node.meta[KEY]._annotated = True


def _is_float_activation_tensor(node: Node):
    if not isinstance(node, Node):
        return False
    if "val" not in node.meta:
        return False
    if not isinstance(node.meta["val"], FakeTensor):
        return False
    return node.meta["val"].dtype == torch.float32


def _annotate_fused_activation_pattern(
    graph: Graph, quant_config: QuantizationConfig
) -> None:
    for relu_node in graph.nodes:
        # Check relu/relu6 node
        if relu_node.op != "call_function":
            continue
        if relu_node.target not in [
            torch.ops.aten.relu.default,
            torch.ops.aten.relu_.default,
            torch.ops.aten.relu6.default,
        ]:
            continue

        producer_node = relu_node.args[0]
        if not isinstance(producer_node, Node):
            continue
        if producer_node.op != "call_function":
            continue
        if len(producer_node.users) != 1:
            continue

        # Handle affine + relu fusion
        if producer_node.target in [
            torch.ops.aten.conv1d.default,
            torch.ops.aten.conv2d.default,
            torch.ops.aten.linear.default,
        ]:
            weight_node = producer_node.args[1]
            _annotate_input_qspec_map(
                producer_node,
                weight_node,
                quant_config.weight,
            )
            _annotate_output_qspec(relu_node, quant_config.activation)
            _mark_as_annotated([producer_node, weight_node, relu_node])
            continue

        # Handle arithmetic + relu fusion
        if producer_node.target in [
            torch.ops.aten.add.Scalar,
            torch.ops.aten.add.Tensor,
            torch.ops.aten.add_.Scalar,
            torch.ops.aten.add_.Tensor,
            torch.ops.aten.div.Scalar,
            torch.ops.aten.div.Tensor,
            torch.ops.aten.div_.Scalar,
            torch.ops.aten.div_.Tensor,
            torch.ops.aten.divide.Scalar,
            torch.ops.aten.divide.Tensor,
            torch.ops.aten.mul.Scalar,
            torch.ops.aten.mul.Tensor,
            torch.ops.aten.mul_.Scalar,
            torch.ops.aten.mul_.Tensor,
            torch.ops.aten.rsub.Scalar,
            torch.ops.aten.rsub.Tensor,
            torch.ops.aten.sub.Scalar,
            torch.ops.aten.sub.Tensor,
            torch.ops.aten.sub_.Scalar,
            torch.ops.aten.sub_.Tensor,
        ]:
            _annotate_output_qspec(relu_node, quant_config.activation)
            _mark_as_annotated([producer_node, relu_node])
            continue


def _annotate_rmsnorm_pattern(graph: Graph, quant_config: QuantizationConfig) -> None:

    class ExecuTorchPattern(torch.nn.Module):
        def forward(self, x):
            norm = x * torch.rsqrt((x * x).mean(-1, keepdim=True) + 1e-6)
            return norm, {}

    class MTKPattern(torch.nn.Module):
        def forward(self, x):
            norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6)
            return norm, {}

    for pattern_cls in (ExecuTorchPattern, MTKPattern):
        pattern_gm = export_for_training(pattern_cls(), (torch.randn(3, 3),)).module()
        matcher = SubgraphMatcherWithNameNodeMap(
            pattern_gm, ignore_literals=True, remove_overlapping_matches=False
        )
        matches = matcher.match(graph)
        for match in matches:
            target_nodes = []
            for node in match.nodes_map.values():
                if node in match.placeholder_nodes:
                    continue
                if node.op == "call_function" and node.target in OP_TO_ANNOTATOR:
                    target_nodes.append(node)

            if any(_is_annotated(node) for node in target_nodes):
                continue
            _mark_as_annotated(target_nodes)
            for node in match.returning_nodes:
                _annotate_output_qspec(node, quant_config.activation)


def annotate_placeholder(node: Node, quant_config: QuantizationConfig) -> None:
    if _is_annotated(node):
        return

    if _is_float_activation_tensor(node):
        _annotate_output_qspec(node, quant_config.activation)

    _mark_as_annotated([node])


@register_annotator(
    [
        torch.ops.aten.conv1d.default,
        torch.ops.aten.conv2d.default,
        torch.ops.aten.linear.default,
    ]
)
def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None:
    if _is_annotated(node):
        return

    weight_node = node.args[1]
    _annotate_input_qspec_map(
        node,
        weight_node,
        quant_config.weight,
    )
    _annotate_output_qspec(node, quant_config.activation)

    # Make weight as annotated because it is a constant node
    _mark_as_annotated([node, weight_node])


@register_annotator(
    [
        torch.ops.aten.add.Scalar,
        torch.ops.aten.add.Tensor,
        torch.ops.aten.add_.Scalar,
        torch.ops.aten.add_.Tensor,
        torch.ops.aten.bmm.default,
        torch.ops.aten.div.Scalar,
        torch.ops.aten.div.Tensor,
        torch.ops.aten.div_.Scalar,
        torch.ops.aten.div_.Tensor,
        torch.ops.aten.divide.Scalar,
        torch.ops.aten.divide.Tensor,
        torch.ops.aten.gelu.default,
        torch.ops.aten.group_norm.default,
        torch.ops.aten.layer_norm.default,
        torch.ops.aten.leaky_relu.default,
        torch.ops.aten.matmul.default,
        torch.ops.aten.mul.Scalar,
        torch.ops.aten.mul.Tensor,
        torch.ops.aten.mul_.Scalar,
        torch.ops.aten.mul_.Tensor,
        torch.ops.aten.pow.Scalar,
        torch.ops.aten.pow.Tensor_Scalar,
        torch.ops.aten.pow.Tensor_Tensor,
        torch.ops.aten.prelu.default,
        torch.ops.aten.rsub.Scalar,
        torch.ops.aten.rsub.Tensor,
        torch.ops.aten.silu.default,
        torch.ops.aten.sub.Scalar,
        torch.ops.aten.sub.Tensor,
        torch.ops.aten.sub_.Scalar,
        torch.ops.aten.sub_.Tensor,
    ]
)
def annotate_output_qspec(node: Node, quant_config: QuantizationConfig) -> None:
    if _is_annotated(node):
        return
    _annotate_output_qspec(node, quant_config.activation)
    _mark_as_annotated([node])


@register_annotator([torch.ops.aten.embedding.default])
def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None:
    if _is_annotated(node):
        return

    wgt_node = node.args[0]
    _annotate_input_qspec_map(node, wgt_node, quant_config.activation)
    _mark_as_annotated([node])
