#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import argparse
import os
import sys
from typing import Any, List

import yaml

from torchgen.code_template import CodeTemplate


ops_and_dtypes_template_str = """((exec_aten::string_view(operator_name).compare("$operator_name") == 0)\n        && ($dtype_checks))"""
ops_and_dtypes_template = CodeTemplate(ops_and_dtypes_template_str)

selected_kernel_dtypes_h_template_str = """#pragma once
/**
 * Generated by executorch/codegen/tools/gen_selected_op_variants.py
 */

inline constexpr bool should_include_kernel_dtype(
  const char *operator_name,
  exec_aten::ScalarType scalar_type
) {
  return $body;
}
"""
selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)

# enum from: https://github.com/pytorch/executorch/blob/main/runtime/core/portable_type/scalar_type.h
dtype_enum_to_type = {
    "0": "Byte",
    "1": "Char",
    "2": "Short",
    "3": "Int",
    "4": "Long",
    "5": "Half",
    "6": "Float",
    "7": "Double",
    "8": "ComplexHalf",
    "9": "ComplexFloat",
    "10": "ComplexDouble",
    "11": "Bool",
    "12": "QInt8",
    "13": "QUInt8",
    "14": "QInt32",
    "15": "BFloat16",
    "16": "QUInt4x2",
    "17": "QUInt2x4",
    "18": "Bits1x8",
    "19": "Bits2x4",
    "20": "Bits4x2",
    "21": "Bits8",
    "22": "Bits16",
}


def write_selected_op_variants(yaml_file_path: str, output_dir: str) -> None:
    with open(yaml_file_path, "r") as selected_operators_file:
        # Collect et_kernel_metadata from selected_operators.yaml and extract dtypes
        # Example format: v1/6;0,1|6;0,1|6;0,1|6;0,1  # Float, 0, 1
        selected_operators_dict = yaml.safe_load(selected_operators_file)
        et_kernel_metadata = selected_operators_dict.get("et_kernel_metadata", {})
        assert isinstance(et_kernel_metadata, dict)
        body = "true"
        body_parts = []
        for operator_name, kernel_metadata_str in et_kernel_metadata.items():
            tensor_meta = []
            for kernel_metadata in kernel_metadata_str:
                if kernel_metadata == "default" or "/" not in kernel_metadata:
                    break
                else:
                    x = kernel_metadata.split("/")[1]
                    tensor_meta.extend(x.split("|"))
            conditions = ["true"]
            if len(tensor_meta) > 0:
                dtype_set = set([x.split(";")[0] for x in tensor_meta])
                dtype_list = sorted([dtype_enum_to_type[x] for x in dtype_set])
                conditions = [
                    "scalar_type == exec_aten::ScalarType::" + x for x in dtype_list
                ]
            body_parts.append(
                ops_and_dtypes_template.substitute(
                    operator_name=operator_name.replace("aten::", ""),
                    dtype_checks=" || ".join(conditions),
                ),
            )
            body = "\n || ".join(body_parts)
        header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
        selected_op_variants_path = os.path.join(output_dir, "selected_op_variants.h")
        with open(selected_op_variants_path, "wb") as out_file:
            out_file.write(header_contents.encode("utf-8"))


def main(argv: List[Any]) -> None:
    parser = argparse.ArgumentParser(description="Generate operator lists")
    parser.add_argument(
        "--yaml-file-path",
        "--yaml_file_path",
        help=("The directory where selected_operators.yaml was generated)"),
        required=True,
    )
    parser.add_argument(
        "--output-dir",
        "--output_dir",
        help=(
            "The directory to store the output yaml files (selected_op_variants.h, "
            + "selected_kernel_dtypes.h, selected_operators.yaml)"
        ),
        required=True,
    )

    options = parser.parse_args(argv)
    write_selected_op_variants(options.yaml_file_path, options.output_dir)


if __name__ == "__main__":
    main(sys.argv[1:])
