# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#

# pyre-unsafe
from typing import cast, Dict

import numpy as np
import serializer.tosa_serializer as ts
import torch
import torch.fx
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import (
    get_quant_arg_upstream,
    get_quantized_node_output_dtype,
    is_node_quantized,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import (
    getNodeArgs,
    is_bias_node_for_quantized_conv,
    tosa_shape,
)
from torch.export.exported_program import ExportedProgram


def process_call_function(
    node: torch.fx.Node,
    tosa_graph: ts.TosaSerializer,
    node_visitors: Dict[str, NodeVisitor],
    tosa_spec: TosaSpecification,
):
    # Unpack arguments and convert
    inputs = getNodeArgs(node)

    # Convert output (this node itself)
    output = TosaArg(node)

    is_quant_node = is_node_quantized(node)
    if is_quant_node:
        output_dtype = map_dtype(get_quantized_node_output_dtype(node))
    else:
        output_dtype = output.dtype
    tosa_graph.currRegion.currBasicBlock.addTensor(
        output.name,
        tosa_shape(output.shape, output.dim_order),
        output_dtype,
    )

    # Visiting each Node
    # pyre-ignore[16]: Undefined attribute.
    if node.target.__name__ in node_visitors:
        # pyre-ignore[16]: Undefined attribute.
        node_visitors[node.target.__name__].define_node(
            node,
            tosa_graph,
            inputs,
            output,
            is_quant_node,
        )
    else:
        raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}")


def process_inputs(
    node: torch.fx.Node,
    tosa_graph: ts.TosaSerializer,
    tosa_spec: TosaSpecification,
):
    """Serialize an input node"""
    # inputs need to be in default dim_order (contiguous memory format)
    meta = node.meta["val"]
    if meta.dim_order() != tuple(range(meta.dim())):
        raise RuntimeError(
            f"Arm backend only supports contiguous memory format for inputs. "
            f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
        )
    inputs = [TosaArg(node)]
    input_shape = inputs[0].shape
    input_dim_order = inputs[0].dim_order
    tensor = ts.TosaSerializerTensor(
        inputs[0].name,
        tosa_shape(input_shape, input_dim_order),
        (
            map_dtype(get_quantized_node_output_dtype(node))
            if is_node_quantized(node)
            else inputs[0].dtype
        ),
        data=None,
        placeholderFilename=inputs[0].name + ".npy",
    )
    tosa_graph.addInputTensor(tensor)


def process_quantized_bias(
    node: torch.fx.Node,
    tosa_graph: ts.TosaSerializer,
    parameter_values,
):
    """
    Serialize bias node that needs to be quantized.
    """
    consumer_node = list(node.users)[0]
    (
        input_node,
        weight_node,
        _,
    ) = consumer_node.all_input_nodes

    input_node_scale = get_quant_arg_upstream(input_node).scale
    weight_node_scale = get_quant_arg_upstream(weight_node).scale
    bias_values_quantized = (
        (parameter_values / (input_node_scale * weight_node_scale))
        .round()
        .astype(np.int32)
    )

    tosa_graph.addConst(
        bias_values_quantized.shape,
        ts.DType.INT32,
        bias_values_quantized,
        name=node.name,
    )


def process_inputs_to_parameters(
    node: torch.fx.Node,
    tosa_graph: ts.TosaSerializer,
    edge_program: ExportedProgram,
    tosa_spec: TosaSpecification,
):
    """Serialize bias and non-quantized weights"""
    inputs = [TosaArg(node)]
    parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name]
    parameter_data = edge_program.state_dict[parameter_name]

    assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
    parameter_values = parameter_data.detach().numpy()

    if is_bias_node_for_quantized_conv(node):
        # BI bias
        assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
        process_quantized_bias(node, tosa_graph, parameter_values)
    else:
        # MI weights or bias
        if inputs[0].dtype == torch.float32:
            assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"

        parameter_values = np.transpose(parameter_values, inputs[0].dim_order)

        tosa_graph.addConst(
            parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
        )


def process_inputs_to_buffers(
    node: torch.fx.Node,
    tosa_graph: ts.TosaSerializer,
    edge_program: ExportedProgram,
):
    """Serialize quantized weights"""
    inputs = [TosaArg(node)]
    buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
    buffer_data = edge_program.state_dict[buffer_name]

    assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor"
    buffer_values = buffer_data.detach().numpy()

    # TODO: fragile code for temporary fix
    # the mean and var tensors are also stored here but they have shape (1, )
    # we only transpose weights here
    buffer_values = np.transpose(buffer_values, inputs[0].dim_order)

    tosa_graph.addConst(
        buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name
    )


def process_inputs_to_lifted_tensor_constants(
    node: torch.fx.Node,
    tosa_graph: ts.TosaSerializer,
    edge_program: ExportedProgram,
):
    arg = TosaArg(node)
    tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
        arg.name
    ]
    tensor = edge_program.tensor_constants[tensor_name]
    tensor_data = tensor.detach().numpy()

    tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name)


def process_placeholder(
    node: torch.fx.Node,
    tosa_graph: ts.TosaSerializer,
    edge_program: ExportedProgram,
    tosa_spec: TosaSpecification,
):
    """Wrapper for processing and serializing all types of placeholders"""
    assert node.name == node.target, "Expect placeholder name and target to match"
    assert 0 == len(node.args), "Can't handle default input values"

    if node.name in edge_program.graph_signature.user_inputs:
        process_inputs(node, tosa_graph, tosa_spec)
    elif node.name in edge_program.graph_signature.inputs_to_parameters:
        process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
    elif node.name in edge_program.graph_signature.inputs_to_buffers:
        process_inputs_to_buffers(node, tosa_graph, edge_program)
    elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
        process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
    elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
        raise NotImplementedError(
            "Placeholder is of type 'lifted custom object' which is not supported."
        )
    else:
        raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")


def process_output(
    node: torch.fx.Node,
    tosa_graph: ts.TosaSerializer,
):
    for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
        tosa_graph.addOutputTensor(
            tosa_graph.currRegion.currBasicBlock.tensors[output.name]
        )
