# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from dataclasses import dataclass
from enum import IntEnum
from typing import List, Optional, Union

from executorch.exir.backend.compile_spec_schema import CompileSpec

from executorch.exir.scalar_type import ScalarType


@dataclass
class AllocationDetails:
    memory_id: int
    # Low 32 bits
    memory_offset_low: int
    # High 32 bits (typically zero)
    memory_offset_high: int

    @property
    def memory_offset(self) -> int:
        return self.memory_offset_low | (self.memory_offset_high << 32)


@dataclass
class OptionalTensorList:
    items: List[int]


class TensorShapeDynamism(IntEnum):
    """
    Check program.fbs for explanations of this enum.
    """

    STATIC = 0
    DYNAMIC_BOUND = 1
    DYNAMIC_UNBOUND = 2


@dataclass
class ExtraTensorInfo:
    """
    Check program.fbs for explanations of this enum.
    """

    mutable_data_segments_idx: Optional[int] = None
    fully_qualified_name: Optional[str] = None


@dataclass
class Tensor:
    scalar_type: ScalarType
    storage_offset: int
    sizes: List[int]
    dim_order: List[bytes]
    requires_grad: bool
    layout: int
    data_buffer_idx: int
    allocation_info: Optional[AllocationDetails]

    # check program.fbs for explanations.
    shape_dynamism: TensorShapeDynamism
    extra_tensor_info: Optional[ExtraTensorInfo] = None


@dataclass
class Null:
    pass


@dataclass
class Int:
    int_val: int


@dataclass
class Bool:
    bool_val: bool


@dataclass
class Double:
    double_val: Union[float, str]

    def __init__(self, double_val: float) -> None:
        if double_val == float("inf"):
            self.double_val = "inf"
        elif double_val == float("-inf"):
            self.double_val = "-inf"
        else:
            self.double_val = double_val

    def __post_init__(self) -> None:
        if isinstance(self.double_val, str):
            assert self.double_val in ["inf", "-inf"]
        else:
            assert isinstance(self.double_val, float)
            assert not self.double_val == float("inf")
            assert not self.double_val == float("-inf")


@dataclass
class String:
    string_val: str


@dataclass
class ContainerMetadata:
    encoded_inp_str: str
    encoded_out_str: str


@dataclass
class IntList:
    items: List[int]


@dataclass
class DoubleList:
    items: List[float]


@dataclass
class BoolList:
    items: List[bool]


@dataclass
class TensorList:
    items: List[int]


KernelTypes = Union[
    Int,
    Double,
    Bool,
    String,
    Tensor,
    IntList,
    BoolList,
    DoubleList,
    TensorList,
    Null,
    OptionalTensorList,
]


@dataclass
class EValue:
    # Union types must be specified as strings so DataclassEncoder can see them.
    val: "KernelTypes"


@dataclass
class Buffer:
    storage: bytes


@dataclass
class BackendDelegateInlineData:
    data: bytes


@dataclass
class KernelCall:
    op_index: int
    args: List[int]


@dataclass
class DelegateCall:
    delegate_index: int
    args: List[int]


@dataclass
class MoveCall:
    move_from: int
    move_to: int


@dataclass
class JumpFalseCall:
    cond_value_index: int
    destination_instruction: int


@dataclass
class FreeCall:
    value_index: int


InstructionArguments = Union[
    KernelCall,
    DelegateCall,
    MoveCall,
    JumpFalseCall,
    FreeCall,
]


@dataclass
class Instruction:
    instr_args: "InstructionArguments"


@dataclass
class Frame:
    filename: str
    lineno: int
    name: str
    context: str


@dataclass
class FrameList:
    items: List[Frame]


class DataLocation(IntEnum):
    INLINE = 0
    SEGMENT = 1


@dataclass
class BackendDelegateDataReference:
    location: DataLocation
    index: int


@dataclass
class BackendDelegate:
    id: str
    processed: BackendDelegateDataReference
    compile_specs: List[CompileSpec]


@dataclass
class Chain:
    inputs: List[int]
    outputs: List[int]
    instructions: List[Instruction]
    stacktrace: Optional[List[FrameList]]


@dataclass
class Operator:
    name: str
    overload: str


@dataclass
class ExecutionPlan:
    name: str
    container_meta_type: ContainerMetadata
    values: List[EValue]
    inputs: List[int]
    outputs: List[int]
    chains: List[Chain]
    operators: List[Operator]
    delegates: List[BackendDelegate]
    # the list index is memory buffer id, the value is the memory buffer size.
    # memory_buffer_id == 0 is special and is for constant memory buffer.
    # Runtime should use the len(constant_buffer) as the ground truch of
    # constant memory buffer size, and ignore non_const_buffer_sizes[0].
    non_const_buffer_sizes: List[int]


@dataclass
class DataSegment:
    offset: int
    size: int


@dataclass
class SubsegmentOffsets:
    segment_index: int
    offsets: List[int]


@dataclass
class Program:
    version: int
    execution_plan: List[ExecutionPlan]
    constant_buffer: List[Buffer]
    backend_delegate_data: List[BackendDelegateInlineData]
    segments: List[DataSegment]
    constant_segment: SubsegmentOffsets
    mutable_data_segments: Optional[List[SubsegmentOffsets]] = None
