# Owner(s): ["oncall: package/deploy"]

import inspect
import os
import platform
import sys
from io import BytesIO
from pathlib import Path
from textwrap import dedent
from unittest import skipIf

from torch.package import is_from_package, PackageExporter, PackageImporter
from torch.package.package_exporter import PackagingError
from torch.testing._internal.common_utils import (
    IS_FBCODE,
    IS_SANDCASTLE,
    run_tests,
    skipIfTorchDynamo,
)


try:
    from .common import PackageTestCase
except ImportError:
    # Support the case where we run this file directly.
    from common import PackageTestCase


class TestMisc(PackageTestCase):
    """Tests for one-off or random functionality. Try not to add to this!"""

    def test_file_structure(self):
        """
        Tests package's Directory structure representation of a zip file. Ensures
        that the returned Directory prints what is expected and filters
        inputs/outputs correctly.
        """
        buffer = BytesIO()

        export_plain = dedent(
            """\
                \u251c\u2500\u2500 .data
                \u2502   \u251c\u2500\u2500 extern_modules
                \u2502   \u251c\u2500\u2500 python_version
                \u2502   \u251c\u2500\u2500 serialization_id
                \u2502   \u2514\u2500\u2500 version
                \u251c\u2500\u2500 main
                \u2502   \u2514\u2500\u2500 main
                \u251c\u2500\u2500 obj
                \u2502   \u2514\u2500\u2500 obj.pkl
                \u251c\u2500\u2500 package_a
                \u2502   \u251c\u2500\u2500 __init__.py
                \u2502   \u2514\u2500\u2500 subpackage.py
                \u251c\u2500\u2500 byteorder
                \u2514\u2500\u2500 module_a.py
            """
        )
        export_include = dedent(
            """\
                \u251c\u2500\u2500 obj
                \u2502   \u2514\u2500\u2500 obj.pkl
                \u2514\u2500\u2500 package_a
                    \u2514\u2500\u2500 subpackage.py
            """
        )
        import_exclude = dedent(
            """\
                \u251c\u2500\u2500 .data
                \u2502   \u251c\u2500\u2500 extern_modules
                \u2502   \u251c\u2500\u2500 python_version
                \u2502   \u251c\u2500\u2500 serialization_id
                \u2502   \u2514\u2500\u2500 version
                \u251c\u2500\u2500 main
                \u2502   \u2514\u2500\u2500 main
                \u251c\u2500\u2500 obj
                \u2502   \u2514\u2500\u2500 obj.pkl
                \u251c\u2500\u2500 package_a
                \u2502   \u251c\u2500\u2500 __init__.py
                \u2502   \u2514\u2500\u2500 subpackage.py
                \u251c\u2500\u2500 byteorder
                \u2514\u2500\u2500 module_a.py
            """
        )

        with PackageExporter(buffer) as he:
            import module_a
            import package_a
            import package_a.subpackage

            obj = package_a.subpackage.PackageASubpackageObject()
            he.intern("**")
            he.save_module(module_a.__name__)
            he.save_module(package_a.__name__)
            he.save_pickle("obj", "obj.pkl", obj)
            he.save_text("main", "main", "my string")

        buffer.seek(0)
        hi = PackageImporter(buffer)

        file_structure = hi.file_structure()
        # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
        self.assertEqual(
            dedent("\n".join(str(file_structure).split("\n")[1:])),
            export_plain,
        )
        file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"])
        self.assertEqual(
            dedent("\n".join(str(file_structure).split("\n")[1:])),
            export_include,
        )

        file_structure = hi.file_structure(exclude="**/*.storage")
        self.assertEqual(
            dedent("\n".join(str(file_structure).split("\n")[1:])),
            import_exclude,
        )

    def test_loaders_that_remap_files_work_ok(self):
        from importlib.abc import MetaPathFinder
        from importlib.machinery import SourceFileLoader
        from importlib.util import spec_from_loader

        class LoaderThatRemapsModuleA(SourceFileLoader):
            def get_filename(self, name):
                result = super().get_filename(name)
                if name == "module_a":
                    return os.path.join(
                        os.path.dirname(result), "module_a_remapped_path.py"
                    )
                else:
                    return result

        class FinderThatRemapsModuleA(MetaPathFinder):
            def find_spec(self, fullname, path, target):
                """Try to find the original spec for module_a using all the
                remaining meta_path finders."""
                if fullname != "module_a":
                    return None
                spec = None
                for finder in sys.meta_path:
                    if finder is self:
                        continue
                    if hasattr(finder, "find_spec"):
                        spec = finder.find_spec(fullname, path, target=target)
                    elif hasattr(finder, "load_module"):
                        spec = spec_from_loader(fullname, finder)
                    if spec is not None:
                        break
                assert spec is not None and isinstance(spec.loader, SourceFileLoader)
                spec.loader = LoaderThatRemapsModuleA(
                    spec.loader.name, spec.loader.path
                )
                return spec

        sys.meta_path.insert(0, FinderThatRemapsModuleA())
        # clear it from sys.modules so that we use the custom finder next time
        # it gets imported
        sys.modules.pop("module_a", None)
        try:
            buffer = BytesIO()
            with PackageExporter(buffer) as he:
                import module_a

                he.intern("**")
                he.save_module(module_a.__name__)

            buffer.seek(0)
            hi = PackageImporter(buffer)
            self.assertTrue("remapped_path" in hi.get_source("module_a"))
        finally:
            # pop it again to ensure it does not mess up other tests
            sys.modules.pop("module_a", None)
            sys.meta_path.pop(0)

    def test_python_version(self):
        """
        Tests that the current python version is stored in the package and is available
        via PackageImporter's python_version() method.
        """
        buffer = BytesIO()

        with PackageExporter(buffer) as he:
            from package_a.test_module import SimpleTest

            he.intern("**")
            obj = SimpleTest()
            he.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)
        hi = PackageImporter(buffer)

        self.assertEqual(hi.python_version(), platform.python_version())

    @skipIf(
        IS_FBCODE or IS_SANDCASTLE,
        "Tests that use temporary files are disabled in fbcode",
    )
    def test_load_python_version_from_package(self):
        """Tests loading a package with a python version embdded"""
        importer1 = PackageImporter(
            f"{Path(__file__).parent}/package_e/test_nn_module.pt"
        )
        self.assertEqual(importer1.python_version(), "3.9.7")

    def test_file_structure_has_file(self):
        """
        Test Directory's has_file() method.
        """
        buffer = BytesIO()
        with PackageExporter(buffer) as he:
            import package_a.subpackage

            he.intern("**")
            obj = package_a.subpackage.PackageASubpackageObject()
            he.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)

        importer = PackageImporter(buffer)
        file_structure = importer.file_structure()
        self.assertTrue(file_structure.has_file("package_a/subpackage.py"))
        self.assertFalse(file_structure.has_file("package_a/subpackage"))

    def test_exporter_content_lists(self):
        """
        Test content list API for PackageExporter's contained modules.
        """

        with PackageExporter(BytesIO()) as he:
            import package_b

            he.extern("package_b.subpackage_1")
            he.mock("package_b.subpackage_2")
            he.intern("**")
            he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"]))
            self.assertEqual(he.externed_modules(), ["package_b.subpackage_1"])
            self.assertEqual(he.mocked_modules(), ["package_b.subpackage_2"])
            self.assertEqual(
                he.interned_modules(),
                ["package_b", "package_b.subpackage_0.subsubpackage_0"],
            )
            self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"])

        with self.assertRaises(PackagingError) as e:
            with PackageExporter(BytesIO()) as he:
                import package_b

                he.deny("package_b")
                he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"]))
                self.assertEqual(he.denied_modules(), ["package_b"])

    def test_is_from_package(self):
        """is_from_package should work for objects and modules"""
        import package_a.subpackage

        buffer = BytesIO()
        obj = package_a.subpackage.PackageASubpackageObject()

        with PackageExporter(buffer) as pe:
            pe.intern("**")
            pe.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        mod = pi.import_module("package_a.subpackage")
        loaded_obj = pi.load_pickle("obj", "obj.pkl")

        self.assertFalse(is_from_package(package_a.subpackage))
        self.assertTrue(is_from_package(mod))

        self.assertFalse(is_from_package(obj))
        self.assertTrue(is_from_package(loaded_obj))

    def test_inspect_class(self):
        """Should be able to retrieve source for a packaged class."""
        import package_a.subpackage

        buffer = BytesIO()
        obj = package_a.subpackage.PackageASubpackageObject()

        with PackageExporter(buffer) as pe:
            pe.intern("**")
            pe.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        packaged_class = pi.import_module(
            "package_a.subpackage"
        ).PackageASubpackageObject
        regular_class = package_a.subpackage.PackageASubpackageObject

        packaged_src = inspect.getsourcelines(packaged_class)
        regular_src = inspect.getsourcelines(regular_class)
        self.assertEqual(packaged_src, regular_src)

    def test_dunder_package_present(self):
        """
        The attribute '__torch_package__' should be populated on imported modules.
        """
        import package_a.subpackage

        buffer = BytesIO()
        obj = package_a.subpackage.PackageASubpackageObject()

        with PackageExporter(buffer) as pe:
            pe.intern("**")
            pe.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        mod = pi.import_module("package_a.subpackage")
        self.assertTrue(hasattr(mod, "__torch_package__"))

    def test_dunder_package_works_from_package(self):
        """
        The attribute '__torch_package__' should be accessible from within
        the module itself, so that packaged code can detect whether it's
        being used in a packaged context or not.
        """
        import package_a.use_dunder_package as mod

        buffer = BytesIO()

        with PackageExporter(buffer) as pe:
            pe.intern("**")
            pe.save_module(mod.__name__)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        imported_mod = pi.import_module(mod.__name__)
        self.assertTrue(imported_mod.is_from_package())
        self.assertFalse(mod.is_from_package())

    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
    def test_std_lib_sys_hackery_checks(self):
        """
        The standard library performs sys.module assignment hackery which
        causes modules who do this hackery to fail on import. See
        https://github.com/pytorch/pytorch/issues/57490 for more information.
        """
        import package_a.std_sys_module_hacks

        buffer = BytesIO()
        mod = package_a.std_sys_module_hacks.Module()

        with PackageExporter(buffer) as pe:
            pe.intern("**")
            pe.save_pickle("obj", "obj.pkl", mod)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        mod = pi.load_pickle("obj", "obj.pkl")
        mod()


if __name__ == "__main__":
    run_tests()
