# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

#
# Utility functions for ArmQuantizer
#

import operator
from typing import Callable, cast, List

import torch
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch._subclasses import FakeTensor

from torch.ao.quantization.quantizer import (
    QuantizationAnnotation,
    SharedQuantizationSpec,
)
from torch.fx import GraphModule, Node


def is_annotated(node: Node) -> bool:
    """Given a node return whether the node is annotated."""
    return (
        "quantization_annotation" in node.meta
        and cast(
            QuantizationAnnotation, node.meta["quantization_annotation"]
        )._annotated
    )


def are_annotated(nodes: List[Node]) -> bool:
    """Given a list of nodes (that represents an operator pattern),
    return True if any of the nodes
    is annotated, otherwise return False.
    """
    for node in nodes:
        if is_annotated(node):
            return True
    return False


def mark_nodes_as_annotated(nodes: List[Node]) -> None:
    """Marks all nodes in list 'nodes' as annotated. If needed, an empty
    QuantizationAnnotation is added to the quantization_annotation node meta entry.
    """
    for node in nodes:
        if node is not None:
            if "quantization_annotation" not in node.meta:
                node.meta["quantization_annotation"] = QuantizationAnnotation()
            node.meta["quantization_annotation"]._annotated = True


def get_shared_qspec(
    node: Node, gm: GraphModule, quantization_config: QuantizationConfig
):
    """Returns a Quantization constallation with a SharedQuantizationSpec for the inputs
    and output to the parameter 'node'.
    Parameters:
        node: a node with two inputs that should share Quantization parameters.
        gm: The GraphModule containing the node. Used to inspect global graph features.
        quantization_config : a QuantizationConfig with the input QuantizationSpec to share
    Returns:
        input_qspec_map: a dict[node, QuantizationSpec] that maps the inputs to 'node' to
            the correct QuantizationSpec.
        shared_with_input0_spec: The SharedQuantizationSpec to be used as output QuantizationSpec.

        Both outputs are None if one of the inputs is a node that can't be quantized.
    """
    input_act0 = cast(Node, node.args[0])
    input_act1 = node.args[1]

    input_act_qspec = quantization_config.get_input_act_qspec()
    shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node))

    input_qspec_map = {}
    if isinstance(input_act0, Node):
        if not is_input_ok_for_quantization(input_act0, gm):
            return None, None
        input_qspec_map[input_act0] = input_act_qspec

    if isinstance(input_act1, Node):
        if not is_input_ok_for_quantization(input_act1, gm):
            return None, None
        if input_act0 is not input_act1:
            input_qspec_map[input_act1] = shared_with_input0_qspec
    return input_qspec_map, shared_with_input0_qspec


def is_input_ok_for_quantization(input_act: Node, gm: GraphModule):
    """Check if an input can be quantized. The input can not be quantized if:
    - The node does not output a float tensor or,
    - The node outputs a large scalar.
    """
    return not (
        is_input_non_float_tensor(input_act) or is_input_large_scalar(input_act, gm)
    )


def get_node_target(module: torch.nn.Module | GraphModule, target_str: str):
    targets = target_str.split(".")
    for target in targets[:-1]:
        module = module.get_submodule(target)
    return getattr(module, targets[-1])


def is_input_large_scalar(node: Node, gm: GraphModule):
    """Check if input is a large scalar value. So that we can skip quantization for the node
    since histc op (in HistogramObserver) only works for values up to certain upper bound
    """
    if node.op == "get_attr" and isinstance(node.target, str):
        tensor = get_node_target(gm, node.target)
        # torch.histc works until this upper bound
        HISTC_UPPER_BOUND = 3.4028235e15
        return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
    return False


def is_input_non_float_tensor(node: Node) -> bool:
    """Check if the input is not a float tensor, so that we can skip quantization for the node
    since observers only works with float Tensors
    """
    if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
        return True
    return node.meta["val"].dtype != torch.float32


def is_share_obs_or_fq_op(op: Callable) -> bool:
    """Returns whether the the operation 'op' can be quantized using a shared observer or
    fake quantizer. This means that the operation can inherit it's quantization spec
    from parent nodes.
    """
    return op in [
        torch.ops.aten.hardtanh.default,
        torch.ops.aten.hardtanh_.default,
        torch.ops.aten.relu.default,
        torch.ops.aten.mean.default,
        torch.ops.aten.mean.dim,
        torch.ops.aten.permute.default,
        torch.ops.aten.permute_copy.default,
        # TODO: remove?
        torch.ops.aten.adaptive_avg_pool2d.default,
        torch.ops.aten.avg_pool2d.default,
        torch.ops.aten.max_pool2d.default,
        torch.ops.aten.full.default,
        torch.ops.aten.flatten.using_ints,
        torch.ops.aten.dropout.default,
        operator.getitem,
    ]


def propagate_annotation(model: GraphModule) -> None:
    """For unannotated ops that can share observer or have fake quantizers,
    annotate with a SharedQuantizationSpec, where the shared spec is the
    output spec of the parent node.
    This propagates output qspecs downward in the graph until
    an op that is already annotated or can't share qspec is encountered.
    """
    for n in model.graph.nodes:
        n = cast(Node, n)
        if is_annotated(n):
            continue
        if n.op != "call_function" or not is_share_obs_or_fq_op(
            cast(Callable, n.target)
        ):
            continue

        prev_node = n.args[0]
        if not isinstance(prev_node, Node):
            continue

        quantization_annotation = cast(
            QuantizationAnnotation | None,
            prev_node.meta.get("quantization_annotation", None),
        )
        if not quantization_annotation or not quantization_annotation.output_qspec:
            continue

        # propagate the previous output_qspec to the current node
        shared_qspec = SharedQuantizationSpec(prev_node)
        n.meta["quantization_annotation"] = QuantizationAnnotation(
            input_qspec_map={
                prev_node: shared_qspec,
            },
            output_qspec=shared_qspec,
            _annotated=True,
        )
