# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import Callable, List, Optional

import torch
import torch.fx
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import SharedQuantizationSpec
from torch.ao.quantization.quantizer.utils import (
    _annotate_input_qspec_map,
    _annotate_output_qspec,
)
from torch.fx import Node


_SUPPORTED_OPS = [
    # DATA LAYOUT OPS
    torch.ops.aten.squeeze.default,
    torch.ops.aten.squeeze_copy.default,
    torch.ops.aten.squeeze_copy.dim,
    torch.ops.aten.squeeze.dim,
    torch.ops.aten.squeeze.dims,
    torch.ops.aten.unsqueeze.default,
    torch.ops.aten.unsqueeze_copy.default,
    torch.ops.aten.reshape.default,
    torch.ops.aten.repeat.default,
    torch.ops.aten.expand_copy.default,
    torch.ops.aten.expand.default,
    # Disabling these as there seems to be an issue with support for complex
    # datatypes in torch:
    # torch.ops.aten.view_as_complex.default,
    # torch.ops.aten.view_as_complex_copy.default,
    # torch.ops.aten.view_as_real.default,
    # torch.ops.aten.view_as_real_copy.default,
    torch.ops.aten.view.default,
    torch.ops.aten.view_as.default,
    torch.ops.aten.view_copy.default,
    torch.ops.aten.select.int,
    torch.ops.aten.select_copy.int,
    torch.ops.aten.slice.Tensor,
    torch.ops.aten.slice_copy.Tensor,
    torch.ops.aten.split.Tensor,
    torch.ops.aten.split_with_sizes.default,
    torch.ops.aten.transpose.Dimname,
    torch.ops.aten.transpose.int,
    torch.ops.aten.transpose_copy.int,
    torch.ops.aten.tile.default,
    torch.ops.aten.flip.default,
    torch.ops.aten.cat.default,
    torch.ops.aten.stack.default,
    torch.ops.aten.chunk.default,
    torch.ops.aten.contiguous.default,
]


@register_annotator("generic")
def _annotate_generic(
    gm: torch.fx.GraphModule,
    quantization_config: QuantizationConfig,
    filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
    """Propagate qspecs to generic ops like unsqueeze, reshape etc."""
    annotated_partitions = []

    for node in gm.graph.nodes:
        if node.op != "call_function" or node.target not in _SUPPORTED_OPS:
            continue
        if filter_fn and not filter_fn(node):
            continue
        if arm_quantizer_utils.is_annotated(node):
            continue

        input_acts = node.args[0]

        # Check to see if there are multiple inputs.
        # this allows for stack/cat ops to be annotated
        # in a similar way.
        has_multi_inputs = isinstance(input_acts, list)

        input_act0 = input_acts[0] if has_multi_inputs else input_acts

        # Using a non-shared quantization spec here as a SharedQuantizationSpec
        # can lead to a recursion.
        _annotate_input_qspec_map(
            node, input_act0, quantization_config.get_input_act_qspec()
        )
        shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node))

        if has_multi_inputs:
            # For the rest of the inputs, share qspec with first.
            for input_act in input_acts[1:]:
                if input_act is not input_act0:
                    node.meta["quantization_annotation"].input_qspec_map[
                        input_act
                    ] = shared_with_input0_qspec

        _annotate_output_qspec(node, shared_with_input0_qspec)
        arm_quantizer_utils.mark_nodes_as_annotated([node])
        annotated_partitions.append([node])

    return annotated_partitions
