# Owner(s): ["module: autograd"]

import importlib
import inspect
import json
import logging
import os
import pkgutil
import unittest
from typing import Callable

import torch
from torch._utils_internal import get_file_path_2
from torch.testing._internal.common_utils import (
    IS_JETSON,
    IS_MACOS,
    IS_WINDOWS,
    run_tests,
    skipIfTorchDynamo,
    TestCase,
)


log = logging.getLogger(__name__)


class TestPublicBindings(TestCase):
    def test_no_new_reexport_callables(self):
        """
        This test aims to stop the introduction of new re-exported callables into
        torch whose names do not start with _. Such callables are made available as
        torch.XXX, which may not be desirable.
        """
        reexported_callables = sorted(
            k
            for k, v in vars(torch).items()
            if callable(v) and not v.__module__.startswith("torch")
        )
        self.assertTrue(
            all(k.startswith("_") for k in reexported_callables), reexported_callables
        )

    def test_no_new_bindings(self):
        """
        This test aims to stop the introduction of new JIT bindings into torch._C
        whose names do not start with _. Such bindings are made available as
        torch.XXX, which may not be desirable.

        If your change causes this test to fail, add your new binding to a relevant
        submodule of torch._C, such as torch._C._jit (or other relevant submodule of
        torch._C). If your binding really needs to be available as torch.XXX, add it
        to torch._C and add it to the allowlist below.

        If you have removed a binding, remove it from the allowlist as well.
        """

        # This allowlist contains every binding in torch._C that is copied into torch at
        # the time of writing. It was generated with
        #
        #   {elem for elem in dir(torch._C) if not elem.startswith("_")}
        torch_C_allowlist_superset = {
            "AggregationType",
            "AliasDb",
            "AnyType",
            "Argument",
            "ArgumentSpec",
            "AwaitType",
            "autocast_decrement_nesting",
            "autocast_increment_nesting",
            "AVG",
            "BenchmarkConfig",
            "BenchmarkExecutionStats",
            "Block",
            "BoolType",
            "BufferDict",
            "StorageBase",
            "CallStack",
            "Capsule",
            "ClassType",
            "clear_autocast_cache",
            "Code",
            "CompilationUnit",
            "CompleteArgumentSpec",
            "ComplexType",
            "ConcreteModuleType",
            "ConcreteModuleTypeBuilder",
            "cpp",
            "CudaBFloat16TensorBase",
            "CudaBoolTensorBase",
            "CudaByteTensorBase",
            "CudaCharTensorBase",
            "CudaComplexDoubleTensorBase",
            "CudaComplexFloatTensorBase",
            "CudaDoubleTensorBase",
            "CudaFloatTensorBase",
            "CudaHalfTensorBase",
            "CudaIntTensorBase",
            "CudaLongTensorBase",
            "CudaShortTensorBase",
            "DeepCopyMemoTable",
            "default_generator",
            "DeserializationStorageContext",
            "device",
            "DeviceObjType",
            "DictType",
            "DisableTorchFunction",
            "DisableTorchFunctionSubclass",
            "DispatchKey",
            "DispatchKeySet",
            "dtype",
            "EnumType",
            "ErrorReport",
            "ExcludeDispatchKeyGuard",
            "ExecutionPlan",
            "FatalError",
            "FileCheck",
            "finfo",
            "FloatType",
            "fork",
            "FunctionSchema",
            "Future",
            "FutureType",
            "Generator",
            "GeneratorType",
            "get_autocast_cpu_dtype",
            "get_autocast_dtype",
            "get_autocast_ipu_dtype",
            "get_default_dtype",
            "get_num_interop_threads",
            "get_num_threads",
            "Gradient",
            "Graph",
            "GraphExecutorState",
            "has_cuda",
            "has_cudnn",
            "has_lapack",
            "has_mkl",
            "has_mkldnn",
            "has_mps",
            "has_openmp",
            "has_spectral",
            "iinfo",
            "import_ir_module_from_buffer",
            "import_ir_module",
            "InferredType",
            "init_num_threads",
            "InterfaceType",
            "IntType",
            "SymFloatType",
            "SymBoolType",
            "SymIntType",
            "IODescriptor",
            "is_anomaly_enabled",
            "is_anomaly_check_nan_enabled",
            "is_autocast_cache_enabled",
            "is_autocast_cpu_enabled",
            "is_autocast_ipu_enabled",
            "is_autocast_enabled",
            "is_grad_enabled",
            "is_inference_mode_enabled",
            "JITException",
            "layout",
            "ListType",
            "LiteScriptModule",
            "LockingLogger",
            "LoggerBase",
            "memory_format",
            "merge_type_from_type_comment",
            "ModuleDict",
            "Node",
            "NoneType",
            "NoopLogger",
            "NumberType",
            "OperatorInfo",
            "OptionalType",
            "OutOfMemoryError",
            "ParameterDict",
            "parse_ir",
            "parse_schema",
            "parse_type_comment",
            "PyObjectType",
            "PyTorchFileReader",
            "PyTorchFileWriter",
            "qscheme",
            "read_vitals",
            "RRefType",
            "ScriptClass",
            "ScriptClassFunction",
            "ScriptDict",
            "ScriptDictIterator",
            "ScriptDictKeyIterator",
            "ScriptList",
            "ScriptListIterator",
            "ScriptFunction",
            "ScriptMethod",
            "ScriptModule",
            "ScriptModuleSerializer",
            "ScriptObject",
            "ScriptObjectProperty",
            "SerializationStorageContext",
            "set_anomaly_enabled",
            "set_autocast_cache_enabled",
            "set_autocast_cpu_dtype",
            "set_autocast_dtype",
            "set_autocast_ipu_dtype",
            "set_autocast_cpu_enabled",
            "set_autocast_ipu_enabled",
            "set_autocast_enabled",
            "set_flush_denormal",
            "set_num_interop_threads",
            "set_num_threads",
            "set_vital",
            "Size",
            "StaticModule",
            "Stream",
            "StreamObjType",
            "Event",
            "StringType",
            "SUM",
            "SymFloat",
            "SymInt",
            "TensorType",
            "ThroughputBenchmark",
            "TracingState",
            "TupleType",
            "Type",
            "unify_type_list",
            "UnionType",
            "Use",
            "Value",
            "set_autocast_gpu_dtype",
            "get_autocast_gpu_dtype",
            "vitals_enabled",
            "wait",
            "Tag",
            "set_autocast_xla_enabled",
            "set_autocast_xla_dtype",
            "get_autocast_xla_dtype",
            "is_autocast_xla_enabled",
        }

        torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}

        # torch.TensorBase is explicitly removed in torch/__init__.py, so included here (#109940)
        explicitly_removed_torch_C_bindings = {"TensorBase"}

        torch_C_bindings = torch_C_bindings - explicitly_removed_torch_C_bindings

        # Check that the torch._C bindings are all in the allowlist. Since
        # bindings can change based on how PyTorch was compiled (e.g. with/without
        # CUDA), the two may not be an exact match but the bindings should be
        # a subset of the allowlist.
        difference = torch_C_bindings.difference(torch_C_allowlist_superset)
        msg = f"torch._C had bindings that are not present in the allowlist:\n{difference}"
        self.assertTrue(torch_C_bindings.issubset(torch_C_allowlist_superset), msg)

    @staticmethod
    def _is_mod_public(modname):
        split_strs = modname.split(".")
        for elem in split_strs:
            if elem.startswith("_"):
                return False
        return True

    @unittest.skipIf(
        IS_WINDOWS or IS_MACOS,
        "Inductor/Distributed modules hard fail on windows and macos",
    )
    @skipIfTorchDynamo("Broken and not relevant for now")
    def test_modules_can_be_imported(self):
        failures = []

        def onerror(modname):
            failures.append(
                (modname, ImportError("exception occurred importing package"))
            )

        for mod in pkgutil.walk_packages(torch.__path__, "torch.", onerror=onerror):
            modname = mod.name
            try:
                # TODO: fix "torch/utils/model_dump/__main__.py"
                # which calls sys.exit() when we try to import it
                if "__main__" in modname:
                    continue
                importlib.import_module(modname)
            except Exception as e:
                # Some current failures are not ImportError
                log.exception("import_module failed")
                failures.append((modname, e))

        # It is ok to add new entries here but please be careful that these modules
        # do not get imported by public code.
        private_allowlist = {
            "torch._inductor.codegen.cuda.cuda_kernel",
            # TODO(#133647): Remove the onnx._internal entries after
            # onnx and onnxscript are installed in CI.
            "torch.onnx._internal.exporter",
            "torch.onnx._internal.exporter._analysis",
            "torch.onnx._internal.exporter._building",
            "torch.onnx._internal.exporter._capture_strategies",
            "torch.onnx._internal.exporter._compat",
            "torch.onnx._internal.exporter._core",
            "torch.onnx._internal.exporter._decomp",
            "torch.onnx._internal.exporter._dispatching",
            "torch.onnx._internal.exporter._fx_passes",
            "torch.onnx._internal.exporter._ir_passes",
            "torch.onnx._internal.exporter._isolated",
            "torch.onnx._internal.exporter._onnx_program",
            "torch.onnx._internal.exporter._registration",
            "torch.onnx._internal.exporter._reporting",
            "torch.onnx._internal.exporter._schemas",
            "torch.onnx._internal.exporter._tensors",
            "torch.onnx._internal.exporter._verification",
            "torch.onnx._internal.fx._pass",
            "torch.onnx._internal.fx.analysis",
            "torch.onnx._internal.fx.analysis.unsupported_nodes",
            "torch.onnx._internal.fx.decomposition_skip",
            "torch.onnx._internal.fx.diagnostics",
            "torch.onnx._internal.fx.fx_onnx_interpreter",
            "torch.onnx._internal.fx.fx_symbolic_graph_extractor",
            "torch.onnx._internal.fx.onnxfunction_dispatcher",
            "torch.onnx._internal.fx.op_validation",
            "torch.onnx._internal.fx.passes",
            "torch.onnx._internal.fx.passes._utils",
            "torch.onnx._internal.fx.passes.decomp",
            "torch.onnx._internal.fx.passes.functionalization",
            "torch.onnx._internal.fx.passes.modularization",
            "torch.onnx._internal.fx.passes.readability",
            "torch.onnx._internal.fx.passes.type_promotion",
            "torch.onnx._internal.fx.passes.virtualization",
            "torch.onnx._internal.fx.type_utils",
            "torch.testing._internal.common_distributed",
            "torch.testing._internal.common_fsdp",
            "torch.testing._internal.dist_utils",
            "torch.testing._internal.distributed.common_state_dict",
            "torch.testing._internal.distributed._shard.sharded_tensor",
            "torch.testing._internal.distributed._shard.test_common",
            "torch.testing._internal.distributed._tensor.common_dtensor",
            "torch.testing._internal.distributed.ddp_under_dist_autograd_test",
            "torch.testing._internal.distributed.distributed_test",
            "torch.testing._internal.distributed.distributed_utils",
            "torch.testing._internal.distributed.fake_pg",
            "torch.testing._internal.distributed.multi_threaded_pg",
            "torch.testing._internal.distributed.nn.api.remote_module_test",
            "torch.testing._internal.distributed.rpc.dist_autograd_test",
            "torch.testing._internal.distributed.rpc.dist_optimizer_test",
            "torch.testing._internal.distributed.rpc.examples.parameter_server_test",
            "torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test",
            "torch.testing._internal.distributed.rpc.faulty_agent_rpc_test",
            "torch.testing._internal.distributed.rpc.faulty_rpc_agent_test_fixture",
            "torch.testing._internal.distributed.rpc.jit.dist_autograd_test",
            "torch.testing._internal.distributed.rpc.jit.rpc_test",
            "torch.testing._internal.distributed.rpc.jit.rpc_test_faulty",
            "torch.testing._internal.distributed.rpc.rpc_agent_test_fixture",
            "torch.testing._internal.distributed.rpc.rpc_test",
            "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture",
            "torch.testing._internal.distributed.rpc_utils",
            "torch._inductor.codegen.cuda.cuda_template",
            "torch._inductor.codegen.cuda.gemm_template",
            "torch._inductor.codegen.cpp_template",
            "torch._inductor.codegen.cpp_gemm_template",
            "torch._inductor.codegen.cpp_micro_gemm",
            "torch._inductor.codegen.cpp_template_kernel",
            "torch._inductor.runtime.triton_helpers",
            "torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks.data_sparsity",
            "torch.backends._coreml.preprocess",
            "torch.contrib._tensorboard_vis",
            "torch.distributed._composable",
            "torch.distributed._functional_collectives",
            "torch.distributed._functional_collectives_impl",
            "torch.distributed._shard",
            "torch.distributed._sharded_tensor",
            "torch.distributed._sharding_spec",
            "torch.distributed._spmd.api",
            "torch.distributed._spmd.batch_dim_utils",
            "torch.distributed._spmd.comm_tensor",
            "torch.distributed._spmd.data_parallel",
            "torch.distributed._spmd.distribute",
            "torch.distributed._spmd.experimental_ops",
            "torch.distributed._spmd.parallel_mode",
            "torch.distributed._tensor",
            "torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
            "torch.distributed.algorithms._optimizer_overlap",
            "torch.distributed.rpc._testing.faulty_agent_backend_registry",
            "torch.distributed.rpc._utils",
            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.dlrm_utils",
            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_disk_savings",
            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_forward_time",
            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_model_metrics",
            "torch.ao.pruning._experimental.data_sparsifier.lightning.tests.test_callbacks",
            "torch.csrc.jit.tensorexpr.scripts.bisect",
            "torch.csrc.lazy.test_mnist",
            "torch.distributed._shard.checkpoint._fsspec_filesystem",
            "torch.distributed._tensor.examples.visualize_sharding_example",
            "torch.distributed.checkpoint._fsspec_filesystem",
            "torch.distributed.examples.memory_tracker_example",
            "torch.testing._internal.distributed.rpc.fb.thrift_rpc_agent_test_fixture",
            "torch.utils._cxx_pytree",
            "torch.utils.tensorboard._convert_np",
            "torch.utils.tensorboard._embedding",
            "torch.utils.tensorboard._onnx_graph",
            "torch.utils.tensorboard._proto_graph",
            "torch.utils.tensorboard._pytorch_graph",
            "torch.utils.tensorboard._utils",
        }

        # No new entries should be added to this list.
        # All public modules should be importable on all platforms.
        public_allowlist = {
            "torch.distributed.algorithms.ddp_comm_hooks",
            "torch.distributed.algorithms.model_averaging.averagers",
            "torch.distributed.algorithms.model_averaging.hierarchical_model_averager",
            "torch.distributed.algorithms.model_averaging.utils",
            "torch.distributed.checkpoint",
            "torch.distributed.constants",
            "torch.distributed.distributed_c10d",
            "torch.distributed.elastic.agent.server",
            "torch.distributed.elastic.rendezvous",
            "torch.distributed.fsdp",
            "torch.distributed.launch",
            "torch.distributed.launcher",
            "torch.distributed.nn",
            "torch.distributed.nn.api.remote_module",
            "torch.distributed.optim",
            "torch.distributed.optim.optimizer",
            "torch.distributed.rendezvous",
            "torch.distributed.rpc.api",
            "torch.distributed.rpc.backend_registry",
            "torch.distributed.rpc.constants",
            "torch.distributed.rpc.internal",
            "torch.distributed.rpc.options",
            "torch.distributed.rpc.rref_proxy",
            "torch.distributed.elastic.rendezvous.etcd_rendezvous",
            "torch.distributed.elastic.rendezvous.etcd_rendezvous_backend",
            "torch.distributed.elastic.rendezvous.etcd_store",
            "torch.distributed.rpc.server_process_global_profiler",
            "torch.distributed.run",
            "torch.distributed.tensor.parallel",
            "torch.distributed.utils",
            "torch.utils.tensorboard",
            "torch.utils.tensorboard.summary",
            "torch.utils.tensorboard.writer",
            "torch.ao.quantization.experimental.fake_quantize",
            "torch.ao.quantization.experimental.linear",
            "torch.ao.quantization.experimental.observer",
            "torch.ao.quantization.experimental.qconfig",
        }

        errors = []
        for mod, exc in failures:
            if mod in public_allowlist:
                # TODO: Ensure this is the right error type

                continue
            if mod in private_allowlist:
                continue
            errors.append(
                f"{mod} failed to import with error {type(exc).__qualname__}: {str(exc)}"
            )
        self.assertEqual("", "\n".join(errors))

    # AttributeError: module 'torch.distributed' has no attribute '_shard'
    @unittest.skipIf(IS_WINDOWS or IS_JETSON or IS_MACOS, "Distributed Attribute Error")
    @skipIfTorchDynamo("Broken and not relevant for now")
    def test_correct_module_names(self):
        """
        An API is considered public, if  its  `__module__` starts with `torch.`
        and there is no name in `__module__` or the object itself that starts with "_".
        Each public package should either:
        - (preferred) Define `__all__` and all callables and classes in there must have their
         `__module__` start with the current submodule's path. Things not in `__all__` should
          NOT have their `__module__` start with the current submodule.
        - (for simple python-only modules) Not define `__all__` and all the elements in `dir(submod)` must have their
          `__module__` that start with the current submodule.
        """

        failure_list = []
        with open(
            get_file_path_2(os.path.dirname(__file__), "allowlist_for_publicAPI.json")
        ) as json_file:
            # no new entries should be added to this allow_dict.
            # New APIs must follow the public API guidelines.

            allow_dict = json.load(json_file)
            # Because we want minimal modifications to the `allowlist_for_publicAPI.json`,
            # we are adding the entries for the migrated modules here from the original
            # locations.

            for modname in allow_dict["being_migrated"]:
                if modname in allow_dict:
                    allow_dict[allow_dict["being_migrated"][modname]] = allow_dict[
                        modname
                    ]

        def test_module(modname):
            try:
                if "__main__" in modname:
                    return
                mod = importlib.import_module(modname)
            except Exception:
                # It is ok to ignore here as we have a test above that ensures
                # this should never happen

                return
            if not self._is_mod_public(modname):
                return
            # verifies that each public API has the correct module name and naming semantics

            def check_one_element(elem, modname, mod, *, is_public, is_all):
                obj = getattr(mod, elem)

                # torch.dtype is not a class nor callable, so we need to check for it separately
                if not (
                    isinstance(obj, (Callable, torch.dtype)) or inspect.isclass(obj)
                ):
                    return
                elem_module = getattr(obj, "__module__", None)

                # Only used for nice error message below
                why_not_looks_public = ""
                if elem_module is None:
                    why_not_looks_public = (
                        "because it does not have a `__module__` attribute"
                    )

                # If a module is being migrated from foo.a to bar.a (that is entry {"foo": "bar"}),
                # the module's starting package would be referred to as the new location even
                # if there is a "from foo import a" inside the "bar.py".
                modname = allow_dict["being_migrated"].get(modname, modname)
                elem_modname_starts_with_mod = (
                    elem_module is not None
                    and elem_module.startswith(modname)
                    and "._" not in elem_module
                )
                if not why_not_looks_public and not elem_modname_starts_with_mod:
                    why_not_looks_public = (
                        f"because its `__module__` attribute (`{elem_module}`) is not within the "
                        f"torch library or does not start with the submodule where it is defined (`{modname}`)"
                    )

                # elem's name must NOT begin with an `_` and it's module name
                # SHOULD start with it's current module since it's a public API
                looks_public = not elem.startswith("_") and elem_modname_starts_with_mod
                if not why_not_looks_public and not looks_public:
                    why_not_looks_public = f"because it starts with `_` (`{elem}`)"
                if is_public != looks_public:
                    if modname in allow_dict and elem in allow_dict[modname]:
                        return
                    if is_public:
                        why_is_public = (
                            f"it is inside the module's (`{modname}`) `__all__`"
                            if is_all
                            else "it is an attribute that does not start with `_` on a module that "
                            "does not have `__all__` defined"
                        )
                        fix_is_public = (
                            f"remove it from the modules's (`{modname}`) `__all__`"
                            if is_all
                            else f"either define a `__all__` for `{modname}` or add a `_` at the beginning of the name"
                        )
                    else:
                        assert is_all
                        why_is_public = (
                            f"it is not inside the module's (`{modname}`) `__all__`"
                        )
                        fix_is_public = (
                            f"add it from the modules's (`{modname}`) `__all__`"
                        )
                    if looks_public:
                        why_looks_public = (
                            "it does look public because it follows the rules from the doc above "
                            "(does not start with `_` and has a proper `__module__`)."
                        )
                        fix_looks_public = "make its name start with `_`"
                    else:
                        why_looks_public = why_not_looks_public
                        if not elem_modname_starts_with_mod:
                            fix_looks_public = (
                                "make sure the `__module__` is properly set and points to a submodule "
                                f"of `{modname}`"
                            )
                        else:
                            fix_looks_public = (
                                "remove the `_` at the beginning of the name"
                            )
                    failure_list.append(f"# {modname}.{elem}:")
                    is_public_str = "" if is_public else " NOT"
                    failure_list.append(
                        f"  - Is{is_public_str} public: {why_is_public}"
                    )
                    looks_public_str = "" if looks_public else " NOT"
                    failure_list.append(
                        f"  - Does{looks_public_str} look public: {why_looks_public}"
                    )
                    # Swap the str below to avoid having to create the NOT again
                    failure_list.append(
                        "  - You can do either of these two things to fix this problem:"
                    )
                    failure_list.append(
                        f"    - To make it{looks_public_str} public: {fix_is_public}"
                    )
                    failure_list.append(
                        f"    - To make it{is_public_str} look public: {fix_looks_public}"
                    )

            if hasattr(mod, "__all__"):
                public_api = mod.__all__
                all_api = dir(mod)
                for elem in all_api:
                    check_one_element(
                        elem, modname, mod, is_public=elem in public_api, is_all=True
                    )
            else:
                all_api = dir(mod)
                for elem in all_api:
                    if not elem.startswith("_"):
                        check_one_element(
                            elem, modname, mod, is_public=True, is_all=False
                        )

        for mod in pkgutil.walk_packages(torch.__path__, "torch."):
            modname = mod.name
            test_module(modname)
        test_module("torch")

        msg = (
            "All the APIs below do not meet our guidelines for public API from "
            "https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.\n"
        )
        msg += (
            "Make sure that everything that is public is expected (in particular that the module "
            "has a properly populated `__all__` attribute) and that everything that is supposed to be public "
            "does look public (it does not start with `_` and has a `__module__` that is properly populated)."
        )

        msg += "\n\nFull list:\n"
        msg += "\n".join(map(str, failure_list))

        # empty lists are considered false in python
        self.assertTrue(not failure_list, msg)


if __name__ == "__main__":
    run_tests()
