# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import _operator
from typing import List, Tuple

import torch

from executorch.backends.qualcomm.builders.utils import is_parameter
from executorch.backends.qualcomm.utils.constants import (
    QCOM_AXIS_ORDER,
    QCOM_INSERTED_PERMUTE,
    QCOM_LAYOUT_CHANGE,
    QCOM_QUANT_ATTRS,
    QCOM_REQUANTIZE,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.sym_util import eval_shape

from .utils import dq_ops, q_ops


class LayoutTransform(ExportPass):
    """
    QNN delegate requires channel last layout format, this pass aims to
    help generate the correct transformation by inserting fewest ammount of
    'permute' operators in the graph.
    """

    layout_sensitive_ops = {
        exir_ops.edge.aten.avg_pool2d.default,
        exir_ops.edge.aten.convolution.default,
        exir_ops.edge.aten.max_pool2d_with_indices.default,
        exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
        exir_ops.edge.aten.native_group_norm.default,
        exir_ops.edge.aten.pixel_shuffle.default,
        exir_ops.edge.aten.pixel_unshuffle.default,
        exir_ops.edge.aten.upsample_bilinear2d.default,
        exir_ops.edge.aten.upsample_nearest2d.default,
    }

    layout_agnostic_ops = {
        exir_ops.edge.aten.add.Tensor,
        exir_ops.edge.aten.bmm.default,
        exir_ops.edge.aten.cat.default,
        exir_ops.edge.aten.ceil.default,
        exir_ops.edge.aten.clamp.default,
        exir_ops.edge.aten.constant_pad_nd.default,
        exir_ops.edge.aten.div.Tensor,
        exir_ops.edge.aten.full.default,
        exir_ops.edge.aten.gelu.default,
        exir_ops.edge.aten.hardswish.default,
        exir_ops.edge.aten.hardsigmoid.default,
        exir_ops.edge.aten.hardtanh.default,
        exir_ops.edge.aten.leaky_relu.default,
        exir_ops.edge.aten.linear.default,
        exir_ops.edge.aten._log_softmax.default,
        exir_ops.edge.aten.mean.dim,
        exir_ops.edge.aten.mul.Tensor,
        exir_ops.edge.aten.pow.Tensor_Scalar,
        exir_ops.edge.aten.prelu.default,
        exir_ops.edge.aten.relu.default,
        exir_ops.edge.aten._softmax.default,  # TODO: Need to find a new solution to do "axis_order" to transform axis.
        exir_ops.edge.aten.sqrt.default,
        exir_ops.edge.aten.sub.Tensor,
        exir_ops.edge.aten.sum.dim_IntList,
        exir_ops.edge.aten.topk.default,
        exir_ops.edge.aten._to_copy.default,
        exir_ops.edge.aten.split_with_sizes.default,
        *q_ops,
        *dq_ops,
        _operator.getitem,
    }

    layout_type = {
        1: ("N", "N"),
        2: ("NC", "NC"),
        3: ("NCW", "NWC"),
        4: ("NCHW", "NHWC"),
        5: ("NCDHW", "NDHWC"),
    }

    @classmethod
    def get_axis_order(cls, size: List[int], reverse=False) -> Tuple[int]:
        old_layout, new_layout = cls.layout_type[len(size)]
        if reverse:
            old_layout, new_layout = new_layout, old_layout
        return tuple(old_layout.find(x) for x in new_layout)

    def __init__(
        self, edge_program: torch.export.ExportedProgram, insert_permute=False
    ):
        super(LayoutTransform, self).__init__()
        self.edge_program = edge_program
        self.insert_permute = insert_permute
        self.qdq_opset = {*q_ops, *dq_ops}
        self.transformed_tag = QCOM_AXIS_ORDER

    def mark_as_transformed(self, node: torch.fx.Node) -> None:
        if isinstance(node.meta["val"], (tuple, list)):
            getitem_node = list(node.users.keys())[0]
            if getitem_node.target.__name__ != "getitem":
                raise AssertionError(
                    "Expected node's user to be getitem, "
                    f"got {getitem_node.target.__name__}"
                )
            index = getitem_node.args[1]
            node.meta[self.transformed_tag] = self.get_axis_order(
                eval_shape(node.meta["val"][index].shape)
            )
        else:
            node.meta[self.transformed_tag] = self.get_axis_order(
                eval_shape(node.meta["val"].shape)
            )

    def is_transformed_node(self, node: torch.fx.Node) -> bool:
        if not hasattr(node, "meta"):
            return False
        return self.transformed_tag in node.meta

    def is_layout_sensitive(self, node: torch.fx.Node) -> bool:
        return node.target in self.layout_sensitive_ops

    def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
        if node.target in [
            exir_ops.edge.aten.mean.dim,
            exir_ops.edge.aten.sum.dim_IntList,
        ]:
            # if dimemsion is not kept, we'll have no clue how to do layout transform
            if len(node.args) < 3 or not node.args[2]:
                return False
        if node.target in self.qdq_opset:
            return QCOM_REQUANTIZE in node.meta
        return node.target in self.layout_agnostic_ops

    def is_edge_condition(self, node):
        if not isinstance(node, torch.fx.Node):
            return True

        if any(
            [
                self.is_transformed_node(node),
                node.op == "get_attr",
                (
                    node.target == exir_ops.edge.aten.permute_copy.default
                    and node.meta.get(QCOM_INSERTED_PERMUTE, False)
                ),
                (
                    node.op != "output"
                    and not isinstance(node.meta["val"], (tuple, list))
                    and len(node.meta["val"].shape) == 0
                ),
                is_parameter(node, self.edge_program),
            ]
        ):
            return True

        return False

    def insert_node(self, graph_module, node, revert_layout: bool) -> None:
        if not self.insert_permute:
            return
        with graph_module.graph.inserting_after(node):
            users = node.users.copy()
            if isinstance(node.meta["val"], tuple):
                getitem_node = list(node.users.keys())[0]
                if getitem_node.target.__name__ != "getitem":
                    raise AssertionError(
                        f"Expected bn node's user to be getitem, got {getitem_node.target.__name__}"
                    )
                index = getitem_node.args[1]
                tensor = node.meta["val"][index]
            else:
                tensor = node.meta["val"]

            permute = self.create_call_function_node(
                graph_module,
                exir_ops.edge.aten.permute_copy.default,
                (
                    node,
                    self.get_axis_order(eval_shape(tensor.shape), revert_layout),
                ),
            )
            permute.meta["val"] = tensor
            permute.meta[QCOM_QUANT_ATTRS] = node.meta.get(QCOM_QUANT_ATTRS)
            # we need this to check the annotation boundary
            permute.meta[QCOM_INSERTED_PERMUTE] = True

            # this is the case when residual connection happened:
            # e.g. consider following graph
            # x --> permute --> layer_norm --> permute --> conv2d --> add
            #               └-------------------------------------┙
            # we should have premute node to be correctly inserted as:
            # x --> permute --> layer_norm --> permute --> qnn_permute --> conv2d --> add
            #               └--------------------------------------> qnn_premute -┙
            # i.e. insert permute by condition between user and current node
            #      if there are multiple users included
            is_node_transformed = self.is_transformed_node(node)
            for user in users:
                is_user_transformed = (
                    self.is_transformed_node(user) or QCOM_LAYOUT_CHANGE in user.meta
                )
                # insert permute only in exclusive condition
                if is_node_transformed != is_user_transformed:
                    user.replace_input_with(node, permute)

    def create_call_function_node(
        self,
        graph_module: torch.fx.GraphModule,
        target: torch.fx.node.Target,
        args: Tuple[torch.fx.node.Argument, ...],
    ):
        return graph_module.graph.create_node(
            "call_function",
            target=target,
            args=args,
        )

    def traverse(self, node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> None:
        for arg in node.args:
            if isinstance(arg, list):
                for arg_node in arg:
                    self.annotate_layout(arg_node, graph_module, revert_layout=False)
            else:
                self.annotate_layout(arg, graph_module, revert_layout=False)

        node_users = set(node.users.keys())
        for user in node_users:
            self.annotate_layout(user, graph_module, revert_layout=True)

    def annotate_layout(
        self, node: torch.fx.Node, graph_module: torch.fx.GraphModule, revert_layout
    ) -> None:

        if self.is_edge_condition(node):
            return
        elif self.is_layout_agnostic(node) or self.is_layout_sensitive(node):
            self.mark_as_transformed(node)
            self.traverse(node, graph_module)
        else:

            def check_arg(arg):
                if self.is_transformed_node(arg):
                    self.insert_node(graph_module, arg, revert_layout=revert_layout)

            if not revert_layout:
                self.insert_node(graph_module, node, revert_layout=revert_layout)
            else:
                for args in node.args:
                    if isinstance(args, torch.fx.immutable_collections.immutable_list):
                        for arg in args:
                            check_arg(arg)
                    else:
                        check_arg(args)

    def call(self, graph_module: torch.fx.GraphModule):
        graph = graph_module.graph
        sensitive_nodes = [
            node for node in graph.nodes if self.is_layout_sensitive(node)
        ]
        # perform first run traversal for identifying nodes subjected to layout changes
        if self.insert_permute:
            self.insert_permute, self.transformed_tag = False, QCOM_LAYOUT_CHANGE
            for node in sensitive_nodes:
                if not self.is_transformed_node(node):
                    self.mark_as_transformed(node)
                    self.traverse(node, graph_module)
            self.insert_permute, self.transformed_tag = True, QCOM_AXIS_ORDER

        for node in sensitive_nodes:
            if not self.is_transformed_node(node):
                self.mark_as_transformed(node)
                self.traverse(node, graph_module)

        graph_module.recompile()
        if not self.insert_permute:
            graph_module = super().call(graph_module).graph_module
        return PassResult(graph_module, True)
