# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import torch
from executorch.backends.arm._passes.arm_pass_utils import (
    create_node,
    get_param_tensor,
    insert_q_dq_pair,
    is_param_node,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class Conv1dUnsqueezePass(ExportPass):
    """
    This pass is used to change conv1d ops into conv2d since TOSA only
    supports 2d and 3d convolution. This is done by modifying the graph to do the
    following:
    1) unsqueeze the convolution's input from 3d to 4d
    2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
    3) perform a conv2d (with a modified version of the original conv1d args)
    4) squeeze the output back down to 3d.
    5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
    """

    def __init__(self, exported_program: ExportedProgram) -> None:
        super().__init__()
        self.exported_program = exported_program

    def unsqueeze_kernel_weights(self, kernel_node):
        """
        Unsqueezes the weights of a conv1d to make it 4 dimensional.

        Args:
            kernel_node: the weights of conv1d node to be unsqueezed
        """
        kernel_param_3d = get_param_tensor(self.exported_program, kernel_node)
        if kernel_param_3d is None:
            raise AssertionError("Expected param tensor for the kernel node")

        kernel_param_4d = torch.nn.Parameter(
            data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1),
            requires_grad=False,
        )

        if torch._export.utils.is_param(self.exported_program, kernel_node):
            parameter_name = self.exported_program.graph_signature.inputs_to_parameters[
                kernel_node.name
            ]
            self.exported_program.state_dict[parameter_name] = kernel_param_4d
            kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
        elif torch._export.utils.is_buffer(self.exported_program, kernel_node):
            buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
                kernel_node.name
            ]
            self.exported_program.state_dict[buffer_name] = kernel_param_4d
            kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
        elif torch._export.utils.is_lifted_tensor_constant(
            self.exported_program, kernel_node
        ):
            buffer_name = (
                self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
                    kernel_node.name
                ]
            )
            self.exported_program.constants[buffer_name] = kernel_param_4d
            kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
        else:
            setattr(
                kernel_node.graph.owning_module,
                kernel_node.target,
                kernel_param_4d,
            )

    def call(self, graph_module: torch.fx.GraphModule):
        graph = graph_module.graph
        node_list = list(graph.nodes)
        for node in node_list:
            if node.op == "call_function":
                if node.target == exir_ops.edge.aten.convolution.default:
                    stride = list(node.args[3])
                    if len(stride) != 1:
                        # skip conv if it is not 1d
                        continue

                    kernel_node = node.args[1]
                    if kernel_node.target == dq_op:
                        kernel_node = kernel_node.args[0]

                    if not is_param_node(self.exported_program, kernel_node):
                        raise AssertionError(
                            "Expected op for convolution weight node to be a get_attr node or a parameter"
                        )

                    # Modify graph such that the conv changes from 1d to 2d
                    self.unsqueeze_kernel_weights(kernel_node)

                    # (b) Extend stride, padding, and dilation for extra dim
                    node.args = (
                        node.args[0],
                        node.args[1],
                        node.args[2],
                        node.args[3] + [1],  # stride
                        node.args[4] + [0],  # padding
                        node.args[5] + [1],  # dilation
                        node.args[6],
                        node.args[7] + [0],
                        node.args[8],
                    )

                    # c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d)
                    # unsqueeze -> conv2d -> squeeze
                    with graph.inserting_before(node):
                        input_node = node.args[0]
                        unsqueeze_before = create_node(
                            graph, exir_ops.edge.aten.unsqueeze_copy.default
                        )
                        unsqueeze_before.args = (
                            input_node,  # Input is node's original input
                            -1,  # Last Dimension
                        )
                        node.replace_input_with(input_node, unsqueeze_before)

                    # If Quantized we must insert unsqueeze --> q --> dq --> node
                    if input_node.target == dq_op:
                        q_params = input_node.args[1:]
                        insert_q_dq_pair(graph, unsqueeze_before, q_params)

                    with graph.inserting_after(node):
                        squeeze_after = create_node(
                            graph,
                            exir_ops.edge.aten.squeeze_copy.dims,
                        )
                        squeeze_after.args = (
                            node,  # Input is the conv node
                            [-1],  # Last dimension
                        )
                        original_users = [
                            user for user in node.users if user != squeeze_after
                        ]
                        for user in original_users:
                            user.replace_input_with(node, squeeze_after)

                        # If quantized, insert conv2d --> q --> dq --> squeeze
                    if all(
                        original_user.target == q_op for original_user in original_users
                    ):
                        q_params = original_users[0].args[1:]
                        insert_q_dq_pair(graph, node, q_params)

        graph_module.recompile()
        # Since we are overriding "call", we need to call the parent's "call"
        # to retrace the graph and regenerate metadata
        graph_module = super().call(graph_module).graph_module

        return PassResult(graph_module, True)
