# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections import Counter
from typing import Callable, List

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
from executorch.backends.transforms.addmm_mm_to_linear import (
    apply_addmm_mm_to_linear_transform,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload as edge_op
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass

from torch.fx.passes.utils.source_matcher_utils import (
    get_source_partitions,
    SourcePartition,
)

from .utils import dq_ops, get_quant_attrs, q_ops


class ConvertToLinear(ExportPass):
    """
    Handle missing quantization tag for addmm op after decomposing
    """

    view_copy = exir_ops.edge.aten.view_copy.default
    permute_copy = exir_ops.edge.aten.permute_copy.default
    expand_copy = exir_ops.edge.aten.expand_copy.default
    linear = exir_ops.edge.aten.linear.default
    add = exir_ops.edge.aten.add.Tensor
    addmm = exir_ops.edge.aten.addmm.default
    bmm = exir_ops.edge.aten.bmm.default
    mm = exir_ops.edge.aten.mm.default

    addmm_patterns = [
        {view_copy: 2, permute_copy: 1, addmm: 1},
        {permute_copy: 1, addmm: 1},
    ]

    bmm_patterns = [
        {view_copy: 3, permute_copy: 1, expand_copy: 2, add: 1, bmm: 1},
        {view_copy: 3, permute_copy: 1, expand_copy: 2, bmm: 1},
    ]

    mm_patterns = [
        {view_copy: 2, permute_copy: 1, mm: 1},
        {permute_copy: 1, mm: 1},
    ]

    def __init__(self):
        super(ConvertToLinear, self).__init__()

    def _get_original_input(
        self, inputs: List[torch.fx.Node], cur_node: torch.fx.Node
    ) -> torch.fx.Node:
        while cur_node not in inputs and cur_node.args:
            cur_node = cur_node.args[0]
        return cur_node

    def _convert_to_linear(
        self,
        gm: torch.fx.GraphModule,
        src_partition: SourcePartition,
        extract_ops_fn: Callable,
    ):
        inputs = src_partition.input_nodes
        # output_nodes contains output node and input buffer such as argX_X
        outputs = [
            node
            for node in src_partition.output_nodes
            if node.target != torch.ops.aten.sym_size.int and node.op != "placeholder"
        ]
        assert (
            len(outputs) == 1
        ), f"Unexpected number of outputs for a torch.nn.Linear module, expecting 1 but got {outputs}"
        output = outputs[0]

        ops = extract_ops_fn(src_partition.nodes)
        input_node, weight_node, fn_node = ops[:3]
        bias_node = None if len(ops) == 3 else ops[3]

        # qnn htp does not support keepdim, the view_copy(reshape) should exist for now
        if self._get_original_input(inputs, input_node).target in dq_ops:
            input_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs(
                gm, self._get_original_input(inputs, input_node).args[0]
            )
        args = [input_node, weight_node]
        if bias_node:
            args.append(bias_node)

        # We need a view copy node after linear op
        with gm.graph.inserting_before(output):
            linear_node = gm.graph.create_node(
                "call_function", self.linear, tuple(args)
            )
            linear_node.meta = fn_node.meta
            if list(output.users)[0].target in q_ops:
                linear_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs(
                    gm, list(output.users)[0]
                )
            for user in fn_node.users.copy():
                user.replace_input_with(fn_node, linear_node)

        # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
        # TODO: Find a more general conditional statement.
        linear_output = linear_node.meta["val"]
        if linear_output.dim() == 3 and linear_output.shape[0] == 1:
            with gm.graph.inserting_after(input_node):
                input_users = list(input_node.users.keys())
                input_tensor = input_node.meta["val"]
                squeeze_dim = input_tensor.shape[-2:]
                squeeze_node = gm.graph.create_node(
                    "call_function",
                    self.view_copy,
                    (
                        input_node,
                        squeeze_dim,
                    ),
                )
                # meta needs to be copied elementwisely for fake-tensor
                # to be updated correctly and not affect meta of input_node
                for k, v in input_node.meta.items():
                    squeeze_node.meta[k] = v
                squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim)
                for user in input_users:
                    if user == linear_node:
                        user.replace_input_with(input_node, squeeze_node)

            with gm.graph.inserting_after(linear_node):
                output_users = list(linear_node.users.keys())
                unsqueeze_dim = linear_output.shape
                unsqueeze_node = gm.graph.create_node(
                    "call_function",
                    self.view_copy,
                    (
                        linear_node,
                        unsqueeze_dim,
                    ),
                )
                # meta needs to be copied elementwisely for fake-tensor
                # to be updated correctly and not affect meta of unsqueeze_node
                for k, v in linear_node.meta.items():
                    unsqueeze_node.meta[k] = v
                # update linear node's shape
                linear_node.meta["val"] = linear_output.reshape(
                    linear_output.shape[-2:]
                )
                for user in output_users:
                    user.replace_input_with(linear_node, unsqueeze_node)

    def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]:
        mm_node = [n for n in partitioned_nodes if n.target == self.mm][0]
        # weight -> permute -> input of mm
        weight_node = mm_node.args[1].args[0]
        input_node = mm_node.args[0]
        return [input_node, weight_node, mm_node]

    def _extract_addmm_ops(
        self, partitioned_nodes: List[edge_op]
    ) -> List[torch.fx.Node]:
        addmm_node = [n for n in partitioned_nodes if n.target == self.addmm][0]
        # weight -> permute -> input of addmm
        weight_node = addmm_node.args[2].args[0]
        input_node = addmm_node.args[1]
        bias_node = addmm_node.args[0]
        return [input_node, weight_node, addmm_node, bias_node]

    def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]:
        bmm_node = [n for n in partitioned_nodes if n.target == self.bmm][0]
        add_node = [n for n in partitioned_nodes if n.target == self.add]

        # weight -> expand_copy -> view_copy -> input of bmm
        weight_node = bmm_node.args[1].args[0].args[0].args[0]
        # input -> expand_copy -> view_copy -> input of bmm
        input_node = bmm_node.args[0].args[0].args[0]

        ret = [input_node, weight_node, bmm_node]
        if add_node:
            bias_node = add_node[0].args[1]
            ret = [input_node, weight_node, add_node[0], bias_node]
        else:
            ret = [input_node, weight_node, bmm_node]

        return ret

    def _convert(self, graph_module: torch.fx.GraphModule):
        partitions = get_source_partitions(graph_module.graph, [torch.nn.Linear])
        for _, src_partitions in partitions.items():
            for src_partition in src_partitions:
                op_cnt = Counter(
                    [
                        n.target
                        for n in src_partition.nodes
                        if isinstance(n.target, edge_op)
                    ]
                )
                if self.linear in op_cnt:
                    continue
                elif op_cnt in self.addmm_patterns:
                    self._convert_to_linear(
                        graph_module, src_partition, self._extract_addmm_ops
                    )
                elif op_cnt in self.mm_patterns:
                    self._convert_to_linear(
                        graph_module, src_partition, self._extract_mm_ops
                    )
                elif op_cnt in self.bmm_patterns:
                    self._convert_to_linear(
                        graph_module, src_partition, self._extract_bmm_ops
                    )
                else:
                    raise AssertionError(
                        "Found a new pattern needed be converted to linear op"
                    )

    def call(self, graph_module: torch.fx.GraphModule):
        self._convert(graph_module)
        # We could not use get_source_partitions because it is the same source for MultiheadAttention
        apply_addmm_mm_to_linear_transform(graph_module.graph)
        dead_code_elimination_pass(graph_module)
        graph_module.recompile()
        return PassResult(graph_module, True)
