# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, List, Optional

###################################
## Generic Test Suite definition ##
###################################


class TestSuite:
    def __init__(self, input_cases: List[Any]):
        self.input_cases: List[Any] = input_cases
        self.prepacked_args: List[str] = []
        self.requires_prepack: bool = False
        self.dtypes: List[str] = ["at::kFloat", "at::kHalf"]

        self.data_gen: str = "make_rand_tensor"
        self.data_range = (0, 1)

        self.arg_dtype = {}
        self.arg_data_range = {}

        self.atol: str = "1e-5"
        self.rtol: str = "1e-5"

        self.is_view_op: bool = False
        self.test_name_suffix: Optional[str] = None

    def supports_prepack(self):
        return len(self.prepacked_args) > 0


##################################
## Vulkan Test Suite Definition ##
##################################


@dataclass
class VkTestSuite(TestSuite):
    def __init__(self, input_cases: List[Any]):
        super().__init__(input_cases)
        self.storage_types: List[str] = ["utils::kTexture3D"]
        self.layouts: List[str] = ["utils::kChannelsPacked"]
        self.data_gen: str = "make_rand_tensor"
