# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import numbers
import operator
from functools import partial
from typing import Callable, Dict, List, Sequence, Tuple

import torch
from torch._ops import OpOverload

from torch._subclasses import FakeTensor
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize

from torch.ao.quantization.observer import FixedQParamsObserver
from torch.ao.quantization.quantizer import (
    DerivedQuantizationSpec,
    QuantizationAnnotation,
    QuantizationSpec,
    SharedQuantizationSpec,
)
from torch.ao.quantization.quantizer.utils import (
    _annotate_input_qspec_map,
    _annotate_output_qspec,
)
from torch.fx import Node

from .qconfig import (
    get_16a16w_qnn_ptq_config,
    get_16a4w_qnn_qat_config,
    get_8a8w_qnn_qat_config,
    QuantizationConfig,
)


QUANT_ANNOTATION_KEY = "quantization_annotation"
OP_ANNOTATOR: Dict[OpOverload, Callable] = {}


def register_annotator(ops: List[OpOverload]):
    def decorator(annotator: Callable):
        for op in ops:
            OP_ANNOTATOR[op] = annotator

    return decorator


def _is_annotated(nodes: List[Node]):
    """
    Given a list of nodes (that represents an operator pattern),
    return True if any of the node
    is annotated, otherwise return False
    """
    annotated = False
    for node in nodes:
        annotated = annotated or (
            QUANT_ANNOTATION_KEY in node.meta
            and node.meta[QUANT_ANNOTATION_KEY]._annotated
        )
    return annotated


def _is_float_tensor(node: Node):
    """Check if the node's tensor is a float tensor, so that we can skip quantization for the node
    since observers only works with float Tensors
    """
    if (
        not isinstance(node, Node)
        or "val" not in node.meta
        or not isinstance(node.meta["val"], FakeTensor)
    ):
        return False
    return node.meta["val"].dtype == torch.float32


def _mark_nodes_as_annotated(nodes: List[Node]):
    for node in nodes:
        if QUANT_ANNOTATION_KEY not in node.meta:
            node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation()
        node.meta[QUANT_ANNOTATION_KEY]._annotated = True


def annotate_in_out_obs_sharing_op(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    if _is_annotated([node]):
        return

    input_act = node.args[0]
    assert isinstance(input_act, Node)

    # only annotate input output sharing operator
    # when the output of the input node is annotated
    if (
        QUANT_ANNOTATION_KEY not in input_act.meta
        or not input_act.meta[QUANT_ANNOTATION_KEY]._annotated
        or input_act.meta[QUANT_ANNOTATION_KEY].output_qspec is None
    ):
        return

    act_qspec = SharedQuantizationSpec(input_act)
    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map={
            input_act: act_qspec,
        },
        output_qspec=act_qspec,
        _annotated=True,
    )


def annotate_single_in_single_out(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    if _is_annotated([node]):
        return

    input_qspec_map = {}
    input_act = node.args[0]
    assert isinstance(input_act, Node)
    input_qspec_map[input_act] = quantization_config.input_activation

    if _is_float_tensor(node):
        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )


@register_annotator([torch.ops.aten.topk.default])
def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None:
    if _is_annotated([node]):
        return
    # We can use single_in_single_out since we don't want to quantize indices output
    annotate_single_in_single_out(node, quantization_config)


def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None:
    if _is_annotated([node]):
        return

    input_act_qspec = quantization_config.input_activation
    output_act_qspec = (
        quantization_config.output_activation if _is_float_tensor(node) else None
    )

    input_qspec_map = {}
    input_act0 = node.args[0]
    if _is_float_tensor(input_act0):
        input_qspec_map[input_act0] = input_act_qspec

    input_act1 = node.args[1]
    if _is_float_tensor(input_act1):
        input_qspec_map[input_act1] = input_act_qspec

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        output_qspec=output_act_qspec,
        _annotated=True,
    )


@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor])
def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor])
def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_binary(node, quantization_config)


@register_annotator(
    [torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar]
)
def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_binary(node, quantization_config)


@register_annotator(
    [torch.ops.aten.div, torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor]
)
def annotate_div(node: Node, quantization_config: QuantizationConfig) -> None:
    def _derived_inp1_const_div_quant_spec(
        node: torch.fx.Node, output_qspec: QuantizationSpec
    ) -> DerivedQuantizationSpec:
        def _derive_div_qparams_fn(
            obs_or_fqs: List,
            const_val: float,
        ) -> Tuple[torch.Tensor, torch.Tensor]:
            inp_0_obs_or_fq = obs_or_fqs[0]
            inp_0_scale, inp_0_zp = inp_0_obs_or_fq.calculate_qparams()
            derived_scale = inp_0_scale / const_val
            return (derived_scale, inp_0_zp)

        inp_0 = node.args[0]
        const_inp_1 = node.args[1]
        _derive_div_qparams_with_const_fn = partial(
            _derive_div_qparams_fn, const_val=const_inp_1
        )

        q_min = (
            torch.iinfo(output_qspec.dtype).min
            if output_qspec.quant_min is None
            else output_qspec.quant_min
        )
        q_max = (
            torch.iinfo(output_qspec.dtype).max
            if output_qspec.quant_max is None
            else output_qspec.quant_max
        )
        return DerivedQuantizationSpec(
            derived_from=[(inp_0, node)],
            derive_qparams_fn=_derive_div_qparams_with_const_fn,
            dtype=output_qspec.dtype,
            quant_min=q_min,
            quant_max=q_max,
            ch_axis=0,
            qscheme=output_qspec.qscheme,
        )

    if [a for a in node.args if isinstance(a, Node)]:
        annotate_binary(node, quantization_config)
    # special constant divisor case
    elif isinstance(node.args[0], Node) and isinstance(node.args[1], numbers.Number):
        if _is_annotated([node]):
            return

        input_act_qspec = quantization_config.input_activation
        output_act_qspec = _derived_inp1_const_div_quant_spec(
            node, quantization_config.output_activation
        )
        input_qspec_map = {}
        input_act0 = node.args[0]
        if _is_float_tensor(input_act0):
            input_qspec_map[input_act0] = input_act_qspec

        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=output_act_qspec,
            _annotated=True,
        )
    else:
        raise NotImplementedError(f"No quant annotation is implemented for {node}.")


@register_annotator([torch.ops.aten.rsub.Scalar])
def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.sum.dim_IntList])
def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.ceil.default])
def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.clamp.default])
def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.tanh.default])
def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator(
    [torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default]
)
def annotate_hardswish(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator(
    [torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardsigmoid_.default]
)
def annotate_hardsigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default])
def annotate_hardtanh(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.mean.default])
def annotate_mean(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.max_pool2d.default])
def annotate_max_pool2d(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.max_pool2d_with_indices.default])
def annotate_max_pool2d_with_indices(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
def annotate_adaptive_avgpool2d(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.avg_pool2d.default])
def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.permute.default])
def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_in_out_obs_sharing_op(node, quantization_config)
    if not _is_annotated([node]):
        annotate_single_in_single_out(node, quantization_config)


@register_annotator(
    [
        torch.ops.aten.leaky_relu.default,
        torch.ops.aten.leaky_relu_.default,
        torch.ops.aten.prelu.default,
    ]
)
def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default])
def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_in_out_obs_sharing_op(node, quantization_config)
    if not _is_annotated([node]):
        annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.pixel_shuffle.default])
def annotate_pixel_shuffle_default(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.pixel_unshuffle.default])
def annotate_pixel_unshuffle_default(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.upsample_bilinear2d.vec])
def annotate_upsample_bilinear2d(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.upsample_nearest2d.vec])
def annotate_upsample_nearest2d(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator(
    [
        torch.ops.aten.softmax.int,
        torch.ops.aten._softmax.default,
        torch.ops.aten._safe_softmax.default,
    ]
)
def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.log_softmax.int])
def annotate_log_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.pad.default])
def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.reshape.default])
def annotate_reshape(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.select.int])
def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.mean.dim])
def annotate_mean_dim(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.slice.Tensor])
def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.sqrt.default])
def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.gelu.default])
def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.scaled_dot_product_attention.default])
def annotate_scaled_dot_product_attention(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator(
    [
        torch.ops.aten.squeeze.default,
        torch.ops.aten.squeeze.dim,
        torch.ops.aten.squeeze_copy.dims,
    ]
)
def annotate_squeeze(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_in_out_obs_sharing_op(node, quantization_config)
    if not _is_annotated([node]):
        annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.rms_norm.default])
def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None:
    act_node = node.args[0]
    weight_node = node.args[2]

    if _is_annotated([node]):
        return

    # TODO current only support 16a16w
    _annotate_input_qspec_map(
        node,
        act_node,
        quantization_config.input_activation,
    )

    _annotate_input_qspec_map(
        node,
        weight_node,
        quantization_config.input_activation,
    )
    nodes_to_mark_annotated = [node]
    _annotate_output_qspec(node, quantization_config.output_activation)
    _mark_nodes_as_annotated(nodes_to_mark_annotated)


@register_annotator([torch.ops.aten.rsqrt.default])
def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default])
def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
    if _is_annotated([node]):
        return

    input_qspec_map = {}
    input_act = node.args[0]
    input_qspec_map[input_act] = quantization_config.input_activation

    assert isinstance(input_act, Node)
    out_qconf = quantization_config.output_activation

    q_max = (
        torch.iinfo(out_qconf.dtype).max
        if out_qconf.quant_max is None
        else out_qconf.quant_max
    )
    q_min = (
        torch.iinfo(out_qconf.dtype).min
        if out_qconf.quant_min is None
        else out_qconf.quant_min
    )

    scale = 1 / (q_max - q_min + 1)

    bias_obs_ctr = observer = FixedQParamsObserver.with_args(
        scale=scale,
        zero_point=0,
        dtype=quantization_config.output_activation.dtype,
        qscheme=torch.torch.per_tensor_affine,
        quant_max=q_max,
        quant_min=q_min,
    )
    if quantization_config in (
        get_8a8w_qnn_qat_config(),
        get_16a4w_qnn_qat_config(),
    ):
        bias_obs_ctr = FixedQParamsFakeQuantize.with_args(
            observer=observer,
            scale=scale,
            zero_point=0,
            dtype=quantization_config.output_activation.dtype,
            qscheme=torch.torch.per_tensor_affine,
            quant_max=q_max,
            quant_min=q_min,
        )

    # make sigmoid map to the range between 0~1
    out_act_quantization_spec = QuantizationSpec(
        dtype=quantization_config.output_activation.dtype,
        quant_max=q_max,
        quant_min=q_min,
        observer_or_fake_quant_ctr=bias_obs_ctr,
        qscheme=torch.torch.per_tensor_affine,
    )

    if _is_float_tensor(node):
        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=out_act_quantization_spec,
            _annotated=True,
        )


@register_annotator([torch.ops.aten.pow.Tensor_Scalar])
def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.unsqueeze.default])
def annotate_unsqueeze(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_in_out_obs_sharing_op(node, quantization_config)
    if not _is_annotated([node]):
        annotate_single_in_single_out(node, quantization_config)


@register_annotator(
    [
        torch.ops.aten.unsqueeze_copy.default,
    ]
)
def annotate_unsqueeze_copy(
    node: Node, quantization_config: QuantizationConfig
) -> None:
    annotate_in_out_obs_sharing_op(node, quantization_config)
    if not _is_annotated([node]):
        annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.transpose.int])
def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_in_out_obs_sharing_op(node, quantization_config)
    if not _is_annotated([node]):
        annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.embedding.default])
def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None:
    weight = node.args[0]

    input_qspec_map = {}
    input_qspec_map[weight] = quantization_config.input_activation

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        output_qspec=SharedQuantizationSpec((weight, node)),
        _annotated=True,
    )


@register_annotator([torch.ops.aten.index.Tensor])
def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_in_out_obs_sharing_op(node, quantization_config)
    if not _is_annotated([node]):
        input_qspec_map = {}
        input = node.args[0]
        input_qspec_map[input] = quantization_config.input_activation
        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=SharedQuantizationSpec((input, node)),
            _annotated=True,
        )


@register_annotator(
    [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default]
)
def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
    input = node.args[0]
    value = node.args[2]

    input_qspec_map = {}
    input_qspec_map[input] = quantization_config.input_activation
    input_qspec_map[value] = SharedQuantizationSpec((input, node))

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        output_qspec=SharedQuantizationSpec((input, node)),
        _annotated=True,
    )


@register_annotator([torch.ops.aten.expand.default])
def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_in_out_obs_sharing_op(node, quantization_config)
    if not _is_annotated([node]):
        annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.group_norm.default])
def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> None:
    act_node = node.args[0]
    weight_node = node.args[2]
    bias_node = None
    if len(node.args) > 2:
        bias_node = node.args[3]

    if _is_annotated([node]):
        return

    _annotate_input_qspec_map(
        node,
        act_node,
        quantization_config.input_activation,
    )
    _annotate_input_qspec_map(
        node,
        weight_node,
        quantization_config.weight,
    )
    nodes_to_mark_annotated = [node, weight_node]
    if bias_node:
        _annotate_input_qspec_map(
            node,
            bias_node,
            quantization_config.bias,
        )
        nodes_to_mark_annotated.append(bias_node)
    _annotate_output_qspec(node, quantization_config.output_activation)
    _mark_nodes_as_annotated(nodes_to_mark_annotated)


@register_annotator([torch.ops.aten.flatten.using_ints])
def annotate_flatten(node: Node, quantization_config: QuantizationConfig) -> None:
    annotate_in_out_obs_sharing_op(node, quantization_config)
    if not _is_annotated([node]):
        annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.stack.default])
def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None:
    input_qspec_map = {}
    for input_act in node.args[0]:
        assert isinstance(input_act, Node)
        input_qspec_map[input_act] = quantization_config.input_activation

        node_tensor = node.meta.get("val")
        if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
            continue

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        output_qspec=quantization_config.output_activation,
        _annotated=True,
    )


@register_annotator([torch.ops.aten.matmul.default])
def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None:
    if _is_annotated([node]):
        return

    input_act_qspec = quantization_config.input_activation
    output_act_qspec = quantization_config.output_activation

    input_qspec_map = {}
    input_act0 = node.args[0]
    if isinstance(input_act0, Node):
        input_qspec_map[input_act0] = input_act_qspec

    input_act1 = node.args[1]
    if isinstance(input_act1, Node):
        # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized.
        if input_act_qspec.dtype == torch.int32:
            # we should use int16 for mm / bmm instead of int4
            input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight
        else:
            input_qspec_map[input_act1] = input_act_qspec

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        output_qspec=output_act_qspec,
        _annotated=True,
    )


@register_annotator([torch.ops.aten.bmm.default])
def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None:
    if _is_annotated([node]):
        return

    input_act_qspec = quantization_config.input_activation
    output_act_qspec = quantization_config.output_activation

    input_qspec_map = {}
    input_act0 = node.args[0]
    if isinstance(input_act0, Node):
        input_qspec_map[input_act0] = input_act_qspec

    input_act1 = node.args[1]
    if isinstance(input_act1, Node):
        # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized.
        if input_act_qspec.dtype == torch.int32:
            # we should use int16 for mm / bmm instead of int4
            input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight
        else:
            input_qspec_map[input_act1] = input_act_qspec

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        output_qspec=output_act_qspec,
        _annotated=True,
    )

    # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
    node.meta["source_fn_stack"] = [(node, torch.bmm)]


@register_annotator(
    [
        torch.ops.aten.conv2d.default,
        torch.ops.aten.conv1d.default,
        torch.ops.aten.conv_transpose2d.input,
    ]
)
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
    if _is_annotated([node]):
        return

    input_qspec_map = {}
    input_act = node.args[0]
    assert isinstance(input_act, Node)
    input_spec = quantization_config.input_activation
    input_qspec_map[input_act] = input_spec

    weight = node.args[1]
    assert isinstance(weight, Node)
    input_qspec_map[weight] = quantization_config.weight

    if len(node.args) > 2:
        bias = node.args[2]
        if isinstance(bias, Node):
            if callable(quantization_config.bias):
                input_qspec_map[bias] = quantization_config.bias(node)
            else:
                input_qspec_map[bias] = quantization_config.bias

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        output_qspec=quantization_config.output_activation,
        _annotated=True,
    )


@register_annotator([torch.ops.aten.linear.default])
def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None:
    act_node = node.args[0]
    weight_node = node.args[1]
    bias_node = None
    if len(node.args) > 2:
        bias_node = node.args[2]

    if _is_annotated([node]):
        return

    _annotate_input_qspec_map(
        node,
        act_node,
        quantization_config.input_activation,
    )
    _annotate_input_qspec_map(
        node,
        weight_node,
        quantization_config.weight,
    )
    nodes_to_mark_annotated = [node, weight_node]
    if bias_node:
        if callable(quantization_config.bias):
            bias_config = quantization_config.bias(node)
        else:
            bias_config = quantization_config.bias
        _annotate_input_qspec_map(node, bias_node, bias_config)
        nodes_to_mark_annotated.append(bias_node)
    _annotate_output_qspec(node, quantization_config.output_activation)
    _mark_nodes_as_annotated(nodes_to_mark_annotated)

    # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
    node.meta["source_fn_stack"] = [(node, torch.nn.Linear)]


@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default])
def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> None:
    act, weight, bias = node.args[0:3]
    if _is_annotated([node]):
        return

    _annotate_input_qspec_map(
        node,
        act,
        quantization_config.input_activation,
    )
    # QNN requires uint8 instead of int8 in 'weight' config
    _annotate_input_qspec_map(
        node,
        weight,
        quantization_config.input_activation,
    )
    _annotate_input_qspec_map(
        node,
        bias,
        quantization_config.bias,
    )
    _annotate_output_qspec(node, quantization_config.output_activation)
    _mark_nodes_as_annotated([node, *node.args[0:3]])


@register_annotator([operator.getitem])
def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> None:
    if _is_annotated([node]):
        return

    if _is_float_tensor(node):
        _annotate_output_qspec(node, quantization_config.output_activation)
        _mark_nodes_as_annotated([node])


@register_annotator([torch.ops.aten.layer_norm.default])
def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None:
    act_node = node.args[0]
    weight_node = node.args[2]
    bias_node = None
    if len(node.args) > 2:
        bias_node = node.args[3]

    if _is_annotated([node]):
        return
    input_act_qspec = quantization_config.input_activation

    _annotate_input_qspec_map(
        node,
        act_node,
        input_act_qspec,
    )
    if input_act_qspec.dtype == torch.int32:
        _annotate_input_qspec_map(
            node,
            weight_node,
            get_16a16w_qnn_ptq_config().weight,
        )
    else:
        _annotate_input_qspec_map(
            node,
            weight_node,
            input_act_qspec,
        )
    nodes_to_mark_annotated = [node, weight_node]
    if bias_node:
        _annotate_input_qspec_map(
            node,
            bias_node,
            quantization_config.bias,
        )
        nodes_to_mark_annotated.append(bias_node)
    _annotate_output_qspec(node, quantization_config.output_activation)
    _mark_nodes_as_annotated(nodes_to_mark_annotated)


@register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default])
def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None:
    input_nodes = node.args[0]
    if _is_annotated([node]):
        return

    assert isinstance(input_nodes, Sequence)

    first_input_node = input_nodes[0]
    input_qspec_map = {}
    assert isinstance(first_input_node, Node)
    assert isinstance(node, Node)
    input_qspec_map[first_input_node] = quantization_config.input_activation
    share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
        (first_input_node, node)
    )

    for input_node in input_nodes[1:]:
        if input_node not in input_qspec_map:
            assert isinstance(input_node, Node)
            input_qspec_map[input_node] = share_qparams_with_input_act0_qspec

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        output_qspec=share_qparams_with_input_act0_qspec,
        _annotated=True,
    )


@register_annotator([torch.ops.aten.unbind.int])
def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None:
    if _is_annotated([node]):
        return

    input_qspec_map = {}
    input_act = node.args[0]
    assert isinstance(input_act, Node)
    input_qspec_map[input_act] = quantization_config.input_activation

    node_tensor = node.meta.get("val")
    if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
        return

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        _annotated=True,
    )


@register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default])
def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
    if _is_annotated([node]):
        return

    input_qspec_map = {}
    input_act = node.args[0]
    assert isinstance(input_act, Node)
    input_qspec_map[input_act] = quantization_config.input_activation

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        _annotated=True,
    )

    for user in node.users:
        user.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )
