# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
Please refer to fbcode/caffe2/executorch/backends/vulkan/serialization/schema/schema.fbs for the schema definitions
"""

from dataclasses import dataclass
from enum import IntEnum
from typing import List, Union


@dataclass
class OperatorCall:
    node_id: int
    name: str
    args: List[int]


class VkDataType(IntEnum):
    BOOL = 0
    UINT8 = 1
    INT8 = 2
    INT32 = 3
    FLOAT16 = 4
    FLOAT32 = 5


class VkStorageType(IntEnum):
    BUFFER = 0
    TEXTURE_3D = 1
    TEXTURE_2D = 2
    DEFAULT_STORAGE = 255

    def __str__(self) -> str:
        return self.name


class VkMemoryLayout(IntEnum):
    TENSOR_WIDTH_PACKED = 0
    TENSOR_HEIGHT_PACKED = 1
    TENSOR_CHANNELS_PACKED = 2
    DEFAULT_LAYOUT = 255

    def __str__(self) -> str:
        return self.name


@dataclass
class VkTensor:
    datatype: VkDataType
    dims: List[int]
    constant_id: int
    mem_obj_id: int
    storage_type: VkStorageType = VkStorageType.DEFAULT_STORAGE
    memory_layout: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT


@dataclass
class Null:
    pass


@dataclass
class Int:
    int_val: int


@dataclass
class Bool:
    bool_val: bool


@dataclass
class Double:
    double_val: float


@dataclass
class IntList:
    items: List[int]


@dataclass
class DoubleList:
    items: List[float]


@dataclass
class BoolList:
    items: List[bool]


@dataclass
class ValueList:
    items: List[int]


@dataclass
class String:
    string_val: str


@dataclass
class SymInt:
    value: int


GraphTypes = Union[
    Null,
    Int,
    Double,
    Bool,
    VkTensor,
    IntList,
    BoolList,
    DoubleList,
    ValueList,
    String,
    SymInt,
]


@dataclass
class VkValue:
    value: "GraphTypes"


@dataclass
class VkBytes:
    offset: int
    length: int


@dataclass
class VkGraph:
    version: str

    chain: List[OperatorCall]
    values: List[VkValue]

    input_ids: List[int]
    output_ids: List[int]

    constants: List[VkBytes]
    shaders: List[VkBytes]

    storage_type_override: VkStorageType = VkStorageType.DEFAULT_STORAGE
    memory_layout_override: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT
