# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Sequence

import torch
from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
from executorch.backends.qualcomm.quantizer.quantizer import (
    get_16a8w_qnn_ptq_config,
    get_8a8w_qnn_ptq_config,
    get_ptq_per_channel_quant_config,
    QuantizationConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.quantizer import (
    QuantizationAnnotation,
    SharedQuantizationSpec,
)
from torch.fx import Node


def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
    """
    This function is specific for matmul op 16a8w.
    """

    def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
        input_qspec_map = {}
        input_act = node.args[0]
        input_spec = quantization_config.input_activation
        input_qspec_map[input_act] = input_spec

        input_act1 = node.args[1]
        input_spec1 = quantization_config.weight
        input_qspec_map[input_act1] = input_spec1

        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )

    def annotate_cat(node: Node, quantization_config: QuantizationConfig):
        input_nodes = node.args[0]

        first_input_node = input_nodes[0]
        input_qspec_map = {}
        input_qspec_map[first_input_node] = quantization_config.input_activation
        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
            (first_input_node, node)
        )

        for input_node in input_nodes[1:]:
            if input_node not in input_qspec_map:
                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec

        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=share_qparams_with_input_act0_qspec,
            _annotated=True,
        )

    def annotate_single_in_single_out(
        node: Node, quantization_config: QuantizationConfig
    ) -> None:

        input_qspec_map = {}
        input_act = node.args[0]
        input_qspec_map[input_act] = quantization_config.input_activation

        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )

    def annotate_matmul_input1(node: Node):
        quantization_config_8a8w = get_8a8w_qnn_ptq_config(
            act_symmetric=True, act_observer=MinMaxObserver
        )
        while isinstance(node, Node) and node.op == "call_function":
            if node.target in [
                torch.ops.aten.permute.default,
                torch.ops.aten.transpose.int,
            ]:
                annotate_single_in_single_out(node, quantization_config_8a8w)
                node = node.args[0]
            elif node.target == torch.ops.aten.cat.default:
                annotate_cat(node, quantization_config_8a8w)
                node = node.args[0][0]
            else:
                node = node.args[0]

    quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)

    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
            annotate_matmul(node, quantization_config_16a8w)
            annotate_matmul_input1(node.args[1])


def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None:  # noqa: C901
    """
    This function is specific for llama matmul op 16a8w.
    """

    def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
        input_qspec_map = {}
        input_act = node.args[0]
        input_spec = quantization_config.input_activation
        input_qspec_map[input_act] = input_spec
        input_act1 = node.args[1]
        input_spec1 = quantization_config.weight
        input_qspec_map[input_act1] = input_spec1
        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )

    def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
        input = node.args[0]
        value = node.args[2]
        input_qspec_map = {}
        input_qspec_map[input] = quantization_config.input_activation
        input_qspec_map[value] = SharedQuantizationSpec((input, node))
        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=SharedQuantizationSpec((input, node)),
            _annotated=True,
        )

    def annotate_single_in_single_out(
        node: Node, quantization_config: QuantizationConfig
    ) -> None:
        input_qspec_map = {}
        input_act = node.args[0]
        input_qspec_map[input_act] = quantization_config.input_activation
        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )

    def annotate_cat(node: Node, quantization_config: QuantizationConfig):
        input_nodes = node.args[0]
        assert isinstance(input_nodes, Sequence)
        first_input_node = input_nodes[0]
        input_qspec_map = {}
        assert isinstance(first_input_node, Node)
        assert isinstance(node, Node)
        input_qspec_map[first_input_node] = quantization_config.input_activation
        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
            (first_input_node, node)
        )
        for input_node in input_nodes[1:]:
            if input_node not in input_qspec_map:
                assert isinstance(input_node, Node)
                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=share_qparams_with_input_act0_qspec,
            _annotated=True,
        )

    def is_edge_condition(node: Node):
        if not isinstance(node, Node) or node.op != "call_function":
            return True
        return False

    def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig):
        if is_edge_condition(node):
            return
        if node.target in [
            torch.ops.aten.index_put.default,
            torch.ops.aten.index_put_.default,
        ]:
            annotate_index_put(node, quantization_config)
            annotate_matmul_input1(node.args[0], quantization_config)
        elif node.target == torch.ops.aten.cat.default:
            annotate_cat(node, quantization_config)
            # Expect that the inputs of the cat op are select ops
            for arg in node.args[0]:
                annotate_matmul_input1(arg, quantization_config)
        else:
            annotate_single_in_single_out(node, quantization_config)
            annotate_matmul_input1(node.args[0], quantization_config)

    # Annotate 16a8w for matmul op to get better performance
    quantization_config_16a8w = get_16a8w_qnn_ptq_config()
    # Annotate 8a8w for second input of matmul until past_kv_cache
    quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True)
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
            if "nn_module_stack" in node.meta:
                module_values_list = list(node.meta["nn_module_stack"].values())
                full_qualified_name = module_values_list[-1][0]
                if "SDPA" in full_qualified_name:
                    annotate_matmul(node, quantization_config_16a8w)
                    annotate_matmul_input1(node.args[1], quantization_config_8a8w)


def custom_annotate_llama_last_conv_16a8w(gm: torch.fx.GraphModule) -> None:
    def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
        input_qspec_map = {}
        input_act = node.args[0]
        input_spec = quantization_config.input_activation
        input_qspec_map[input_act] = input_spec

        weight = node.args[1]
        input_qspec_map[weight] = quantization_config.weight

        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )

    quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
        torch.uint16, weight_dtype=torch.int8
    )
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
            if "nn_module_stack" in node.meta:
                module_values_list = list(node.meta["nn_module_stack"].values())
                full_qualified_name = module_values_list[0][0]
                if full_qualified_name == "L['self'].llama.output":
                    annotate_conv2d(
                        node, quantization_config=quantization_config_16a8w_per_channel
                    )


def custom_annotate_matmul_16a8w(gm: torch.fx.GraphModule):
    """
    Annotate matmul op with 16a8w quantization config
    """

    def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
        input_qspec_map = {}
        input_act = node.args[0]
        input_spec = quantization_config.input_activation
        input_qspec_map[input_act] = input_spec
        input_act1 = node.args[1]
        input_spec1 = quantization_config.weight
        input_qspec_map[input_act1] = input_spec1
        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )

    # Annotate 16a8w for matmul op to get better performance
    quantization_config_16a8w = get_16a8w_qnn_ptq_config()
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
            annotate_matmul(node, quantization_config_16a8w)


def get_custom_quant_ios_dtype(
    cache_shape: torch.Size,
    node: torch.fx.Node,
    kv_dtype=torch.uint8,
    sharding_dtype=torch.uint16,
):
    """
    This function is specific for llama inputs and outputs
    """
    if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
        return kv_dtype

    # Tag index put node before copy node, because copy is a skipped node in qnn
    if (
        exir_ops.edge.aten.index_put.default == node.target
        and node.meta["val"].shape == cache_shape
    ):
        return kv_dtype

    # Tag sharding io
    if exir_ops.edge.llama.fallback.default in [
        u.target for u in list(node.users.keys())
    ] + [node.target]:
        return sharding_dtype

    # Tag index op as quantized tensors. It is caused by sharding
    if exir_ops.edge.aten.index.Tensor in [
        u.target for u in list(node.users.keys())
    ] + [node.target]:
        return sharding_dtype
