# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import numpy as np

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
    NodeVisitor,
    register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import (
    dequantize_value,
    get_quant_arg_downstream,
    get_quant_arg_upstream,
    QuantArgs,
    quantize_value,
)
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class ExpVisitor(NodeVisitor):
    target = "aten.exp.default"

    def __init__(self, *args):
        super().__init__(*args)

    def define_node(
        self,
        node: Node,
        tosa_graph: ts.TosaSerializer,
        inputs: List[TosaArg],
        output: TosaArg,
        is_quant_node: bool,
    ) -> None:

        assert len(node.all_input_nodes) == 1

        if is_quant_node:
            # Assume quantized input is 8 bit.

            # Create attribute for 8 bit table lookup.
            input_node = node.all_input_nodes[0]
            in_quantargs = get_quant_arg_upstream(input_node)
            output_node = list(node.users)[0]
            out_quantargs = get_quant_arg_downstream(output_node)

            table = exp_table_8bit(in_quantargs, out_quantargs)
            table_attr = ts.TosaSerializerAttribute()
            table_attr.TableAttribute(table)

            tosa_graph.addOperator(
                TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
            )
        else:
            tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name])


def exp_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
    """
    Returns a table mapping 256 entries to exp([qmin,qmax])
    """

    def exp(x):
        # Convert quantized input to floating point exp input space.
        v = dequantize_value(x, in_quantargs)
        # Compute exp.
        v = np.exp(v)
        # Convert exp output back to quantized space.
        return quantize_value(v, out_quantargs)

    return [
        exp(x)
        for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
    ]
