# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools
from typing import Callable, List, Optional

import torch
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


@register_annotator("mm")
def _annotate_mm(
    gm: torch.fx.GraphModule,
    quantization_config: QuantizationConfig,
    filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
    mm_partitions = get_source_partitions(
        gm.graph, [torch.mm, torch.bmm, torch.matmul], filter_fn
    )
    mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values()))
    annotated_partitions = []
    for mm_partition in mm_partitions:
        annotated_partitions.append(mm_partition.nodes)
        mm_node = mm_partition.output_nodes[0]

        if arm_quantizer_utils.is_annotated(mm_node):
            continue

        input_act_qspec = quantization_config.get_input_act_qspec()
        output_act_qspec = quantization_config.get_output_act_qspec()

        input_qspec_map = {}
        input_act0 = mm_node.args[0]
        if isinstance(input_act0, Node):
            if not arm_quantizer_utils.is_input_ok_for_quantization(input_act0, gm):
                continue
            input_qspec_map[input_act0] = input_act_qspec

        input_act1 = mm_node.args[1]
        if isinstance(input_act1, Node):
            if not arm_quantizer_utils.is_input_ok_for_quantization(input_act1, gm):
                continue
            input_qspec_map[input_act1] = input_act_qspec

        mm_node.meta["quantization_annotation"] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=output_act_qspec,
            _annotated=True,
        )
    return annotated_partitions
