# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from enum import IntEnum
from typing import Optional, Set, Tuple

import torch

from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
    VkMemoryLayout,
    VkStorageType,
)

from executorch.exir.tensor import TensorSpec

from torch._export.utils import is_buffer, is_param

from torch._subclasses.fake_tensor import FakeTensor

from torch.export import ExportedProgram

##
## Node type determination
##


def is_get_attr_node(node: torch.fx.Node) -> bool:
    return isinstance(node, torch.fx.Node) and node.op == "get_attr"


def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool:
    return node.name in program.graph_signature.inputs_to_lifted_tensor_constants


def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
    """
    Check if the given node is a parameter within the exported program
    """
    return (
        is_get_attr_node(node)
        or is_param(program, node)
        or is_buffer(program, node)
        or is_constant(program, node)
    )


def is_symint_node(node: torch.fx.Node) -> bool:
    """
    Returns true if the given node produces a SymInt value
    """
    if "val" not in node.meta:
        return False

    if isinstance(node.meta["val"], torch.SymInt):
        return True

    return False


def is_tensor_node(node: torch.fx.Node) -> bool:
    """
    Returns true if the given node produces a tensor value, or a collection of tensor values
    """
    # All nodes with tensor values are tagged by the SpecPropPass transform
    if "spec" in node.meta:
        return True

    if "val" not in node.meta:
        return False

    if isinstance(node.meta["val"], FakeTensor):
        return True

    if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
        return all(isinstance(x, FakeTensor) for x in node.meta["val"])

    return False


##
## Memory Layout, Storage Type Determination
##

ImageExtents = Tuple[int, int, int]

DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048)
DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024)


class PackedDim(IntEnum):
    WIDTH = 0
    HEIGHT = 1
    CHANNELS = 2


all_packed_dims: Set[PackedDim] = {
    PackedDim.WIDTH,
    PackedDim.HEIGHT,
    PackedDim.CHANNELS,
}

all_storage_types: Set[VkStorageType] = {
    VkStorageType.BUFFER,
    VkStorageType.TEXTURE_3D,
}

all_memory_layouts: Set[VkMemoryLayout] = {
    VkMemoryLayout.TENSOR_WIDTH_PACKED,
    VkMemoryLayout.TENSOR_HEIGHT_PACKED,
    VkMemoryLayout.TENSOR_CHANNELS_PACKED,
}


def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
    """
    Checks whether the tensors produced by the given node can fit within the device's
    GPU buffer limit, which represents the maximum number of elements that can be stored
    in a GPU buffer.
    """
    assert is_tensor_node(node)

    if isinstance(node.meta["val"], FakeTensor):
        return node.meta["val"].numel() < buffer_limit
    elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
        return all(x.numel() < buffer_limit for x in node.meta["val"])
    else:
        raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")


def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
    """
    Calculate the image extents that will be used to represent a tensor with the given sizes
    and memory layout in the Vulkan Delegate.
    """
    width = sizes[-1] if len(sizes) >= 1 else 1
    height = sizes[-2] if len(sizes) >= 2 else 1
    channels = sizes[-3] if len(sizes) >= 3 else 1
    batch = sizes[0] if len(sizes) >= 4 else 1

    if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED:
        width = (width + 3) // 4
    elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED:
        height = (height + 3) // 4
    elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED:
        channels = (channels + 3) // 4
    else:
        raise RuntimeError(f"Unsupported memory layout {layout}")

    return width, height, channels * batch


def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool:
    return all(extents[i] <= limits[i] for i in range(len(extents)))


def valid_texture_memory_layouts(
    tensor_sizes: torch.Size, texture_limits: ImageExtents
) -> Set[VkMemoryLayout]:
    """
    Given tensor sizes, determine the set of memory layouts which will prodice a texture
    that can fit within the specified device limits.
    """
    valid_layouts = set()
    for layout in list(all_memory_layouts):
        extents = required_image_extents(tensor_sizes, layout)
        if extents_are_valid(extents, texture_limits):
            valid_layouts.add(layout)

    return valid_layouts


def possible_node_memory_layouts(
    node: torch.fx.Node, texture_limits: ImageExtents
) -> Set[VkMemoryLayout]:
    """
    Given a node, determine the set of memory layouts which can be used to represent all
    tensors involved in the computation.
    """
    assert is_tensor_node(node)
    if isinstance(node.meta["val"], FakeTensor):
        return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits)
    valid_layouts = set()
    if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
        for fake_tensor in node.meta["val"]:
            valid_layouts = valid_layouts.union(
                valid_texture_memory_layouts(fake_tensor.shape, texture_limits)
            )

    return valid_layouts


##
## TensorSpec Utils
##


def set_node_spec_attr(node: torch.fx.Node, attr: str, value):
    assert "spec" in node.meta
    spec = node.meta["spec"]
    if isinstance(spec, TensorSpec):
        setattr(spec, attr, value)
    elif isinstance(spec, list) or isinstance(spec, tuple):
        for s in spec:
            assert isinstance(s, TensorSpec)
            setattr(s, attr, value)
    else:
        raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")


def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True):
    assert "spec" in node.meta
    spec = node.meta["spec"]
    if isinstance(spec, TensorSpec):
        return getattr(spec, attr) if hasattr(spec, attr) else None
    elif isinstance(spec, list) or isinstance(spec, tuple):
        if return_first:
            return getattr(spec[0], attr) if hasattr(spec, attr) else None
        else:
            return [getattr(s, attr) if hasattr(s, attr) else None for s in spec]
    else:
        raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}")


def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]:
    return get_node_spec_attr(node, "vk_storage_type")


def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]:
    return get_node_spec_attr(node, "vk_memory_layout")
