# Owner(s): ["oncall: package/deploy"]

from io import BytesIO
from textwrap import dedent
from unittest import skipIf

import torch
from torch.package import PackageExporter, PackageImporter, sys_importer
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests


try:
    from torchvision.models import resnet18

    HAS_TORCHVISION = True
except ImportError:
    HAS_TORCHVISION = False
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")

try:
    from .common import PackageTestCase
except ImportError:
    # Support the case where we run this file directly.
    from common import PackageTestCase


@skipIf(
    True,
    "Does not work with recent torchvision, see https://github.com/pytorch/pytorch/issues/81115",
)
@skipIfNoTorchVision
class ModelTest(PackageTestCase):
    """End-to-end tests packaging an entire model."""

    @skipIf(
        IS_FBCODE or IS_SANDCASTLE,
        "Tests that use temporary files are disabled in fbcode",
    )
    def test_resnet(self):
        resnet = resnet18()

        f1 = self.temp()

        # create a package that will save it along with its code
        with PackageExporter(f1) as e:
            # put the pickled resnet in the package, by default
            # this will also save all the code files references by
            # the objects in the pickle
            e.intern("**")
            e.save_pickle("model", "model.pkl", resnet)

        # we can now load the saved model
        i = PackageImporter(f1)
        r2 = i.load_pickle("model", "model.pkl")

        # test that it works
        input = torch.rand(1, 3, 224, 224)
        ref = resnet(input)
        self.assertEqual(r2(input), ref)

        # functions exist also to get at the private modules in each package
        torchvision = i.import_module("torchvision")

        f2 = BytesIO()
        # if we are doing transfer learning we might want to re-save
        # things that were loaded from a package.
        # We need to tell the exporter about any modules that
        # came from imported packages so that it can resolve
        # class names like torchvision.models.resnet.ResNet
        # to their source code.
        with PackageExporter(f2, importer=(i, sys_importer)) as e:
            # e.importers is a list of module importing functions
            # that by default contains importlib.import_module.
            # it is searched in order until the first success and
            # that module is taken to be what torchvision.models.resnet
            # should be in this code package. In the case of name collisions,
            # such as trying to save a ResNet from two different packages,
            # we take the first thing found in the path, so only ResNet objects from
            # one importer will work. This avoids a bunch of name mangling in
            # the source code. If you need to actually mix ResNet objects,
            # we suggest reconstructing the model objects using code from a single package
            # using functions like save_state_dict and load_state_dict to transfer state
            # to the correct code objects.
            e.intern("**")
            e.save_pickle("model", "model.pkl", r2)

        f2.seek(0)

        i2 = PackageImporter(f2)
        r3 = i2.load_pickle("model", "model.pkl")
        self.assertEqual(r3(input), ref)

    @skipIfNoTorchVision
    def test_model_save(self):
        # This example shows how you might package a model
        # so that the creator of the model has flexibility about
        # how they want to save it but the 'server' can always
        # use the same API to load the package.

        # The convension is for each model to provide a
        # 'model' package with a 'load' function that actual
        # reads the model out of the archive.

        # How the load function is implemented is up to the
        # the packager.

        # get our normal torchvision resnet
        resnet = resnet18()

        f1 = BytesIO()
        # Option 1: save by pickling the whole model
        # + single-line, similar to torch.jit.save
        # - more difficult to edit the code after the model is created
        with PackageExporter(f1) as e:
            e.intern("**")
            e.save_pickle("model", "pickled", resnet)
            # note that this source is the same for all models in this approach
            # so it can be made part of an API that just takes the model and
            # packages it with this source.
            src = dedent(
                """\
                import importlib
                import torch_package_importer as resources

                # server knows to call model.load() to get the model,
                # maybe in the future it passes options as arguments by convension
                def load():
                    return resources.load_pickle('model', 'pickled')
                """
            )
            e.save_source_string("model", src, is_package=True)

        f2 = BytesIO()
        # Option 2: save with state dict
        # - more code to write to save/load the model
        # + but this code can be edited later to adjust adapt the model later
        with PackageExporter(f2) as e:
            e.intern("**")
            e.save_pickle("model", "state_dict", resnet.state_dict())
            src = dedent(
                """\
                import importlib
                import torch_package_importer as resources

                from torchvision.models.resnet import resnet18
                def load():
                    # if you want, you can later edit how resnet is constructed here
                    # to edit the model in the package, while still loading the original
                    # state dict weights
                    r = resnet18()
                    state_dict = resources.load_pickle('model', 'state_dict')
                    r.load_state_dict(state_dict)
                    return r
                """
            )
            e.save_source_string("model", src, is_package=True)

        # regardless of how we chose to package, we can now use the model in a server in the same way
        input = torch.rand(1, 3, 224, 224)
        results = []
        for m in [f1, f2]:
            m.seek(0)
            importer = PackageImporter(m)
            the_model = importer.import_module("model").load()
            r = the_model(input)
            results.append(r)

        self.assertEqual(*results)

    @skipIfNoTorchVision
    def test_script_resnet(self):
        resnet = resnet18()

        f1 = BytesIO()
        # Option 1: save by pickling the whole model
        # + single-line, similar to torch.jit.save
        # - more difficult to edit the code after the model is created
        with PackageExporter(f1) as e:
            e.intern("**")
            e.save_pickle("model", "pickled", resnet)

        f1.seek(0)

        i = PackageImporter(f1)
        loaded = i.load_pickle("model", "pickled")

        # Model should script successfully.
        scripted = torch.jit.script(loaded)

        # Scripted model should save and load successfully.
        f2 = BytesIO()
        torch.jit.save(scripted, f2)
        f2.seek(0)
        loaded = torch.jit.load(f2)

        input = torch.rand(1, 3, 224, 224)
        self.assertEqual(loaded(input), resnet(input))


if __name__ == "__main__":
    run_tests()
