# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
import os
from typing import Any, cast

import numpy as np
import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import (
    get_quant_arg_downstream,
    get_quant_arg_upstream,
    q_op,
)
from executorch.exir.dialects._ops import ops as exir_ops
from serializer.tosa_serializer import TosaOp
from torch.fx import Node

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
if TOSA_DBG_VERBOSE:
    logging.basicConfig(level=logging.INFO)
    logger.setLevel(logging.INFO)


def dbg_node(node):
    # Debug output of node information
    logger.info("OP")
    logger.info(f"  op is {node.op}")
    logger.info(f"  name is {node.name}")
    logger.info(f"  node target is {node.target}")
    logger.info(f"  node args is {node.args}")
    logger.info(f"  node kwargs is {node.kwargs}")
    logger.info("  node.meta = ")
    for k, v in node.meta.items():
        logger.info(f"    '{k}' = {v}")
        if isinstance(v, list):
            for i in v:
                logger.info(f"      {i} ")


# Output TOSA flatbuffer and test harness file
def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
    filename = f"output{suffix}.tosa"

    logger.info(f"Emitting debug output to: {path=}, {suffix=}")

    os.makedirs(path, exist_ok=True)

    fb = tosa_graph.serialize()
    js = tosa_graph.writeJson(filename)

    filepath_tosa_fb = os.path.join(path, filename)
    with open(filepath_tosa_fb, "wb") as f:
        f.write(fb)
    assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer"

    filepath_desc_json = os.path.join(path, f"desc{suffix}.json")
    with open(filepath_desc_json, "w") as f:
        f.write(js)
    assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON"


def dbg_fail(node, tosa_graph, path):
    dbg_tosa_dump(tosa_graph, path)
    logger.warn("Internal error due to poorly handled node:")
    dbg_node(node)
    logger.warn(f"Debug output captured in '{path}'.")
    raise RuntimeError("TOSA Internal Error on node, enable logging for further info.")


# Helper function to match TOSA's broadcasting rank requirement
# Ref: TOSA 0.80.0 specification - 1.9.3. Data Layouts from
# https://www.mlplatform.org/tosa/tosa_spec.html
def promote_shape(tosa_fb, arg, promoted_shape, out_dtype):
    assert np.prod(arg.shape) == np.prod(promoted_shape), "Incompatible promoted shape"
    reshape_res = tosa_fb.addIntermediate(promoted_shape, out_dtype)
    attr = ts.TosaSerializerAttribute()
    attr.ReshapeAttribute(promoted_shape)
    tosa_fb.addOperator(TosaOp.Op().RESHAPE, [arg.name], [reshape_res.name], attr)
    return reshape_res


# Helper transpose function to match TOSA's shape requirements
# E.g., TOSA 0.80.0 specification - 2.3.3 CONV2D shapes:
# https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d
def transpose_helper(tosa_fb, input, new_order, out_dtype):
    # Check new_order's length is equal to input rank
    assert len(input.shape) == len(new_order), "Wrong shape order length"

    # Check no duplications
    assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers"

    # Check all dims are valid
    for idx in new_order:
        if idx < 0:
            assert True, "Negative dim number"
        elif idx >= len(input.shape):
            assert True, "Dim is greater than input rank"

    input_shape_transpoed = [input.shape[i] for i in new_order]
    attr = ts.TosaSerializerAttribute()
    attr.TransposeAttribute(new_order)
    input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype)
    tosa_fb.addOperator(
        TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr
    )
    return input_transposed


def getNodeArgs(node: Node) -> list[TosaArg]:
    return [TosaArg(arg) for arg in node.args]


def get_input_tensor(node: Node) -> TosaArg:
    return TosaArg(node.args[0])


def get_output_node(node: Node) -> Node:
    return list(node.users)[0]


""" TOSA reshape returns a tensor with the same type/values as the input.
    No data conversion happens during a reshape operation. """


def build_reshape(tosa_fb, input_name, new_shape, output_name):
    attr = ts.TosaSerializerAttribute()
    attr.ReshapeAttribute(new_shape)
    tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr)


def is_bias_node_for_quantized_conv(node):
    consumer_node = list(node.users)[0]
    return (
        consumer_node.target == exir_ops.edge.aten.convolution.default
        and list(consumer_node.users)[0].target == q_op
    )


def is_consumer_node_depthwise_conv2d(node):
    consumer_node = list(node.users)[0]
    if consumer_node.target == exir_ops.edge.aten.convolution.default:
        inputs = getNodeArgs(consumer_node)
        group = inputs[-1]
        in_channels = inputs[0].shape[1]
        out_channels = inputs[1].shape[0]
        if (in_channels == group.number) and (out_channels % in_channels) == 0:
            return True

    return False


def build_avg_pool_2d_common(
    node: torch.fx.Node,
    tosa_graph: ts.TosaSerializer,
    input_tensor: TosaArg,
    kernel_size: list,
    stride: list,
    padding: list,
    is_quant_node: bool,
    output: TosaArg,
):
    accumulator_type = input_tensor.dtype

    if is_quant_node:
        # Accumulator type always is int32 when input tensor is an integer type.
        accumulator_type = ts.DType.INT32

    # Initilize zero point to zero.
    input_zp = 0
    output_zp = 0

    if is_quant_node:
        input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp
        output_zp = get_quant_arg_downstream(list(node.users)[0]).zp

    attr = ts.TosaSerializerAttribute()
    attr.PoolAttribute(
        kernel=kernel_size,
        stride=stride,
        pad=padding,
        input_zp=input_zp,
        output_zp=output_zp,
        accum_dtype=accumulator_type,
    )

    tosa_graph.addOperator(
        TosaOp.Op().AVG_POOL2D,
        [input_tensor.name],
        [output.name],
        attr,
    )


def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]:
    """Returns two input nodes to 'node' in order. If 'node' only has one input,
    it is returned twice.

    Fails if there are no input nodes.
    Fails if there are >2 input nodes and 'check' is True,
    """

    num_inputs = len(node.all_input_nodes)
    assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}."

    input1 = node.all_input_nodes[0]
    if num_inputs == 1:
        input2 = node.all_input_nodes[0]
    else:
        input2 = node.all_input_nodes[1]
    if check:
        assert (
            num_inputs <= 2
        ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}."

    return input1, input2


def tosa_shape(shape, dim_order):
    return tuple([shape[dim] for dim in dim_order])


def expand_dims(
    tosa_graph: ts.TosaSerializer,
    input_node: TosaArg,
    dtype: int,
    dim: int,
) -> Any:
    """Inserts TOSA operators into the tosa_graph, that perform the equivalent
    of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the
    dim location.

    Args:
        tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate.
        input_node (TosaArg): The parent node of the expand dim operations.
        dtype (ts.DType): The data type expand dims operations.
        dim (int): The dimension to expand.

    Returns:
        Any: The output tensor of the inserted operation in the TOSA graph.
    """
    new_shape = list(input_node.shape)
    new_shape.insert(dim, 1)

    intermediate = tosa_graph.addIntermediate(new_shape, dtype)

    build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name)

    return intermediate


def get_resize_parameters(
    input_size: torch.Tensor,
    output_size: torch.Tensor,
    resize_mode: int,
    align_corners: bool,
):
    """Get the tosa.resize parameters based on the input and output size.

    Args:
        input_size (torch.Tensor): Size of the input
        output_size (torch.Tensor): Size of the output
        resize_mode (tosa.ResizeMode): The TOSA resize mode
        align_corners (bool): Align the corners pixels of the input and output

    Returns:
        scale_n (torch.Tensor), scale_d (torch.Tensor),
        offset (torch.Tensor), border (torch.Tensor)
    """
    assert torch.all(input_size > 0)
    assert torch.all(output_size > 0)

    scale_n = torch.tensor(
        [
            so - 1 if align_corners and si > 1 and so > 1 else so
            for si, so in zip(input_size, output_size)
        ]
    )
    scale_d = torch.tensor(
        [
            si - 1 if align_corners and si > 1 and so > 1 else si
            for si, so in zip(input_size, output_size)
        ]
    )

    gcd = torch.gcd(scale_n, scale_d)
    scale_n = scale_n // gcd
    scale_d = scale_d // gcd

    # No half-pixel centre support in PyTorch, no offset needed
    offset = torch.zeros_like(input_size)
    border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset

    return scale_n, scale_d, offset, border
