# Owner(s): ["oncall: package/deploy"]

import os
import zipfile
from sys import version_info
from tempfile import TemporaryDirectory
from textwrap import dedent
from unittest import skipIf

import torch
from torch.package import PackageExporter, PackageImporter
from torch.testing._internal.common_utils import (
    IS_FBCODE,
    IS_SANDCASTLE,
    IS_WINDOWS,
    run_tests,
)


try:
    from torchvision.models import resnet18

    HAS_TORCHVISION = True
except ImportError:
    HAS_TORCHVISION = False
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")


try:
    from .common import PackageTestCase
except ImportError:
    # Support the case where we run this file directly.
    from common import PackageTestCase

from pathlib import Path


packaging_directory = Path(__file__).parent


@skipIf(
    IS_FBCODE or IS_SANDCASTLE or IS_WINDOWS,
    "Tests that use temporary files are disabled in fbcode",
)
class DirectoryReaderTest(PackageTestCase):
    """Tests use of DirectoryReader as accessor for opened packages."""

    @skipIfNoTorchVision
    @skipIf(
        True,
        "Does not work with latest TorchVision, see https://github.com/pytorch/pytorch/issues/81115",
    )
    def test_loading_pickle(self):
        """
        Test basic saving and loading of modules and pickles from a DirectoryReader.
        """
        resnet = resnet18()

        filename = self.temp()
        with PackageExporter(filename) as e:
            e.intern("**")
            e.save_pickle("model", "model.pkl", resnet)

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            dir_mod = importer.load_pickle("model", "model.pkl")
            input = torch.rand(1, 3, 224, 224)
            self.assertEqual(dir_mod(input), resnet(input))

    def test_loading_module(self):
        """
        Test basic saving and loading of a packages from a DirectoryReader.
        """
        import package_a

        filename = self.temp()
        with PackageExporter(filename) as e:
            e.save_module("package_a")

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            dir_mod = dir_importer.import_module("package_a")
            self.assertEqual(dir_mod.result, package_a.result)

    def test_loading_has_record(self):
        """
        Test DirectoryReader's has_record().
        """
        import package_a  # noqa: F401

        filename = self.temp()
        with PackageExporter(filename) as e:
            e.save_module("package_a")

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            self.assertTrue(dir_importer.zip_reader.has_record("package_a/__init__.py"))
            self.assertFalse(dir_importer.zip_reader.has_record("package_a"))

    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
    def test_resource_reader(self):
        """Tests DirectoryReader as the base for get_resource_reader."""
        filename = self.temp()
        with PackageExporter(filename) as pe:
            # Layout looks like:
            #    package
            #    |-- one/
            #    |   |-- a.txt
            #    |   |-- b.txt
            #    |   |-- c.txt
            #    |   +-- three/
            #    |       |-- d.txt
            #    |       +-- e.txt
            #    +-- two/
            #       |-- f.txt
            #       +-- g.txt
            pe.save_text("one", "a.txt", "hello, a!")
            pe.save_text("one", "b.txt", "hello, b!")
            pe.save_text("one", "c.txt", "hello, c!")

            pe.save_text("one.three", "d.txt", "hello, d!")
            pe.save_text("one.three", "e.txt", "hello, e!")

            pe.save_text("two", "f.txt", "hello, f!")
            pe.save_text("two", "g.txt", "hello, g!")

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            reader_one = importer.get_resource_reader("one")

            # Different behavior from still zipped archives
            resource_path = os.path.join(
                Path(temp_dir), Path(filename).name, "one", "a.txt"
            )
            self.assertEqual(reader_one.resource_path("a.txt"), resource_path)

            self.assertTrue(reader_one.is_resource("a.txt"))
            self.assertEqual(
                reader_one.open_resource("a.txt").getbuffer(), b"hello, a!"
            )
            self.assertFalse(reader_one.is_resource("three"))
            reader_one_contents = list(reader_one.contents())
            reader_one_contents.sort()
            self.assertSequenceEqual(
                reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"]
            )

            reader_two = importer.get_resource_reader("two")
            self.assertTrue(reader_two.is_resource("f.txt"))
            self.assertEqual(
                reader_two.open_resource("f.txt").getbuffer(), b"hello, f!"
            )
            reader_two_contents = list(reader_two.contents())
            reader_two_contents.sort()
            self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"])

            reader_one_three = importer.get_resource_reader("one.three")
            self.assertTrue(reader_one_three.is_resource("d.txt"))
            self.assertEqual(
                reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!"
            )
            reader_one_three_contents = list(reader_one_three.contents())
            reader_one_three_contents.sort()
            self.assertSequenceEqual(reader_one_three_contents, ["d.txt", "e.txt"])

            self.assertIsNone(importer.get_resource_reader("nonexistent_package"))

    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
    def test_package_resource_access(self):
        """Packaged modules should be able to use the importlib.resources API to access
        resources saved in the package.
        """
        mod_src = dedent(
            """\
            import importlib.resources
            import my_cool_resources

            def secret_message():
                return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
            """
        )
        filename = self.temp()
        with PackageExporter(filename) as pe:
            pe.save_source_string("foo.bar", mod_src)
            pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            self.assertEqual(
                dir_importer.import_module("foo.bar").secret_message(),
                "my sekrit plays",
            )

    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
    def test_importer_access(self):
        filename = self.temp()
        with PackageExporter(filename) as he:
            he.save_text("main", "main", "my string")
            he.save_binary("main", "main_binary", b"my string")
            src = dedent(
                """\
                import importlib
                import torch_package_importer as resources

                t = resources.load_text('main', 'main')
                b = resources.load_binary('main', 'main_binary')
                """
            )
            he.save_source_string("main", src, is_package=True)

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            m = dir_importer.import_module("main")
            self.assertEqual(m.t, "my string")
            self.assertEqual(m.b, b"my string")

    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
    def test_resource_access_by_path(self):
        """
        Tests that packaged code can used importlib.resources.path.
        """
        filename = self.temp()
        with PackageExporter(filename) as e:
            e.save_binary("string_module", "my_string", b"my string")
            src = dedent(
                """\
                import importlib.resources
                import string_module

                with importlib.resources.path(string_module, 'my_string') as path:
                    with open(path, mode='r', encoding='utf-8') as f:
                        s = f.read()
                """
            )
            e.save_source_string("main", src, is_package=True)

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            m = dir_importer.import_module("main")
            self.assertEqual(m.s, "my string")

    def test_scriptobject_failure_message(self):
        """
        Test basic saving and loading of a ScriptModule in a directory.
        Currently not supported.
        """
        from package_a.test_module import ModWithTensor

        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))

        filename = self.temp()
        with PackageExporter(filename) as e:
            e.save_pickle("res", "mod.pkl", scripted_mod)

        zip_file = zipfile.ZipFile(filename, "r")

        with self.assertRaisesRegex(
            RuntimeError,
            "Loading ScriptObjects from a PackageImporter created from a "
            "directory is not supported. Use a package archive file instead.",
        ):
            with TemporaryDirectory() as temp_dir:
                zip_file.extractall(path=temp_dir)
                dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
                dir_mod = dir_importer.load_pickle("res", "mod.pkl")


if __name__ == "__main__":
    run_tests()
