# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp.

from __future__ import annotations

import argparse
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Sequence, TYPE_CHECKING

import yaml

from torchgen.api import cpp, unboxing
from torchgen.api.translate import translate
from torchgen.api.types import CppSignatureGroup
from torchgen.api.unboxing import convert_arguments
from torchgen.context import method_with_native_function
from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml
from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target


if TYPE_CHECKING:
    from torchgen.selective_build.selector import SelectiveBuilder


# Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
@dataclass(frozen=True)
class ComputeUnboxingFunctions:
    target: Literal[Target.DECLARATION, Target.DEFINITION]
    selector: SelectiveBuilder

    @method_with_native_function
    def __call__(self, f: NativeFunction) -> str:
        if not self.selector.is_root_operator(f"aten::{f.func.name}"):
            return ""

        if self.target is Target.DECLARATION:
            # Note [The ATen Codegen Unboxing API]
            # Similar to the ATen Operators API, ATen Codegen Unboxing API lives in the at::unboxing namespace, and
            # will be used by codegen unboxing wrappers (CodegenUnboxingWrappers.cpp).
            # The Wrappers will be registered into torch::jit::OperatorRegistry using RegisterOperators API.
            #
            # Important characteristics about the Codegen Unboxing API:
            # (1) It follows the OperatorRegistry API.
            #     This is kind of necessary to avoid overhead.
            #     For example: if it followed the C++ API, then all of the faithful C++ factory functions
            #     would need to wrap their arguments into TensorOptions only to unwrap them again.
            # (2) Under the hood it calls C++ API.
            return f"""
// aten::{f.func}
TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack);
"""
        else:
            sig_group = CppSignatureGroup.from_native_function(
                f, method=(Variant.method in f.variants)
            )
            sig = sig_group.most_faithful_signature()
            # parse arguments into C++ code
            binding_list, code_list = convert_arguments(f)

            # for each C++ argument, generate the conversion code
            code_connector = "\n\t"
            arg_connector = ", "
            # function call and push back to stack
            prefix = "self_base." if sig.method else "at::"
            translated_args = translate(
                binding_list, sig.arguments(), method=sig.method
            )
            args_str = f"{arg_connector.join(e.expr for e in translated_args)}"
            if len(f.func.returns) == 0:
                ret_str = ""
                push_str = ""
            else:
                ret_str = "auto result_ = "
                push_str = """
    pack(stack, std::move(result_));
                """
            return f"""
// aten::{f.func}
TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack) {{
    {code_connector.join(code_list)}

    drop(stack, {len(binding_list)});

    {ret_str}{prefix}{sig.name()}({args_str});
    {push_str}
}}
"""


# Generates RegisterCodegenUnboxedKernels.cpp.
@dataclass(frozen=True)
class ComputeCodegenUnboxedKernels:
    selector: SelectiveBuilder

    @method_with_native_function
    def __call__(self, f: NativeFunction) -> str:
        if not self.selector.is_root_operator(f"aten::{f.func.name}"):
            return ""
        # We unconditionally generate function wrappers,
        sig_group = CppSignatureGroup.from_native_function(f, method=False)

        sig = sig_group.most_faithful_signature()

        # escape double quote in schema, get rid of extra double quotes
        schema = cpp_string(str(sig.func))[1:-1]

        # arguments
        args = sig.arguments()
        connector = ",\n\t\t"
        args_code = []
        for arg in args:
            # Using method=False faithful C++ API, so we should not see SelfArgument/TensorOptionsArgument
            assert isinstance(arg.argument, Argument)
            if not arg.argument.default:
                arg_cpp = "c10::IValue(::std::nullopt)"
            else:
                # The unboxing code uses the faithful C++ API to avoid the overhead
                # from wrapping/unwrapping TensorOptios.
                # However, we would look to include default args for schema parsing.
                # Default args only show up in the nonfaithful C++ API,
                arg_default = cpp.default_expr(
                    arg.argument.default, arg.argument.type, symint=False
                )
                if arg_default.startswith("{"):
                    arg_cpp = f"c10::IntArrayRef({arg_default})"
                else:
                    arg_cpp = f"c10::IValue({arg_default})"
            args_code.append(
                f"""c10::Argument("{arg.name}", nullptr, ::std::nullopt, {arg_cpp})"""
            )

        returns = f.func.returns
        returns_code = []
        for ret in returns:
            returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""")
        return f"""
// aten::{schema}
OperatorGenerator(
    "aten::{f.func.name.name}",
    "{f.func.name.overload_name}",
    {{
        {connector.join(args_code)}
    }},
    {{
        {connector.join(returns_code)}
    }},
    [](Stack & stack) {{
        RECORD_FUNCTION("{sig.name()}", std::vector<c10::IValue>());
        at::unboxing::{unboxing.name(f)}(stack);
    }},
    aliasAnalysisFromSchema()
),
"""


def gen_unboxing(
    *,
    native_functions: Sequence[NativeFunction],
    cpu_fm: FileManager,
    selector: SelectiveBuilder,
) -> None:
    def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str:
        return fn.root_name

    selected_op_num: int = len(selector.operators)
    # a best practice threshold of operators to enable sharding
    sharding_threshold: int = 100
    cpu_fm.write_sharded(
        "UnboxingFunctions.cpp",
        native_functions,
        key_fn=key_func,
        env_callable=lambda fn: {
            "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
        },
        num_shards=1 if selected_op_num < sharding_threshold else 5,
        sharded_keys={"definitions"},
    )
    cpu_fm.write(
        "UnboxingFunctions.h",
        lambda: {
            "declarations": list(
                mapMaybe(
                    ComputeUnboxingFunctions(Target.DECLARATION, selector),
                    native_functions,
                )
            ),
        },
    )
    cpu_fm.write_sharded(
        "RegisterCodegenUnboxedKernels.cpp",
        native_functions,
        key_fn=key_func,
        env_callable=lambda fn: {
            "unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
        },
        num_shards=1 if selected_op_num < sharding_threshold else 10,
        sharded_keys={"unboxed_ops"},
    )


def main(args: list[str]) -> None:
    parser = argparse.ArgumentParser(description="Generate unboxing source files")
    parser.add_argument(
        "-s",
        "--source-path",
        help="path to source directory for ATen",
        default="aten/src/ATen",
    )
    parser.add_argument(
        "-d",
        "--install-dir",
        "--install_dir",
        help="output directory",
        default="build/aten/src/ATen",
    )
    parser.add_argument(
        "-o",
        "--output-dependencies",
        help="output a list of dependencies into the given file and exit",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="run without writing any files (still updates outputs)",
    )
    parser.add_argument(
        "--op-selection-yaml-path",
        "--op_selection_yaml_path",
        help="Provide a path to the operator selection (for custom build) YAML "
        "that contains the information about the set of selected operators "
        "and their categories (training, ...). Each operator is either a "
        "full operator name with overload or just a bare operator name. "
        "The operator names also contain the namespace prefix (e.g. aten::)",
    )
    parser.add_argument(
        "--op-registration-allowlist",
        "--op_registration_allowlist",
        nargs="*",
        help="filter op registrations by the allowlist (if set); "
        "each item is `namespace`::`operator name` without overload name; "
        "e.g.: aten::empty aten::conv2d ...",
    )
    parser.add_argument(
        "--TEST-ONLY-op-registration-allowlist-yaml-path",
        "--TEST_ONLY_op_registration_allowlist_yaml_path",
        help="Provide a path to the operator selection (for custom build) YAML "
        "which contains a list of operators. It is to serve testing purpose and "
        "each item is `namespace`::`operator name` without overload name; "
        "e.g.: aten::empty aten::conv2d ...",
    )

    options = parser.parse_args(args)
    if options.op_registration_allowlist:
        op_registration_allowlist = options.op_registration_allowlist
    elif options.TEST_ONLY_op_registration_allowlist_yaml_path:
        with open(options.TEST_ONLY_op_registration_allowlist_yaml_path) as f:
            op_registration_allowlist = yaml.safe_load(f)
    else:
        op_registration_allowlist = None

    selector = get_custom_build_selector(
        op_registration_allowlist,
        options.op_selection_yaml_path,
    )

    native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
    tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
    parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
    native_functions, backend_indices = (
        parsed_yaml.native_functions,
        parsed_yaml.backend_indices,
    )

    cpu_fm = make_file_manager(options=options)
    gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)

    if options.output_dependencies:
        depfile_path = Path(options.output_dependencies).resolve()
        depfile_name = depfile_path.name
        depfile_stem = depfile_path.stem

        path = depfile_path.parent / depfile_name
        cpu_fm.write_outputs(depfile_stem, str(path))


if __name__ == "__main__":
    main(sys.argv[1:])
