# Owner(s): ["module: dynamo"]
import dataclasses
import importlib
import inspect
import math
import types
import unittest
import warnings
from typing import Any, Dict, Set

import torch
import torch._dynamo.config as config
import torch._dynamo.test_case
import torch._functorch.deprecated as deprecated_func
from torch._dynamo.trace_rules import (
    LEGACY_MOD_INLINELIST,
    load_object,
    manual_torch_name_rule_map,
    MOD_INLINELIST,
    torch_c_binding_in_graph_functions,
    torch_non_c_binding_in_graph_functions,
)
from torch._dynamo.utils import hashable, is_safe_constant, istype
from torch._dynamo.variables import TorchInGraphFunctionVariable, UserFunctionVariable
from torch.testing._internal.common_utils import skipIfWindows


try:
    from .utils import create_dummy_module_and_function
except ImportError:
    from utils import create_dummy_module_and_function


ignored_c_binding_in_graph_function_names = {
    # Ignored because they have manual rules defined at `trace_rules.manual_torch_name_rule_map`.
    "torch._nested_tensor_from_mask",
    "torch._nested_from_padded",
    "torch.sparse_compressed_tensor",
    "torch.sparse_bsc_tensor",
    "torch.sparse_bsr_tensor",
    "torch.sparse_coo_tensor",
    "torch.sparse_csc_tensor",
    "torch.sparse_csr_tensor",
    "torch.cuda._get_device_properties",
    # Ignored and go through rules defined at `trace_rules.check`.
    "torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode",
    "torch._cslt_sparse_mm_search",
    "torch._C._abort",
    "torch._C._mps_is_on_macos_or_newer",
    "torch._C._swap_tensor_impl",
    "torch._C._unsafe_reset_storage",
    "torch._dynamo.eval_frame.reset_code",
    "torch._C.autocast_decrement_nesting",
    "torch._C.autocast_increment_nesting",
    "torch._C.clear_autocast_cache",
    "torch._C.set_anomaly_enabled",
    "torch._C.set_autocast_cache_enabled",
    "torch._C.set_autocast_cpu_dtype",
    "torch._C.set_autocast_cpu_enabled",
    "torch._C.set_autocast_enabled",
    "torch._C.set_autocast_gpu_dtype",
    "torch._C.set_autocast_ipu_dtype",
    "torch._C.set_autocast_ipu_enabled",
    "torch._C.set_autocast_xla_dtype",
    "torch._C.set_autocast_xla_enabled",
    "torch.resize_as_",
    "torch.resize_as_sparse_",
    "torch._C._data_address",
    "torch._C._is_cow_tensor",
    "torch._lazy_clone",
    "torch._test_parallel_materialize",
    "torch._C._storage_address",
    "torch._C._pickle_save",
    "torch._validate_sparse_compressed_tensor_args",
    "torch._validate_sparse_csr_tensor_args",
    "torch._validate_sparse_bsr_tensor_args",
    "torch._validate_sparse_csc_tensor_args",
    "torch._validate_sparse_coo_tensor_args",
    "torch._validate_sparse_bsc_tensor_args",
    "torch._validate_compressed_sparse_indices",
}
if torch._C._llvm_enabled():
    ignored_c_binding_in_graph_function_names |= {
        "torch._C._te.set_llvm_aot_workflow",
        "torch._C._te.set_llvm_target_cpu",
        "torch._C._te.set_llvm_target_attrs",
        "torch._C._te.set_llvm_target_triple",
    }


# Helper function to dump the torch name rule map generated based on
# the heuristic defined in gen_allowed_objs_and_ids.
def dump_allowed_torch_name_rule_map() -> None:
    m = gen_allowed_objs_and_ids(record=True, c_binding_only=False).name_rule_map
    for k, v in m.items():
        print(f'"{k}": {v.__name__},')


@dataclasses.dataclass
class AllowedObjects:
    """
    Track the objects, object id - name pairs, and name - dynamo wrapping rule pairs
    from the heuristic defined in `gen_allowed_objs_and_ids`.
    """

    object_ids: Dict[int, str]
    c_binding_in_graph_functions: Set[Any]
    non_c_binding_in_graph_functions: Set[Any]
    name_rule_map: Dict[str, Any]


def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects:
    """
    Walk torch.* and get the ids of all the stuff in it
    """

    warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
    torch_object_ids = {}
    c_binding_in_graph_functions = set()
    non_c_binding_in_graph_functions = set()
    torch_name_rule_map = {}

    # In some platforms, these functions were loaded as classes instead of functions.
    # To mitigate these weired cases, we need this special check.
    def is_special_functions(obj):
        return hashable(obj) and obj in {
            torch._C._cuda_isCurrentStreamCapturing,
            torch._C._graph_pool_handle,
        }

    # Add obj to c_binding_in_graph_functions set or non_c_binding_in_graph_functions set
    # if it's a torch function or method.
    # This is used to generate the in graph function list based on heuristic.
    def heuristic_record_if_in_graph_function(obj, module, name):
        try:
            if hasattr(obj, "__wrapped__"):
                obj = obj.__wrapped__
        except Exception:
            pass
        if isinstance(
            obj,
            (
                types.FunctionType,
                types.BuiltinFunctionType,
                types.MethodDescriptorType,
                types.WrapperDescriptorType,
            ),
        ) or is_special_functions(obj):
            torch_name_rule_map[
                f"{module.__name__}.{name}"
            ] = TorchInGraphFunctionVariable
            if c_binding_only:
                if not hasattr(obj, "__code__"):
                    c_binding_in_graph_functions.add(obj)
            else:
                if hasattr(obj, "__code__"):
                    non_c_binding_in_graph_functions.add(obj)
                else:
                    c_binding_in_graph_functions.add(obj)

    def _is_allowed_module_prefix(obj):
        allowed_modules = ("torch", "math")
        # torch.nn.modules.rnn is disallowed because these modules internally
        # flatten their parameters.  This flattening process will call
        # Tensor.set_ with a Storage, and Storages cannot be traced with
        # AOTAutograd; so we need to graph-break. To ensure this, we inline
        # these functions, rather than keep them opaque-ly in the graph.
        disallowed_modules = [
            "torch.optim.",
            "torch.nn.modules.rnn.",
            "torch._dynamo.",
            "torch._C._dynamo.",
            "torch._inductor.",
            "torch._C.inductor.",
            "torch.fx.",
            "torch._C._autograd",
            "torch._C._cudart",
            "torch._C._distributed_autograd",
            "torch._C._distributed_c10d",
            "torch._C._distributed_rpc",
            "torch._C._functorch",
            "torch._C._monitor",
            "torch._C._nvtx",
            "torch._C._lazy",
            "torch._C._profiler",
            "torch.__config__",
            "torch._custom_op",
            "torch._decomp",
            "torch._dispatch",
            "torch._export",
            "torch._functorch.make_functional",
            "torch._functorch.compile_utils",
            "torch._functorch.partitioners",
            "torch._functorch.aot_autograd",
            "torch._functorch.compilers",
            "torch._functorch.fx_minifier",
            "torch.autograd.profiler_util",
            "torch.autograd.profiler",
            "torch._jit_internal",
            "torch._library",
            "torch._lobpcg",
            "torch._logging",
            "torch._meta_registrations",
            "torch._namedtensor_internals",
            "torch._numpy",
            "torch._sources",
            "torch._subclasses",
            "torch._tensor",
            "torch._tensor_str",
            "torch._utils",
            "torch._utils_internal",
            "torch._vmap_internals",
            "torch.compiler",
            "torch.distributed",
            "torch.export",
            "torch.hub",
            "torch.jit",
            "torch.library",
            "torch.masked.maskedtensor",
            "torch.nn.init",
            "torch.nn.modules.module",
            "torch.nn.parallel",
            "torch.nn.utils",
            "torch.multiprocessing",
            "torch.onnx",
            "torch.overrides",
            "torch.package",
            "torch.profiler",
            "torch.serialization",
            "torch.storage",
            "torch.utils",
            "torch.distributed.",
        ]

        allowed_modules_dot = tuple([x + "." for x in allowed_modules])
        module = inspect.getmodule(obj)
        if module is None:
            return False

        mod_name = module.__name__

        if any(mod_name.startswith(m) for m in disallowed_modules):
            return False

        return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)

    def _find_torch_objects(module):
        if any(
            module.__name__.startswith(mod_name)
            for mod_name in config.allowed_functions_module_string_ignorelist
        ):
            return
        torch_object_ids[id(module)] = module.__name__
        for name, obj in list(module.__dict__.items()):
            if id(obj) not in torch_object_ids:
                # Dynamo allows all builtins into the graph and does not attempt
                # to introspect into them. We don't want to allow instances of
                # HigherOrderOperator into the graph all the time (Dynamo needs
                # to introspect the body functions of these HigherOrderOperator
                # first, decide they are safe, and then allow them into the graph).
                # So we exclude HigherOrderOperator from being a builtin.
                import torch._ops

                if isinstance(obj, torch._ops.HigherOrderOperator):
                    continue

                # We want to trace through `grad` and `vmap`
                if obj in (
                    torch.func.grad,
                    deprecated_func.grad,
                    torch.func.vmap,
                    deprecated_func.vmap,
                    torch.nn.functional.triplet_margin_with_distance_loss,
                    torch.cond,
                ):
                    continue

                if isinstance(obj, types.ModuleType):
                    if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
                        obj
                    ):
                        torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
                        _find_torch_objects(obj)
                elif _is_allowed_module_prefix(obj):
                    if record:
                        heuristic_record_if_in_graph_function(obj, module, name)
                    torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
                elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
                    if record:
                        heuristic_record_if_in_graph_function(obj, module, name)
                    torch_object_ids[id(obj)] = f"{module.__name__}.{name}"

    _find_torch_objects(torch)
    _find_torch_objects(math)

    return AllowedObjects(
        torch_object_ids,
        c_binding_in_graph_functions,
        non_c_binding_in_graph_functions,
        torch_name_rule_map,
    )


class TraceRuleTests(torch._dynamo.test_case.TestCase):
    def _check_set_equality(self, generated, used, rule_map, ignored_set):
        x = generated - used
        y = used - generated
        msg1 = (
            f"New torch objects: {x} "
            f"were not added to `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
            "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
        )
        msg2 = (
            f"Existing torch objects: {y} were removed. "
            f"Please remove them from `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
            "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
        )
        self.assertTrue(len(x) == 0, msg1)
        self.assertTrue(len(y) == 0, msg2)

    # We are using python function and module string names for these inlinelist,
    # this unit test is to make sure the functions/modules can be correctly imported
    # or loaded in case there is typo in the strings.
    def test_skipfiles_inlinelist(self):
        for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST):
            self.assertTrue(
                isinstance(importlib.import_module(m), types.ModuleType),
                f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.",
            )

    @unittest.skip(
        "This test keeps getting broken and our disable infra is not handling well. see #120627"
    )
    def test_torch_name_rule_map_updated(self):
        # Generate the allowed objects based on heuristic defined in `allowed_functions.py`,
        objs = gen_allowed_objs_and_ids(record=True, c_binding_only=True)
        # Test C binding in graph functions are updated in torch_name_rule_map.
        generated = objs.c_binding_in_graph_functions
        used = set()
        for x in (
            set(torch_c_binding_in_graph_functions.keys())
            | ignored_c_binding_in_graph_function_names
        ):
            obj = load_object(x)
            if obj is not None:
                used.add(obj)
        self._check_set_equality(
            generated,
            used,
            "torch_c_binding_in_graph_functions",
            "ignored_c_binding_in_graph_function_names",
        )
        # For non C binding in graph functions, we only test if they can be loaded successfully.
        for f in torch_non_c_binding_in_graph_functions:
            self.assertTrue(
                isinstance(
                    load_object(f),
                    (
                        types.FunctionType,
                        types.BuiltinFunctionType,
                        types.MethodDescriptorType,
                        types.WrapperDescriptorType,
                    ),
                )
            )

    def test_force_inline_torch_function(self):
        # `torch._dynamo.utils.istype` is skipped by default
        def fn(x):
            if istype(x, torch.Tensor):
                return x + 1
            else:
                return x - 1

        _manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
        # Force inline `torch._dynamo.utils.istype` by setting trace rule.
        _manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable

        _torch_name_rule_map = [
            _manual_torch_name_rule_map,
            torch_c_binding_in_graph_functions,
            torch_non_c_binding_in_graph_functions,
        ]

        self.assertTrue(
            "torch._dynamo" not in torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST
        )
        self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)

        with unittest.mock.patch(
            "torch._dynamo.trace_rules.torch_name_rule_map",
            _torch_name_rule_map,
        ), unittest.mock.patch(
            "torch._dynamo.trace_rules.get_torch_obj_rule_map",
            torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,  # bypass functools.lru_cache
        ):
            x = torch.rand(3)
            opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
            ref = fn(x)
            res = opt_fn(x)
            self.assertEqual(ref, res)

    def test_force_inline_custom_function(self):
        mod, func = create_dummy_module_and_function()

        def fn(x):
            return func(x)

        _manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
        # Force inline `mod.func` by setting trace rule.
        _manual_torch_name_rule_map[
            f"{mod.__name__}.{func.__name__}"
        ] = UserFunctionVariable

        _torch_name_rule_map = [
            _manual_torch_name_rule_map,
            torch_c_binding_in_graph_functions,
            torch_non_c_binding_in_graph_functions,
        ]

        with unittest.mock.patch(
            "torch._dynamo.trace_rules.torch_name_rule_map",
            _torch_name_rule_map,
        ), unittest.mock.patch(
            "torch._dynamo.trace_rules.get_torch_obj_rule_map",
            torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
        ):
            # First adding the module to SKIP_DIRS so that it will be skipped by default.
            torch._dynamo.trace_rules.add(mod.__name__)
            x = torch.rand(3)
            opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
            ref = fn(x)
            res = opt_fn(x)
            self.assertEqual(ref, res)


class TestModuleSurviveSkipFiles(torch._dynamo.test_case.TestCase):
    @unittest.skipIf(
        not torch.distributed.is_available(),
        "need to import MLP module from distributed",
    )
    @skipIfWindows(
        msg="AssertionError: False is not true : MLP did not survive skip files"
    )
    def test_module_survive_skip_files(self):
        from torch.testing._internal.common_fsdp import MLP

        model = MLP(3)
        inp = torch.randn((2, 3))
        frame_count_before = torch._dynamo.convert_frame.FRAME_COUNTER
        model.compile(backend="eager")
        model(inp)
        frame_count_after = torch._dynamo.convert_frame.FRAME_COUNTER
        self.assertTrue(
            frame_count_after > frame_count_before, "MLP did not survive skip files"
        )


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
