# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
from typing import Any, Dict, Tuple

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import (
    QCOM_AXIS,
    QCOM_AXIS_ORDER,
    QCOM_BITWIDTH,
    QCOM_DTYPE,
    QCOM_ENCODING,
    QCOM_OFFSET,
    QCOM_QUANT_ATTRS,
    QCOM_QUANT_MAX,
    QCOM_QUANT_MIN,
    QCOM_REQUANTIZE,
    QCOM_SCALE,
    QCOM_SCALE_OFFSET,
    QCOM_SCALES,
    QCOM_ZERO_POINT,
    QCOM_ZERO_POINTS,
)

from executorch.exir.dialects._ops import ops as exir_ops

from .utils import (
    deduce_dtype,
    get_parameter,
    is_graph_input,
    is_graph_output,
    is_parameter,
)


QNN_QUANT_TYPE_MAP = {
    torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8,
    torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16,
    torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32,
    # Note that there is no int64 tensor data type in Qnn.
    torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED,
    torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
    torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
}
QNN_TENSOR_TYPE_MAP = {
    torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
    torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
    torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
    torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,
    torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
    torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
    torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
    torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
    float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
}

PER_CHANNEL_ENCODING = {
    exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
    exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
}

PER_TENSOR_ENCODING = {
    exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
    exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
}


class NodeVisitor:
    """
    Node visitor pattern for visiting nodes in an edge IR graph
    """

    def __init__(
        self,
        external_ids,
        edge_program: torch.export.ExportedProgram,
        enable_tensor_dump,
    ) -> None:
        self.external_ids = external_ids or {}
        self.edge_program = edge_program
        self.enable_tensor_dump = enable_tensor_dump

    def get_tensor(self, input_node, op_node, idx=None):
        """
        Get tensor value/shape with axis_order
        """

        def _get_tensor(node, index):
            if index is not None:
                assert isinstance(index, int)
                if is_parameter(node, self.edge_program):
                    return get_parameter(node, self.edge_program)[index]
                return node.meta["val"][index]

            if is_parameter(node, self.edge_program):
                return get_parameter(node, self.edge_program)
            return node.meta["val"]

        tensor = _get_tensor(input_node, idx)
        if len(tensor.shape) != 0 and QCOM_AXIS_ORDER in op_node.meta:
            tensor = tensor.permute(dims=op_node.meta[QCOM_AXIS_ORDER]).contiguous()
        return tensor

    def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
        quant_config = copy.deepcopy(quant_attrs)

        scales = quant_attrs[QCOM_SCALES]
        zero_points = quant_attrs[QCOM_ZERO_POINTS]
        assert len(scales) == len(
            zero_points
        ), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}"

        scale_offset = []
        for i in range(len(scales)):
            # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
            scale_offset.append(
                PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
            )

        user_0 = list(node.users)[0]
        # Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
        if (
            "convolution" in user_0.target.__name__
            and list(node.users)[0].args[1] == node
        ):
            quant_config[QCOM_AXIS] = 3

        else:
            quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]

        quant_config[QCOM_SCALE_OFFSET] = scale_offset
        # special case for 4 bits
        if (
            quant_config[QCOM_DTYPE] == torch.int8
            and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
        ):
            quant_config[QCOM_BITWIDTH] = 4
            return (
                PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
                quant_config,
            )
        return (
            PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
            quant_config,
        )

    def make_qnn_per_tensor_config(self, quant_attrs: Dict):
        quant_config = copy.deepcopy(quant_attrs)
        # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
        quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT]
        # special case for 4 bits
        if (
            quant_config[QCOM_DTYPE] == torch.int8
            and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
        ):
            quant_config[QCOM_BITWIDTH] = 4
            return (
                PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET,
                quant_config,
            )
        return (
            PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
            quant_config,
        )

    def get_quant_encoding_conf(
        self, node: torch.fx.Node, is_input_tensor: bool = False
    ) -> Tuple[Any, Dict]:
        if not node.meta.get(QCOM_QUANT_ATTRS, None):
            return (
                PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
                {},
            )
        quant_attrs = (
            node.meta[QCOM_REQUANTIZE]
            if QCOM_REQUANTIZE in node.meta and is_input_tensor
            else node.meta[QCOM_QUANT_ATTRS]
        )
        if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
            return self.make_qnn_per_channel_config(node, quant_attrs)

        return self.make_qnn_per_tensor_config(quant_attrs)

    def get_quant_tensor_value(
        self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
    ) -> torch.Tensor:
        if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING:
            scale = quant_attrs[QCOM_SCALE]
            zero_point = quant_attrs[QCOM_ZERO_POINT]
        else:  # per channel case
            scale = quant_attrs[QCOM_SCALES]
            zero_point = quant_attrs[QCOM_ZERO_POINTS]

        dtype = quant_configs[QCOM_DTYPE]

        tensor = tensor.div(scale).add(zero_point).round().to(dtype)
        # Make the backends access data correctly
        if quant_configs.get(QCOM_BITWIDTH) == 4:
            mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
            tensor = torch.bitwise_and(mask, tensor)
        return tensor

    def get_tensor_type(
        self,
        node: torch.fx.Node,
        tensor_type: PyQnnWrapper.Qnn_TensorType_t,
    ) -> PyQnnWrapper.Qnn_TensorType_t:
        is_input = is_graph_input(node, self.edge_program)
        is_output = is_graph_output(node)
        # handle logic for input/output tensors
        if is_input or is_output:
            assert (
                node in self.external_ids
            ), f"Node {node}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}"
            if is_input:
                return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_WRITE
            if is_output:
                return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ

        if is_parameter(node, self.edge_program):
            return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
        # dump all tensor, set to app read, and we only dump native tensors
        if (
            self.enable_tensor_dump
            and tensor_type == PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
        ):
            return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ
        return tensor_type

    def get_data_type(
        self,
        tensor: torch.Tensor,
        quant_config: Dict,
    ) -> PyQnnWrapper.Qnn_TensorType_t:
        if quant_config:
            quant_config[QCOM_DTYPE] = deduce_dtype(tensor, quant_config)
            return QNN_QUANT_TYPE_MAP[quant_config[QCOM_DTYPE]]

        return QNN_TENSOR_TYPE_MAP[tensor.dtype]

    def define_custom_tensor_wrapper(
        self,
        node_name: str,
        tensor_type: PyQnnWrapper.Qnn_TensorType_t,
        dtype: PyQnnWrapper.Qnn_DataType_t,
        quant_encoding: PyQnnWrapper.Qnn_QuantizationEncoding_t,
        quant_configs: dict,
        dims: torch.Size,
        tensor: torch.Tensor,
        is_fake_tensor: bool,
        nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
        wrapper_idx: int = 0,
    ) -> PyQnnWrapper.TensorWrapper:
        if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
            return cached
        if is_fake_tensor:
            tensor_wrapper = PyQnnWrapper.TensorWrapper(
                node_name,
                tensor_type,
                dtype,
                quant_encoding,
                quant_configs,
                len(dims),
                dims,
                np.array([]),
                False,
            )
        else:
            # Can implement non-fake tensor when there is a need
            return None
        nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
        return tensor_wrapper

    def define_tensor(
        self,
        node: torch.fx.Node,
        tensor: torch.Tensor,
        tensor_type: PyQnnWrapper.Qnn_TensorType_t,
        nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
        is_input_tensor: bool,
        node_name: str = None,
        wrapper_idx: int = 0,
    ) -> PyQnnWrapper.TensorWrapper:
        """
        Covert torch.Tensor to TensorWrapper

        Args:
            node: EdgeIR Node
            tensor: EdgeIR Tensor
            tensor_type: QNN tensor type
            nodes_to_wrappers: Set contains edge_graph values(node targets)
            is_input_tensor: Whether tensor is a fake input tensor relatively to
                             the op builder that is calling this function
        """
        if node_name is None:
            node_name = node.name

        if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
            return cached

        tensor_name = f"{node.name}_{wrapper_idx}"
        if is_graph_input(node, self.edge_program):
            tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
        if is_graph_output(node):
            tensor_name = "output_" + tensor_name
        dims = [1] if len(tensor.size()) == 0 else tensor.size()
        tensor_type = self.get_tensor_type(node, tensor_type)
        quant_encoding, quant_configs = self.get_quant_encoding_conf(
            node, is_input_tensor
        )
        dtype = self.get_data_type(tensor, quant_configs)
        if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
            tensor_wrapper = PyQnnWrapper.TensorWrapper(
                tensor_name,
                tensor_type,
                dtype,
                quant_encoding,
                quant_configs,
                len(dims),
                dims,
                np.array([]),
                False,
            )
        else:
            if quant_configs:
                tensor = self.get_quant_tensor_value(
                    tensor,
                    node.meta[QCOM_QUANT_ATTRS],
                    quant_configs,
                )
            tensor_wrapper = PyQnnWrapper.TensorWrapper(
                tensor_name,
                tensor_type,
                dtype,
                quant_encoding,
                quant_configs,
                len(dims),
                dims,
                tensor.detach().numpy(),
                True,
            )
        nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
        return tensor_wrapper

    def define_node(
        self,
        node: torch.fx.Node,
        nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
    ) -> PyQnnWrapper.PyQnnOpWrapper:
        """Convert torch.fx.Node to OpWrapper"""
        raise NotImplementedError("NodeVisitor must be extended!")


# This will hold mapping of all node names to the visitor class
_node_visitor_dict = {}


def register_node_visitor(visitor):
    """Register node visitor into _node_visitor_dict"""
    assert (
        isinstance(visitor, type)
        and issubclass(visitor, NodeVisitor)
        and hasattr(visitor, "target")
    ), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
    for target in visitor.target:
        _node_visitor_dict[target] = visitor


def generate_node_to_external_map(
    edge_program: torch.export.ExportedProgram,
) -> Dict[torch.fx.Node, int]:
    node_to_external_map = {}
    for node in edge_program.graph_module.graph.nodes:
        # The order in which we visit the placeholder node is same as the *args
        # order for the forward(*args) signature for this gm. Using the order of
        # the nodes as external_id to extract the right arg from *args at runtime
        if is_graph_input(node, edge_program):
            node_to_external_map[node] = len(node_to_external_map)
    for node in edge_program.graph_module.graph.nodes:
        if is_graph_output(node):
            node_to_external_map[node] = len(node_to_external_map)
    return node_to_external_map


def get_node_visitors(
    edge_program: torch.export.ExportedProgram,
    enable_tensor_dump=False,
) -> Dict[str, NodeVisitor]:
    """Create a new class instance at runtime, and put them in a dict"""
    node_to_external_map = generate_node_to_external_map(edge_program)
    node_visitors = {}
    for target, visitor in _node_visitor_dict.items():
        assert callable(
            visitor
        ), f"Expeting a callable class, but got {visitor} of type {type(visitor)}"
        node_visitors[target] = visitor(
            node_to_external_map, edge_program, enable_tensor_dump
        )
    return node_visitors
