# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

# Utiliy functions for TOSA quantized lowerings

import math
from typing import Callable, cast, NamedTuple, Sequence

import numpy as np

import serializer.tosa_serializer as ts
import torch.fx
import tosa.Op as TosaOp
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.exir.dialects._ops import ops as exir_ops
from serializer.tosa_serializer import TosaSerializerTensor
from torch.fx import Node


q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
dq_q_ops = (q_op, dq_op)
passable_ops = [
    exir_ops.edge.aten.view_copy.default,
    exir_ops.edge.aten.permute_copy.default,
    exir_ops.edge.aten.squeeze_copy.dims,
    exir_ops.edge.aten.unsqueeze_copy.default,
    exir_ops.edge.aten.split_with_sizes_copy.default,
    exir_ops.edge.aten.repeat.default,
    exir_ops.edge.aten.clone.default,
    exir_ops.edge.aten.slice_copy.Tensor,
    exir_ops.edge.aten.cat.default,
]


def register_passable_op(op):
    """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created"""
    passable_ops.append(op)


class QuantArgs(NamedTuple):
    scale: float
    zp: int
    qmin: int
    qmax: int
    dtype: torch.dtype

    def quantize_value(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.Tensor([x])
        return torch.clip(
            torch.round(x / self.scale) + self.zp,
            self.qmin,
            self.qmax,
        ).to(self.dtype)

    def dequantize_value(self, qx: int) -> float:
        return (qx - self.zp) * self.scale


def quantize_value(x, qargs: QuantArgs, dtype=np.int8):
    return np.clip(
        np.round(x / qargs.scale) + qargs.zp,
        qargs.qmin,
        qargs.qmax,
    ).astype(dtype)


def dequantize_value(qx, qargs: QuantArgs):
    return (qx - qargs.zp) * qargs.scale


def qargs_from_qnode(node: torch.fx.Node):
    assert node.target in dq_q_ops, f"Op {node} is not a quant node."

    return QuantArgs(
        scale=cast(float, node.args[1]),
        zp=cast(int, node.args[2]),
        qmin=cast(int, node.args[3]),
        qmax=cast(int, node.args[4]),
        dtype=cast(torch.dtype, node.args[5]),
    )


def get_neighbour_quant_args(
    node: torch.fx.Node,
) -> tuple[list[QuantArgs], list[QuantArgs]]:
    user_q_args = []

    for user in node.users:
        q_args = search_quant_arg_downstream(user)
        if q_args:
            user_q_args.append(q_args)

    input_q_nodes = []
    for input_node in node.all_input_nodes:
        q_args = search_quant_arg_upstream(input_node)
        if q_args:
            input_q_nodes.append(q_args)
    return user_q_args, input_q_nodes


def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool:
    first_q_arg = q_arg_list[0]
    for q_arg in q_arg_list:
        if q_arg != first_q_arg:
            return False
    return True


def is_node_quantized(node: torch.fx.Node) -> bool:
    if node.target in dq_q_ops:
        return True

    user_q_args, input_q_args = get_neighbour_quant_args(node)

    # If we did not find any neighbouring quant nodes, we are not quantized.
    if len(input_q_args) == 0 and len(user_q_args) == 0:
        return False

    if node.target in passable_ops:
        assert all_q_args_equal(
            user_q_args + input_q_args
        ), f"Node {node} needs same quantization parameters on all inputs and outputs."

    return True


def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None:
    """
    Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node,
    starting with 'node'.
    If a  passable node with multiple consumers is encountered,
    find QuantArgs for all consumers and assert that they are equal.
    If a node not in passable_ops is encountered, return None.
    If a node without consumers is encountered, return None.
    """
    if node.target in dq_q_ops:
        return qargs_from_qnode(node)
    if node.target not in passable_ops:
        return None
    consumer_nodes = list(node.users)
    if len(consumer_nodes) == 0:
        return None
    elif len(consumer_nodes) == 1:
        return search_quant_arg_downstream(consumer_nodes[0])
    else:
        consumer_qargs: list[QuantArgs] = []
        for input in consumer_nodes:
            quant_args = search_quant_arg_downstream(input)
            if quant_args:
                consumer_qargs.append(quant_args)
        if len(consumer_qargs) == 0:
            return None
        assert all_q_args_equal(
            consumer_qargs
        ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers."
        return consumer_qargs[0]


def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs:
    """Calls search_quant_arg_downstream and asserts that QuantArgs are found,
    meaning return value can't be None.
    """
    qargs = search_quant_arg_downstream(node)
    assert qargs, f"Did not find QuantArgs downstream for node {node}"
    return qargs


def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None:
    """
    Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node,
    starting with 'node'.
    If a  passable node with multiple inputs is encountered,
    find QuantArgs for all inputs and assert that they are equal.
    If a node not in passable_ops is encountered, return None.
    If a node without inputs is encountered, return None.
    """

    if node.target in dq_q_ops:
        return qargs_from_qnode(node)
    if node.target not in passable_ops:
        return None
    input_nodes = list(node.all_input_nodes)
    if len(input_nodes) == 0:
        return None
    elif len(input_nodes) == 1:
        return search_quant_arg_upstream(input_nodes[0])
    else:
        input_qargs: list[QuantArgs] = []
        for input in input_nodes:
            quant_args = search_quant_arg_upstream(input)
            if quant_args:
                input_qargs.append(quant_args)
        if len(input_qargs) == 0:
            return None
        assert all_q_args_equal(
            input_qargs
        ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs."
        return input_qargs[0]


def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs:
    """Calls search_quant_arg_upstream and asserts that QuantArgs are found,
    meaning return value can't be None.
    """
    qargs = search_quant_arg_upstream(node)
    assert qargs, f"Did not find QuantArgs upstream for node {node}"
    return qargs


def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype:
    if isinstance(node.target, Callable) and "tosa" in node.target.__name__:
        return node.meta["val"].dtype
    if node.target in dq_q_ops:
        return cast(torch.dtype, node.args[5])

    # if not a tosa node, nor a q/dq op, walk the graph until we find a q op
    user_q_args, input_q_args = get_neighbour_quant_args(node)
    if len(user_q_args) > 0:
        return user_q_args[0].dtype
    elif node.target in passable_ops and len(input_q_args) > 0:
        return input_q_args[0].dtype
    else:
        raise RuntimeError("No quantized node found in graph")


# Check if scale32 mode is used for given output element type
def is_scale32(type):
    return type == ts.DType.INT8


# TOSA uses the RESCALE operation to scale between values with differing precision.
# The RESCALE operator is defined using an integer multiply, add, and shift.
# This utility function is for calculating the multier and shift given a scale.
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
def compute_multiplier_and_shift(scale, scaleWidth=32):
    if scaleWidth == 16:
        offset = 15
    elif scaleWidth == 32:
        offset = 31
    else:
        raise AssertionError("unsupported scale width")

    assert isinstance(scale, float)

    mantissa, exponent = math.frexp(scale)
    shift = exponent

    const_2_power_15_or_31 = 1 << offset
    shifted_mantissa = round(mantissa * const_2_power_15_or_31)

    assert shifted_mantissa <= const_2_power_15_or_31

    if shifted_mantissa == const_2_power_15_or_31:
        shifted_mantissa = shifted_mantissa / 2
        shift += 1

    # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
    shift = offset - shift

    # INT32_MAX, 2^31 - 1
    assert shifted_mantissa <= (const_2_power_15_or_31 - 1)

    multiplier = shifted_mantissa

    if shift > 62:
        multiplier = multiplier >> min(31, shift - 62)
        shift = 62
    return multiplier, shift


def build_rescale(
    tosa_fb,
    scale,
    input_node,
    output_name,
    output_type,
    output_shape,
    input_zp,
    output_zp,
    is_double_round=False,
):
    scale_width = 32 if is_scale32(output_type) else 16
    multiplier, shift = compute_multiplier_and_shift(scale, scale_width)

    attr_rescale = ts.TosaSerializerAttribute()
    attr_rescale.RescaleAttribute(
        input_zp=input_zp,
        output_zp=output_zp,
        multiplier=[multiplier],
        shift=[shift],
        scale32=is_scale32(output_type),
        double_round=is_double_round,
        per_channel=False,
        input_unsigned=False,
        output_unsigned=False,
    )

    tosa_fb.addOperator(
        TosaOp.Op().RESCALE, [input_node.name], [output_name], attr_rescale
    )

    return


def build_rescale_to_int32(
    tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=False
) -> TosaSerializerTensor:
    multiplier, shift = compute_multiplier_and_shift(rescale_scale)
    attr_rescale = ts.TosaSerializerAttribute()
    attr_rescale.RescaleAttribute(
        input_zp=input_zp,
        output_zp=0,
        multiplier=[multiplier],
        shift=[shift],
        scale32=is_scale32,
        double_round=is_double_round,
        per_channel=False,
        input_unsigned=False,
        output_unsigned=False,
    )
    input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32)
    tosa_fb.addOperator(
        TosaOp.Op().RESCALE,
        [input.name],
        [input_A_rescaled_to_int32.name],
        attr_rescale,
    )

    return input_A_rescaled_to_int32


def build_rescale_from_int32(
    tosa_fb,
    input_name,
    output_name,
    output_zp,
    rescale_scale,
    is_scale32=True,
    is_double_round=False,
) -> None:
    multiplier, shift = compute_multiplier_and_shift(rescale_scale)
    attr_rescale_output = ts.TosaSerializerAttribute()
    attr_rescale_output.RescaleAttribute(
        input_zp=0,
        output_zp=output_zp,
        multiplier=[multiplier],
        shift=[shift],
        scale32=is_scale32,
        double_round=is_double_round,
        per_channel=False,
        input_unsigned=False,
        output_unsigned=False,
    )

    tosa_fb.addOperator(
        TosaOp.Op().RESCALE, [input_name], [output_name], attr_rescale_output
    )

    return


def rescale_nodes_to_int32(
    nodes: Sequence[Node], tosa_graph: ts.TosaSerializer
) -> tuple[list[TosaSerializerTensor], float]:
    """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'.
    The scales are adjusted using the smallest scale of all 'nodes'.

    Returns a list of the rescaled nodes and the scale factor used,
    needed by rescale_node_back_to_int8.
    """

    tensors = [TosaArg(node) for node in nodes]

    # Reshape tensor according to tosa dim order
    for tensor in tensors:
        dim_order = tensor.dim_order
        tensor.shape = [tensor.shape[i] for i in dim_order]

    qargs = [get_quant_arg_upstream(node) for node in nodes]

    # Scale the int8 quantized input to a common scale in the integer
    # domain
    min_scale = min([qarg.scale for qarg in qargs])
    scales = [qarg.scale / min_scale for qarg in qargs]

    rescaled_nodes: list[TosaSerializerTensor] = []
    for tensor, qarg, scale in zip(tensors, qargs, scales):
        rescaled_nodes.append(
            build_rescale_to_int32(
                tosa_graph,
                tensor,
                qarg.zp,
                scale,
            )
        )
    return rescaled_nodes, min_scale


def rescale_node_back_to_int8(
    node: Node,
    last_tensor: TosaSerializerTensor,
    scale: float,
    tosa_graph: ts.TosaSerializer,
):
    """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'.
    Parameters:
        node: The original node that is being handled by the rescales.
        last_tensor:the tosa tensor to rescale back.
        scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32'
        tosa_graph: the tosa_graph to manipulate.
    """
    qargs_out = get_quant_arg_downstream(list(node.users)[0])
    output_rescale_scale = scale / qargs_out.scale

    # Rescale Back to INT8
    build_rescale_from_int32(
        tosa_graph,
        last_tensor.name,
        node.name,
        qargs_out.zp,
        output_rescale_scale,
    )


""" Creates a TOSA rescale op based on conv2d parameters. """


def build_rescale_conv_output(
    tosa_fb,
    op,
    output_name,
    output_type,
    input_scale,
    weight_scale,
    output_scale,
    output_zp,
):
    # TODO add check to verify if this is a Per-channel quantization.
    post_conv2d_scale = (input_scale * weight_scale) / output_scale

    # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0.
    build_rescale(
        tosa_fb,
        post_conv2d_scale,
        op,
        output_name,
        output_type,
        op.shape,
        0,
        output_zp,
    )
    return
