# Owner(s): ["oncall: jit"]

import io
import os
import sys
from pathlib import Path
from typing import NamedTuple, Optional

import torch
from torch import Tensor
from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


class TestSaveLoad(JitTestCase):
    def test_different_modules(self):
        """
        Exercise the situation where we have the same qualified name
        in two different CompilationUnits on save/load.
        """

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 2)
                self.bar = torch.nn.Linear(2, 2)

            def forward(self, x):
                x = self.foo(x)
                x = self.bar(x)
                return x

        first_script_module = torch.jit.script(Foo())
        first_saved_module = io.BytesIO()
        torch.jit.save(first_script_module, first_saved_module)
        first_saved_module.seek(0)

        clear_class_registry()

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 2)

            def forward(self, x):
                x = self.foo(x)
                return x

        second_script_module = torch.jit.script(Foo())
        second_saved_module = io.BytesIO()
        torch.jit.save(torch.jit.script(Foo()), second_saved_module)
        second_saved_module.seek(0)

        clear_class_registry()

        self.assertEqual(
            first_script_module._c.qualified_name,
            second_script_module._c.qualified_name,
        )

        class ContainsBoth(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("second", torch.jit.load(second_saved_module))
                self.add_module("first", torch.jit.load(first_saved_module))

            def forward(self, x):
                x = self.first(x)
                x = self.second(x)
                return x

        sm = torch.jit.script(ContainsBoth())
        contains_both = io.BytesIO()
        torch.jit.save(sm, contains_both)
        contains_both.seek(0)
        sm = torch.jit.load(contains_both)

    def test_different_functions(self):
        """
        Exercise the situation where we have the same qualified name
        in two different CompilationUnits on save/load.
        """

        def lol(x):
            return x

        class Foo(torch.nn.Module):
            def forward(self, x):
                return lol(x)

        first_script_module = torch.jit.script(Foo())
        first_saved_module = io.BytesIO()
        torch.jit.save(first_script_module, first_saved_module)
        first_saved_module.seek(0)

        clear_class_registry()

        def lol(x):  # noqa: F811
            return "hello"

        class Foo(torch.nn.Module):
            def forward(self, x):
                return lol(x)

        second_script_module = torch.jit.script(Foo())
        second_saved_module = io.BytesIO()
        torch.jit.save(torch.jit.script(Foo()), second_saved_module)
        second_saved_module.seek(0)

        clear_class_registry()

        self.assertEqual(
            first_script_module._c.qualified_name,
            second_script_module._c.qualified_name,
        )

        class ContainsBoth(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("second", torch.jit.load(second_saved_module))
                self.add_module("first", torch.jit.load(first_saved_module))

            def forward(self, x):
                x = self.first(x)
                x = self.second(x)
                return x

        sm = torch.jit.script(ContainsBoth())
        contains_both = io.BytesIO()
        torch.jit.save(sm, contains_both)
        contains_both.seek(0)
        sm = torch.jit.load(contains_both)

    def test_different_interfaces(self):
        """
        Exercise the situation where we have the same qualified name
        in two different CompilationUnits on save/load.
        """

        @torch.jit.interface
        class MyInterface:
            def bar(self, x: Tensor) -> Tensor:
                pass

        @torch.jit.script
        class ImplementInterface:
            def __init__(self) -> None:
                pass

            def bar(self, x):
                return x

        class Foo(torch.nn.Module):
            __annotations__ = {"interface": MyInterface}

            def __init__(self) -> None:
                super().__init__()
                self.interface = ImplementInterface()

            def forward(self, x):
                return self.interface.bar(x)

        first_script_module = torch.jit.script(Foo())
        first_saved_module = io.BytesIO()
        torch.jit.save(first_script_module, first_saved_module)
        first_saved_module.seek(0)

        clear_class_registry()

        @torch.jit.interface
        class MyInterface:
            def not_bar(self, x: Tensor) -> Tensor:
                pass

        @torch.jit.script  # noqa: F811
        class ImplementInterface:  # noqa: F811
            def __init__(self) -> None:
                pass

            def not_bar(self, x):
                return x

        class Foo(torch.nn.Module):
            __annotations__ = {"interface": MyInterface}

            def __init__(self) -> None:
                super().__init__()
                self.interface = ImplementInterface()

            def forward(self, x):
                return self.interface.not_bar(x)

        second_script_module = torch.jit.script(Foo())
        second_saved_module = io.BytesIO()
        torch.jit.save(torch.jit.script(Foo()), second_saved_module)
        second_saved_module.seek(0)

        clear_class_registry()

        self.assertEqual(
            first_script_module._c.qualified_name,
            second_script_module._c.qualified_name,
        )

        class ContainsBoth(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("second", torch.jit.load(second_saved_module))
                self.add_module("first", torch.jit.load(first_saved_module))

            def forward(self, x):
                x = self.first(x)
                x = self.second(x)
                return x

        sm = torch.jit.script(ContainsBoth())
        contains_both = io.BytesIO()
        torch.jit.save(sm, contains_both)
        contains_both.seek(0)
        sm = torch.jit.load(contains_both)

    def test_many_collisions(self):
        class MyCoolNamedTuple(NamedTuple):
            a: int

        @torch.jit.interface
        class MyInterface:
            def bar(self, x: Tensor) -> Tensor:
                pass

        @torch.jit.script
        class ImplementInterface:
            def __init__(self) -> None:
                pass

            def bar(self, x):
                return x

        def lol(x):
            return x

        class Foo(torch.nn.Module):
            interface: MyInterface

            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 2)
                self.bar = torch.nn.Linear(2, 2)
                self.interface = ImplementInterface()

            def forward(self, x):
                x = self.foo(x)
                x = self.bar(x)
                x = lol(x)
                x = self.interface.bar(x)

                return x, MyCoolNamedTuple(a=5)

        first_script_module = torch.jit.script(Foo())
        first_saved_module = io.BytesIO()
        torch.jit.save(first_script_module, first_saved_module)
        first_saved_module.seek(0)

        clear_class_registry()

        @torch.jit.interface
        class MyInterface:
            def not_bar(self, x: Tensor) -> Tensor:
                pass

        @torch.jit.script  # noqa: F811
        class ImplementInterface:  # noqa: F811
            def __init__(self) -> None:
                pass

            def not_bar(self, x):
                return x

        def lol(x):  # noqa: F811
            return "asdofij"

        class MyCoolNamedTuple(NamedTuple):  # noqa: F811
            a: str

        class Foo(torch.nn.Module):
            interface: MyInterface

            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 2)
                self.interface = ImplementInterface()

            def forward(self, x):
                x = self.foo(x)
                self.interface.not_bar(x)
                x = lol(x)
                return x, MyCoolNamedTuple(a="hello")

        second_script_module = torch.jit.script(Foo())
        second_saved_module = io.BytesIO()
        torch.jit.save(second_script_module, second_saved_module)
        second_saved_module.seek(0)

        clear_class_registry()

        self.assertEqual(
            first_script_module._c.qualified_name,
            second_script_module._c.qualified_name,
        )

        class ContainsBoth(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("second", torch.jit.load(second_saved_module))
                self.add_module("first", torch.jit.load(first_saved_module))

            def forward(self, x):
                x, named_tuple_1 = self.first(x)
                x, named_tuple_2 = self.second(x)
                return len(x + named_tuple_2.a) + named_tuple_1.a

        sm = torch.jit.script(ContainsBoth())
        contains_both = io.BytesIO()
        torch.jit.save(sm, contains_both)
        contains_both.seek(0)
        sm = torch.jit.load(contains_both)

    def test_save_load_with_extra_files(self):
        class MyMod(torch.jit.ScriptModule):
            @torch.jit.script_method
            def forward(self, a):
                return a

        # specifically test binary data
        value = b"bar\x00\xffbaz"

        expected_extra_files = {}
        expected_extra_files["foo"] = value
        # verify that str to bytes conversion also works
        expected_extra_files["foo2"] = "bar"
        m = MyMod()

        # Save to file.
        with TemporaryFileName() as fname:
            m.save(fname, _extra_files=expected_extra_files)
            # values don't matter
            extra_files = {"foo": "", "foo2": None}
            torch.jit.load(fname, _extra_files=extra_files)
            self.assertEqual(value, extra_files["foo"])
            # results come back always as bytes
            self.assertEqual(b"bar", extra_files["foo2"])

            # Use torch.jit API
            torch.jit.save(m, fname, _extra_files=expected_extra_files)
            extra_files["foo"] = ""
            torch.jit.load(fname, _extra_files=extra_files)
            self.assertEqual(value, extra_files["foo"])

        # Save to buffer.
        buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
        extra_files = {"foo": ""}
        torch.jit.load(buffer, _extra_files=extra_files)
        self.assertEqual(value, extra_files["foo"])

        # Use torch.jit API
        buffer = io.BytesIO()
        torch.jit.save(m, buffer, _extra_files=expected_extra_files)
        buffer.seek(0)
        extra_files = {"foo": ""}
        torch.jit.load(buffer, _extra_files=extra_files)
        self.assertEqual(value, extra_files["foo"])

        # Non-existent file 'bar'
        with self.assertRaises(RuntimeError):
            extra_files["bar"] = ""
            torch.jit.load(buffer, _extra_files=extra_files)

    def test_save_load_using_pathlib(self):
        class MyMod(torch.jit.ScriptModule):
            @torch.jit.script_method
            def forward(self, a):
                return 2 * a

        m = MyMod()

        # Save then load.
        with TemporaryFileName() as fname:
            path = Path(fname)
            m.save(path)
            m2 = torch.jit.load(path)

        x = torch.tensor([1.0, 2.0, 3.0, 4.0])
        self.assertTrue(torch.equal(m(x), m2(x)))

    def test_save_nonexit_file(self):
        class Foo(torch.nn.Module):
            def forward(self, x):
                return 2 * x

        script_module = torch.jit.script(Foo())
        with self.assertRaises(RuntimeError):
            script_module.save("NonExist/path/test.pt")

    def test_save_namedtuple_input_only(self):
        """
        Even if a NamedTuple is only used as an input argument, saving and
        loading should work correctly.
        """
        global FooTuple  # see [local resolution in python]

        class FooTuple(NamedTuple):
            a: int

        class MyModule(torch.nn.Module):
            def forward(self, x: FooTuple) -> torch.Tensor:
                return torch.tensor(3)

        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
        output = m_loaded(FooTuple(a=5))
        self.assertEqual(output, torch.tensor(3))

    def test_save_namedtuple_input_only_forwardref(self):
        """
        Even if a NamedTuple is only used as an input argument, saving and
        loading should work correctly.
        """
        global FooTuple  # see [local resolution in python]

        class FooTuple(NamedTuple):
            a: "int"

        class MyModule(torch.nn.Module):
            def forward(self, x: FooTuple) -> torch.Tensor:
                return torch.tensor(3)

        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
        output = m_loaded(FooTuple(a=5))
        self.assertEqual(output, torch.tensor(3))

    def test_save_namedtuple_output_only(self):
        """
        Even if a NamedTuple is only used as an output argument, saving and
        loading should work correctly.
        """
        global FooTuple  # see [local resolution in python]

        class FooTuple(NamedTuple):
            a: int

        class MyModule(torch.nn.Module):
            def forward(self) -> Optional[FooTuple]:
                return None

        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
        output = m_loaded()
        self.assertEqual(output, None)

    def test_save_load_params_buffers_submodules(self):
        """
        Check that parameters, buffers, and submodules are the same after loading.
        """

        class Submodule(torch.nn.Module):
            pass

        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("submodule_a", Submodule())
                self.register_parameter(
                    "parameter_a", torch.nn.Parameter(torch.randn(4))
                )
                self.buffer = torch.nn.Buffer(torch.randn(4))
                self.t = torch.rand(4)  # not buffer

                self.parameter_b = torch.nn.Parameter(torch.randn(4))
                self.submodule_b = Submodule()
                self.buffer_b = torch.nn.Buffer(torch.randn(4))

        m = TestModule()
        m_loaded = self.getExportImportCopy(torch.jit.script(m))

        # Check submodules.
        self.assertEqual(
            len(list(m.named_modules())), len(list(m_loaded.named_modules()))
        )
        for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
            m_name, _ = m_s
            loaded_name, _ = loaded_s
            self.assertEqual(m_name, loaded_name)

        # Check parameters.
        self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters())))
        for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()):
            self.assertEqual(m_p, loaded_p)

        # Check buffers.
        self.assertEqual(
            len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
        )
        for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
            m_name, m_buffer = m_b
            loaded_name, loaded_buffer = loaded_b
            self.assertEqual(m_name, loaded_name)
            self.assertEqual(m_buffer, loaded_buffer)

    def test_save_load_meta_tensors(self):
        """
        Check that parameters, buffers, and submodules are the same after loading
        for a module with parameters and buffers that are meta tensors
        """

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 3, device="meta")
                self.bar = torch.nn.Linear(3, 4)
                self.buffer = torch.nn.Buffer(torch.randn(4, device="meta"))

            def forward(self, x):
                x = self.foo(x)
                x = self.bar(x)
                return x

        m = Foo()
        m_loaded = self.getExportImportCopy(torch.jit.script(m))
        # Check submodules.
        self.assertEqual(
            len(list(m.named_modules())), len(list(m_loaded.named_modules()))
        )
        self.assertEqual(
            {name for name, _ in m.named_modules()},
            {name for name, _ in m_loaded.named_modules()},
        )
        # Check parameters.
        m_params = dict(m.named_parameters())
        m_loaded_params = dict(m_loaded.named_parameters())
        self.assertEqual(len(m_params), len(m_loaded_params))
        self.assertEqual(m_params, m_loaded_params)
        # Check buffers.
        m_buffers = dict(m.named_buffers())
        m_loaded_buffers = dict(m_loaded.named_buffers())
        self.assertEqual(len(m_buffers), len(m_loaded_buffers))
        self.assertEqual(m_buffers, m_loaded_buffers)
        # Check params and buffers that are/are not meta tensors
        self.assertTrue(m_params["foo.weight"].is_meta)
        self.assertTrue(m_loaded_params["foo.weight"].is_meta)
        self.assertTrue(m_params["foo.bias"].is_meta)
        self.assertTrue(m_loaded_params["foo.bias"].is_meta)
        self.assertFalse(m_params["bar.weight"].is_meta)
        self.assertFalse(m_loaded_params["bar.weight"].is_meta)
        self.assertFalse(m_params["bar.bias"].is_meta)
        self.assertFalse(m_loaded_params["bar.bias"].is_meta)
        self.assertTrue(m_buffers["buffer"].is_meta)
        self.assertTrue(m_loaded_buffers["buffer"].is_meta)

    def test_save_load_meta_tensors_to_device(self):
        """
        Check that when loading a module with meta tensors to device, the meta tensors
        stay on meta, but non-meta tensors are set to the indicated device.
        """

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 3, device="meta")
                self.bar = torch.nn.Linear(3, 4)

            def forward(self, x):
                x = self.foo(x)
                x = self.bar(x)
                return x

        m = Foo()

        m_loaded = self.getExportImportCopy(torch.jit.script(m), map_location="cpu")
        # Check submodules.
        self.assertEqual(
            len(list(m.named_modules())), len(list(m_loaded.named_modules()))
        )
        self.assertEqual(
            {name for name, _ in m.named_modules()},
            {name for name, _ in m_loaded.named_modules()},
        )
        # Check parameters.
        m_params = dict(m.named_parameters())
        m_loaded_params = dict(m_loaded.named_parameters())
        self.assertEqual(len(m_params), len(m_loaded_params))
        self.assertEqual(m_params, m_loaded_params)
        # Check params and buffers that are/are not meta tensors
        self.assertTrue(m_params["foo.weight"].is_meta)
        self.assertTrue(m_loaded_params["foo.weight"].is_meta)
        self.assertTrue(m_params["foo.bias"].is_meta)
        self.assertTrue(m_loaded_params["foo.bias"].is_meta)
        self.assertTrue(m_params["bar.weight"].is_cpu)
        self.assertTrue(m_loaded_params["bar.weight"].is_cpu)
        self.assertTrue(m_params["bar.bias"].is_cpu)
        self.assertTrue(m_loaded_params["bar.bias"].is_cpu)

    def test_save_load_with_saved_traced_inputs(self):
        """
        Check that saving and loading with traced inputs works as expected
        """

        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                return torch.ones(1)

        def get_loaded_inputs(inputs):
            traced_module = torch.jit.trace(module, input1)
            traced_inputs = list(traced_module.graph.inputs())
            with TemporaryFileName() as fname:
                path = Path(fname)
                traced_module.save(path)
                print(traced_module.graph)
                loaded_module = torch.jit.load(path, _restore_shapes=True)
                print(loaded_module.graph)
                return traced_inputs, list(loaded_module.graph.inputs())

        module = Module()
        input_tensor = torch.rand(1, 3, 24, 24)
        # Validate that with no input specified the traced inputs are stored
        traced_module = torch.jit.trace(module, input_tensor)
        traced_inputs = list(traced_module.graph.inputs())
        self.assertEqual(
            traced_module._c._retrieve_traced_inputs()["forward"], [input_tensor]
        )
        with TemporaryFileName() as fname:
            path = Path(fname)
            traced_module.save(path)
            loaded_module = torch.jit.load(path, _restore_shapes=True)
            loaded_inputs = list(loaded_module.graph.inputs())
            self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
            self.assertEqual(
                traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes()
            )
            # Validate that if no shapes are requested previous functionality remains
            loaded_module = torch.jit.load(path)
            loaded_inputs = list(loaded_module.graph.inputs())
            self.assertEqual(loaded_inputs[1].type().sizes(), None)

        # Validate that inputs aren't saved when requested not to
        traced_module = torch.jit.trace(module, input_tensor, _store_inputs=False)
        traced_inputs = list(traced_module.graph.inputs())
        self.assertEqual(len(traced_module._c._retrieve_traced_inputs()), 0)

        with TemporaryFileName() as fname:
            path = Path(fname)
            traced_module.save(path)
            loaded_module = torch.jit.load(path, _restore_shapes=True)
            loaded_inputs = list(loaded_module.graph.inputs())
            self.assertEqual(loaded_inputs[1].type().sizes(), None)
            # Validate that if no shapes are requested previous functionality remains
            loaded_module = torch.jit.load(path)
            loaded_inputs = list(loaded_module.graph.inputs())
            self.assertEqual(loaded_inputs[1].type().sizes(), None)

        # Validate that complex inputs work
        # Testing dict of list with empty tensors
        input1 = {
            "1000": (
                torch.tensor([0]),
                torch.tensor([], dtype=torch.int64),
                torch.tensor([]),
            )
        }
        traced_inputs, loaded_inputs = get_loaded_inputs(input1)
        self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())

        # Testing dict of list
        input2 = {
            "1000": (
                torch.tensor([0]),
                torch.tensor([1500000, 1500004], dtype=torch.int64),
                torch.tensor([2.0, 3.0]),
            )
        }
        traced_inputs, loaded_inputs = get_loaded_inputs(input2)
        self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())

        # Testing list
        input3 = [
            torch.tensor([0]),
            torch.tensor([1500000, 1500004], dtype=torch.int64),
            torch.tensor([2.0, 3.0]),
        ]

        traced_inputs, loaded_inputs = get_loaded_inputs(input3)
        self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())

        # Testing list of dict of list
        input4 = [
            {
                "1000": (
                    torch.tensor([0]),
                    torch.tensor([1500000, 1500004], dtype=torch.int64),
                    torch.tensor([2.0, 3.0]),
                )
            }
        ]

        traced_inputs, loaded_inputs = get_loaded_inputs(input4)
        self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())

    @skipIfTorchDynamo("too slow")
    def test_save_load_large_string_attribute(self):
        """
        Check if the model with string > 4GB can be loaded.
        """
        import psutil

        if psutil.virtual_memory().available < 60 * 1024 * 1024 * 1024:
            # Profiled the test execution, and got this number to be safe to run the test
            self.skipTest(
                "Doesn't have enough memory to run test_save_load_large_string_attribute"
            )

        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = "x" * (2**32 + 1)

            def forward(self, i) -> int:
                return len(self.x) + i.numel()

        inp = torch.ones(0)
        ts = torch.jit.script(Model())
        ts_output = ts(inp)

        b = io.BytesIO(ts.save_to_buffer())
        del ts

        loaded_ts = torch.jit.load(b)
        del b
        loaded_output = loaded_ts(inp)
        self.assertEqual(ts_output, loaded_output)


def script_module_to_buffer(script_module):
    module_buffer = io.BytesIO(
        script_module._save_to_buffer_for_lite_interpreter(_use_flatbuffer=True)
    )
    module_buffer.seek(0)
    return module_buffer


class TestSaveLoadFlatbuffer(JitTestCase):
    def test_different_modules(self):
        """
        Exercise the situation where we have the same qualified name
        in two different CompilationUnits on save/load.
        """

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 2)
                self.bar = torch.nn.Linear(2, 2)

            def forward(self, x):
                x = self.foo(x)
                x = self.bar(x)
                return x

        first_script_module = torch.jit.script(Foo())
        first_saved_module = script_module_to_buffer(first_script_module)

        clear_class_registry()

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 2)

            def forward(self, x):
                x = self.foo(x)
                return x

        second_script_module = torch.jit.script(Foo())
        second_saved_module = script_module_to_buffer(second_script_module)

        clear_class_registry()

        self.assertEqual(
            first_script_module._c.qualified_name,
            second_script_module._c.qualified_name,
        )

        class ContainsBoth(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("second", torch.jit.load(second_saved_module))
                self.add_module("first", torch.jit.load(first_saved_module))

            def forward(self, x):
                x = self.first(x)
                x = self.second(x)
                return x

        sm = torch.jit.script(ContainsBoth())
        contains_both = script_module_to_buffer(sm)
        sm = torch.jit.load(contains_both)

    def test_different_functions(self):
        """
        Exercise the situation where we have the same qualified name
        in two different CompilationUnits on save/load.
        """

        def lol(x):
            return x

        class Foo(torch.nn.Module):
            def forward(self, x):
                return lol(x)

        first_script_module = torch.jit.script(Foo())
        first_saved_module = script_module_to_buffer(first_script_module)
        clear_class_registry()

        def lol(x):  # noqa: F811
            return "hello"

        class Foo(torch.nn.Module):
            def forward(self, x):
                return lol(x)

        second_script_module = torch.jit.script(Foo())
        second_saved_module = script_module_to_buffer(second_script_module)

        clear_class_registry()

        self.assertEqual(
            first_script_module._c.qualified_name,
            second_script_module._c.qualified_name,
        )

        class ContainsBoth(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("second", torch.jit.load(second_saved_module))
                self.add_module("first", torch.jit.load(first_saved_module))

            def forward(self, x):
                x = self.first(x)
                x = self.second(x)
                return x

        sm = torch.jit.script(ContainsBoth())
        contains_both = script_module_to_buffer(sm)
        sm = torch.jit.load(contains_both)

    def test_different_interfaces(self):
        """
        Exercise the situation where we have the same qualified name
        in two different CompilationUnits on save/load.
        """

        @torch.jit.interface
        class MyInterface:
            def bar(self, x: Tensor) -> Tensor:
                pass

        @torch.jit.script
        class ImplementInterface:
            def __init__(self) -> None:
                pass

            def bar(self, x):
                return x

        class Foo(torch.nn.Module):
            __annotations__ = {"interface": MyInterface}

            def __init__(self) -> None:
                super().__init__()
                self.interface = ImplementInterface()

            def forward(self, x):
                return self.interface.bar(x)

        first_script_module = torch.jit.script(Foo())
        first_saved_module = script_module_to_buffer(first_script_module)
        clear_class_registry()

        @torch.jit.interface
        class MyInterface:
            def not_bar(self, x: Tensor) -> Tensor:
                pass

        @torch.jit.script  # noqa: F811
        class ImplementInterface:  # noqa: F811
            def __init__(self) -> None:
                pass

            def not_bar(self, x):
                return x

        class Foo(torch.nn.Module):
            __annotations__ = {"interface": MyInterface}

            def __init__(self) -> None:
                super().__init__()
                self.interface = ImplementInterface()

            def forward(self, x):
                return self.interface.not_bar(x)

        second_script_module = torch.jit.script(Foo())
        second_saved_module = script_module_to_buffer(second_script_module)

        clear_class_registry()

        self.assertEqual(
            first_script_module._c.qualified_name,
            second_script_module._c.qualified_name,
        )

        class ContainsBoth(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("second", torch.jit.load(second_saved_module))
                self.add_module("first", torch.jit.load(first_saved_module))

            def forward(self, x):
                x = self.first(x)
                x = self.second(x)
                return x

        sm = torch.jit.script(ContainsBoth())
        contains_both = script_module_to_buffer(sm)
        sm = torch.jit.load(contains_both)

    def test_many_collisions(self):
        class MyCoolNamedTuple(NamedTuple):
            a: int

        @torch.jit.interface
        class MyInterface:
            def bar(self, x: Tensor) -> Tensor:
                pass

        @torch.jit.script
        class ImplementInterface:
            def __init__(self) -> None:
                pass

            def bar(self, x):
                return x

        def lol(x):
            return x

        class Foo(torch.nn.Module):
            interface: MyInterface

            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 2)
                self.bar = torch.nn.Linear(2, 2)
                self.interface = ImplementInterface()

            def forward(self, x):
                x = self.foo(x)
                x = self.bar(x)
                x = lol(x)
                x = self.interface.bar(x)

                return x, MyCoolNamedTuple(a=5)

        first_script_module = torch.jit.script(Foo())
        first_saved_module = script_module_to_buffer(first_script_module)

        clear_class_registry()

        @torch.jit.interface
        class MyInterface:
            def not_bar(self, x: Tensor) -> Tensor:
                pass

        @torch.jit.script  # noqa: F811
        class ImplementInterface:  # noqa: F811
            def __init__(self) -> None:
                pass

            def not_bar(self, x):
                return x

        def lol(x):  # noqa: F811
            return "asdofij"

        class MyCoolNamedTuple(NamedTuple):  # noqa: F811
            a: str

        class Foo(torch.nn.Module):
            interface: MyInterface

            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 2)
                self.interface = ImplementInterface()

            def forward(self, x):
                x = self.foo(x)
                self.interface.not_bar(x)
                x = lol(x)
                return x, MyCoolNamedTuple(a="hello")

        second_script_module = torch.jit.script(Foo())
        second_saved_module = script_module_to_buffer(second_script_module)

        clear_class_registry()

        self.assertEqual(
            first_script_module._c.qualified_name,
            second_script_module._c.qualified_name,
        )

        class ContainsBoth(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("second", torch.jit.load(second_saved_module))
                self.add_module("first", torch.jit.load(first_saved_module))

            def forward(self, x):
                x, named_tuple_1 = self.first(x)
                x, named_tuple_2 = self.second(x)
                return len(x + named_tuple_2.a) + named_tuple_1.a

        sm = torch.jit.script(ContainsBoth())
        contains_both = script_module_to_buffer(sm)
        sm = torch.jit.load(contains_both)

    def test_save_load_using_pathlib(self):
        class MyMod(torch.jit.ScriptModule):
            @torch.jit.script_method
            def forward(self, a):
                return 2 * a

        m = MyMod()

        # Save then load.
        with TemporaryFileName() as fname:
            path = Path(fname)
            torch.jit.save_jit_module_to_flatbuffer(m, path)
            m2 = torch.jit.load(path)

        x = torch.tensor([1.0, 2.0, 3.0, 4.0])
        self.assertTrue(torch.equal(m(x), m2(x)))

    def test_save_namedtuple_input_only(self):
        """
        Even if a NamedTuple is only used as an input argument, saving and
        loading should work correctly.
        """
        global FooTuple  # see [local resolution in python]

        class FooTuple(NamedTuple):
            a: int

        class MyModule(torch.nn.Module):
            def forward(self, x: FooTuple) -> torch.Tensor:
                return torch.tensor(3)

        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
        output = m_loaded(FooTuple(a=5))
        self.assertEqual(output, torch.tensor(3))

    def test_save_namedtuple_output_only(self):
        """
        Even if a NamedTuple is only used as an output argument, saving and
        loading should work correctly.
        """
        global FooTuple  # see [local resolution in python]

        class FooTuple(NamedTuple):
            a: int

        class MyModule(torch.nn.Module):
            def forward(self) -> Optional[FooTuple]:
                return None

        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
        output = m_loaded()
        self.assertEqual(output, None)

    def test_module_info_flatbuffer(self):
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Linear(2, 2)
                self.bar = torch.nn.Linear(2, 2)

            def forward(self, x):
                x = self.foo(x)
                x = self.bar(x)
                return x

        first_script_module = torch.jit.script(Foo())
        first_saved_module = io.BytesIO()
        torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
        first_saved_module.seek(0)
        ff_info = torch.jit._serialization.get_flatbuffer_module_info(
            first_saved_module
        )
        self.assertEqual(ff_info["bytecode_version"], 9)
        self.assertEqual(ff_info["operator_version"], 1)
        self.assertEqual(ff_info["type_names"], set())
        self.assertEqual(ff_info["opname_to_num_args"], {"aten::linear": 3})

        self.assertEqual(len(ff_info["function_names"]), 1)
        self.assertTrue(next(iter(ff_info["function_names"])).endswith("forward"))

    def test_save_load_params_buffers_submodules(self):
        """
        Check that parameters, buffers, and submodules are the same after loading.
        """

        class Submodule(torch.nn.Module):
            pass

        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.add_module("submodule_a", Submodule())
                self.register_parameter(
                    "parameter_a", torch.nn.Parameter(torch.randn(4))
                )
                self.buffer = torch.nn.Buffer(torch.randn(4))
                self.t = torch.rand(4)  # not buffer

                self.parameter_b = torch.nn.Parameter(torch.randn(4))
                self.submodule_b = Submodule()
                self.buffer_b = torch.nn.Buffer(torch.randn(4))

        m = TestModule()
        m_loaded = self.getExportImportCopy(torch.jit.script(m))

        # Check submodules.
        self.assertEqual(
            len(list(m.named_modules())), len(list(m_loaded.named_modules()))
        )
        for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
            m_name, _ = m_s
            loaded_name, _ = loaded_s
            self.assertEqual(m_name, loaded_name)

        # Check parameters.
        self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters())))
        for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()):
            self.assertEqual(m_p, loaded_p)

        # Check buffers.
        self.assertEqual(
            len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
        )
        for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
            m_name, m_buffer = m_b
            loaded_name, loaded_buffer = loaded_b
            self.assertEqual(m_name, loaded_name)
            self.assertEqual(m_buffer, loaded_buffer)

    def test_save_load_with_extra_files(self):
        """
        Check that parameters, buffers, and submodules are the same after loading.
        """

        class Module(torch.nn.Module):
            def forward(self, x: Tensor):
                return x

        module = Module()
        script_module = torch.jit.script(module)

        extra_files = {"abc.json": b"[1,2,3]"}
        script_module_io = script_module._save_to_buffer_for_lite_interpreter(
            _extra_files=extra_files, _use_flatbuffer=True
        )

        re_extra_files = {}
        torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files)

        self.assertEqual(extra_files, re_extra_files)
