# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import operator
from types import NoneType
from typing import cast, List, Optional, Union

import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema

import torch

from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
    VkMemoryLayout,
    VkStorageType,
)
from executorch.backends.vulkan.utils import (
    is_constant,
    is_get_attr_node,
    is_param_node,
)
from executorch.exir.backend.utils import DelegateMappingBuilder

from executorch.exir.tensor import TensorSpec
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
from torch.export import ExportedProgram
from torch.fx import Node

_ScalarType = Union[bool, int, float]
_Argument = Union[
    Node, NoneType, _ScalarType, TensorSpec, List[_ScalarType], List[Node], str
]

logger: logging.Logger = logging.getLogger("")
logger.setLevel(logging.INFO)


class VkGraphBuilder:
    def __init__(
        self,
        program: ExportedProgram,
        delegate_mapping_builder: DelegateMappingBuilder,
    ) -> None:
        self.program = program
        self.delegate_mapping_builder = delegate_mapping_builder
        self.chain = []
        self.values = []
        self.input_ids = []
        self.output_ids = []
        self.const_tensors = []

        # Mapping from Node to VkValue id
        self.node_to_value_ids = {}

        # For logging
        self.seen_ops = set()

    @staticmethod
    def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
        if torch_dtype == torch.bool:
            return vk_graph_schema.VkDataType.BOOL
        elif torch_dtype == torch.uint8:
            return vk_graph_schema.VkDataType.UINT8
        elif torch_dtype == torch.int8:
            return vk_graph_schema.VkDataType.INT8
        elif torch_dtype == torch.int32:
            return vk_graph_schema.VkDataType.INT32
        elif torch_dtype == torch.float16:
            return vk_graph_schema.VkDataType.FLOAT16
        elif torch_dtype == torch.float32:
            return vk_graph_schema.VkDataType.FLOAT32
        # Narrowing conversion for index tensor produced by max_poolNd_with_indices.
        elif torch_dtype == torch.int64:
            return vk_graph_schema.VkDataType.INT32
        else:
            raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

    def get_constant(self, node: Node) -> Optional[torch.Tensor]:
        """
        Returns the constant associated with the given node in the exported program.
        Returns None if the node is not a constant within the exported program
        """
        if is_constant(self.program, node):
            constant_name = (
                self.program.graph_signature.inputs_to_lifted_tensor_constants[
                    node.name
                ]
            )
            if constant_name in self.program.constants:
                return self.program.constants[constant_name]
            else:
                return None

        return None

    def get_param_tensor(self, node: Node) -> torch.Tensor:
        tensor = None
        if node is None:
            raise RuntimeError("node is None")
        elif is_param(self.program, node):
            tensor = get_param(self.program, node)
        elif is_buffer(self.program, node):
            tensor = get_buffer(self.program, node)
        elif is_constant(self.program, node):
            tensor = self.get_constant(node)
        elif is_get_attr_node(node):
            # This is a hack to support both lifted and unlifted graph
            try:
                tensor = getattr(node.graph.owning_module, node.target)
            except AttributeError:
                tensor = getattr(self.program.graph_module, node.target)
        else:
            raise RuntimeError(f"unsupported param type, {node.op}.")

        assert tensor is not None
        return tensor

    def maybe_add_constant_tensor(self, node: Node) -> int:
        constant_id = -1
        if is_param_node(self.program, node):
            constant_id = len(self.const_tensors)
            self.const_tensors.append(self.get_param_tensor(node))

        return constant_id

    def create_node_value(self, node: Node) -> int:
        # If the node has been marked as a scalar tensor, create a SymInt instead of a tensor
        if node.meta.get("vkdg_is_scalar_tensor", False):
            new_id = self.create_symint_value()
            self.node_to_value_ids[node] = new_id
            return new_id

        spec = node.meta.get("spec")
        if isinstance(spec, TensorSpec):
            constant_id = self.maybe_add_constant_tensor(node)
            new_id = self.create_tensor_value(spec, constant_id)
            self.node_to_value_ids[node] = new_id
            return new_id
        elif isinstance(spec, list) or isinstance(spec, tuple):
            # pyre-ignore[6]: pyre having hard time to infer Node type inside
            # the container.
            new_id = self.create_value_list_value(spec)
            self.node_to_value_ids[node] = new_id
            return new_id
        else:
            raise RuntimeError(f"Cannot create value for spec of type {type(spec)}")

    def create_null_value(self) -> int:
        new_id = len(self.values)
        self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Null()))
        return new_id

    def create_scalar_value(self, scalar: _ScalarType) -> int:
        new_id = len(self.values)
        if isinstance(scalar, bool):
            self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar)))
        elif isinstance(scalar, int):
            self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
        elif isinstance(scalar, float):
            self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
        return new_id

    def create_symint_value(self) -> int:
        new_id = len(self.values)
        self.values.append(vk_graph_schema.VkValue(vk_graph_schema.SymInt(0)))
        return new_id

    def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
        # Negative id indicates that this tensor will have its own dedicated memory.
        mem_obj_id = -1
        if spec.mem_obj_id is not None:
            mem_obj_id = spec.mem_obj_id

        storage_type = VkStorageType.DEFAULT_STORAGE
        memory_layout = VkMemoryLayout.DEFAULT_LAYOUT
        if hasattr(spec, "vk_storage_type"):
            # pyre-ignore[16]
            storage_type = spec.vk_storage_type
        if hasattr(spec, "vk_memory_layout"):
            # pyre-ignore[16]
            memory_layout = spec.vk_memory_layout

        new_id = len(self.values)
        self.values.append(
            vk_graph_schema.VkValue(
                value=vk_graph_schema.VkTensor(
                    datatype=self.get_vk_datatype(spec.dtype),
                    dims=spec.shape,
                    constant_id=constant_id,
                    mem_obj_id=mem_obj_id,
                    storage_type=storage_type,
                    memory_layout=memory_layout,
                )
            )
        )
        return new_id

    def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
        new_id = len(self.values)
        if len(arg) == 0:
            self.values.append(
                vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[]))
            )
        elif isinstance(arg[0], bool):
            self.values.append(
                vk_graph_schema.VkValue(
                    vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg])
                )
            )
        elif isinstance(arg[0], int):
            self.values.append(
                vk_graph_schema.VkValue(
                    vk_graph_schema.IntList(items=[cast(int, e) for e in arg])
                )
            )
        elif isinstance(arg[0], float):
            self.values.append(
                vk_graph_schema.VkValue(
                    vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg])
                )
            )
        return new_id

    def create_value_list_value(self, arg: tuple | list) -> int:
        self.values.append(
            vk_graph_schema.VkValue(
                vk_graph_schema.ValueList(
                    items=[self.get_or_create_value_for(e) for e in arg]
                )
            )
        )
        return len(self.values) - 1

    def create_string_value(self, string: str) -> int:
        new_id = len(self.values)
        self.values.append(
            vk_graph_schema.VkValue(vk_graph_schema.String(string_val=string))
        )
        return new_id

    def get_or_create_value_for(self, arg: _Argument):
        if isinstance(arg, Node):
            # If the Node has already been processed, return the existing id.
            if arg in self.node_to_value_ids:
                return self.node_to_value_ids[arg]
            return self.create_node_value(arg)
        elif (
            isinstance(arg, NoneType)
            or isinstance(arg, torch.device)
            or isinstance(arg, torch.dtype)
            or isinstance(arg, torch.layout)
            or isinstance(arg, torch.memory_format)
        ):
            return self.create_null_value()
        elif isinstance(arg, _ScalarType):
            return self.create_scalar_value(arg)
        elif isinstance(arg, TensorSpec):
            return self.create_tensor_value(arg)
        elif isinstance(arg, list) and (
            len(arg) == 0 or isinstance(arg[0], _ScalarType)
        ):
            # pyre-ignore[6]
            return self.create_scalar_list_value(arg)
        elif isinstance(arg, list) and isinstance(arg[0], Node):
            return self.create_value_list_value(arg)
        elif isinstance(arg, torch.fx.immutable_collections.immutable_list):
            # pyre-ignore[6]
            return self.create_value_list_value(arg)
        elif isinstance(arg, str):
            return self.create_string_value(arg)
        else:
            raise RuntimeError(f"Cannot create value for arg of type {type(arg)}")

    def process_placeholder_node(self, node: Node) -> None:
        # ignores any tensors that don't get used in any ops
        if len(node.users) == 0:
            return None
        ids = self.create_node_value(node)
        if not is_param_node(self.program, node):
            if isinstance(ids, int):
                self.input_ids.append(ids)
            else:
                self.input_ids += ids

    def process_getitem_node(self, node: Node) -> None:
        # Find ValueList id from the collection node.
        collection_node = node.all_input_nodes[0]
        list_id = self.node_to_value_ids[collection_node]

        # Extract the target Value id from ValueList.
        valuelist_id = node.args[1]
        value_id = self.values[list_id].value.items[valuelist_id]

        # Map Node to Value id.
        self.node_to_value_ids[node] = value_id

    def process_call_function_node(self, node) -> None:
        operator_call_args = []

        self.seen_ops.add(node.target)

        for i, schema_arg in enumerate(node.target._schema.arguments):
            if not schema_arg.kwarg_only and i < len(node.args):
                function_arg = node.args[i]
            elif schema_arg.name in node.kwargs:
                function_arg = node.kwargs[schema_arg.name]
            else:
                function_arg = schema_arg.default_value

            # Create a Value for each function argument. If the argument has been
            # previously encountered, then use the existing Value id.
            operator_call_args.append(self.get_or_create_value_for(function_arg))

        # Add output node
        operator_call_args.append(self.create_node_value(node))
        operator_node_id = (
            0
            if not self.delegate_mapping_builder
            else self.delegate_mapping_builder.insert_delegate_mapping_entry(node)
        )
        self.chain.append(
            vk_graph_schema.OperatorCall(
                node_id=operator_node_id,  # pyre-ignore[6]: this is going to be an int
                name=node.target.__name__,
                args=operator_call_args,
            ),
        )

    def process_getattr_node(self, node: Node) -> None:
        self.create_node_value(node)

    def process_output_node(self, node: Node) -> None:
        for out_node in node.all_input_nodes:
            if out_node not in self.node_to_value_ids:
                raise AssertionError(
                    "Cannot find input to output node in node_to_value_ids. This means "
                    "the output node is being serialized before its corresponding "
                    "internal node which is not allowed."
                )
            self.output_ids.append(self.node_to_value_ids[out_node])

    def process_node(self, node: Node, call_node_debug_hdl: int) -> None:
        if node.op == "placeholder":
            self.process_placeholder_node(node)
        elif node.op == "call_function":
            if node.target == operator.getitem:
                self.process_getitem_node(node)
            else:
                node.meta["debug_handle"] = call_node_debug_hdl
                self.process_call_function_node(node)
        elif node.op == "get_attr":
            self.process_getattr_node(node)
        elif node.op == "output":
            self.process_output_node(node)
        else:
            raise AssertionError(f"Unsupported node op: {node.op}")

    def build_graph(self) -> vk_graph_schema.VkGraph:
        call_node_debug_hdl = 0
        for node in self.program.graph_module.graph.nodes:
            self.process_node(node, call_node_debug_hdl)
            call_node_debug_hdl += 1

        logger.info("Operators included in this Vulkan partition: ")
        for op in self.seen_ops:
            logger.info(f"    {op.__name__}")

        return vk_graph_schema.VkGraph(
            version="0",
            chain=self.chain,
            values=self.values,
            input_ids=self.input_ids,
            output_ids=self.output_ids,
            constants=[],
            shaders=[],
        )
