# Owner(s): ["oncall: jit"]

import os
import sys
from typing import Any, List

import torch
import torch.nn as nn
from torch import Tensor
from torch.testing._internal.jit_utils import JitTestCase, make_global


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)

if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


class OrigModule(nn.Module):
    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
        return inp1 + inp2 + 1

    def two(self, input: Tensor) -> Tensor:
        return input + 2

    def forward(self, input: Tensor) -> Tensor:
        return input + self.one(input, input) + 1


class NewModule(nn.Module):
    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
        return inp1 * inp2 + 1

    def forward(self, input: Tensor) -> Tensor:
        return self.one(input, input + 1)


class TestModuleInterface(JitTestCase):
    def test_not_submodule_interface_call(self):
        @torch.jit.interface
        class ModuleInterface(nn.Module):
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                pass

        class TestNotModuleInterfaceCall(nn.Module):
            proxy_mod: ModuleInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigModule()

            def forward(self, input: Tensor) -> Tensor:
                return self.proxy_mod.two(input)

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "object has no attribute or method", "self.proxy_mod.two"
        ):
            torch.jit.script(TestNotModuleInterfaceCall())

    def test_module_interface(self):
        @torch.jit.interface
        class OneTwoModule(nn.Module):
            def one(self, x: Tensor, y: Tensor) -> Tensor:
                pass

            def two(self, x: Tensor) -> Tensor:
                pass

            def forward(self, x: Tensor) -> Tensor:
                pass

        @torch.jit.interface
        class OneTwoClass:
            def one(self, x: Tensor, y: Tensor) -> Tensor:
                pass

            def two(self, x: Tensor) -> Tensor:
                pass

        class FooMod(nn.Module):
            def one(self, x: Tensor, y: Tensor) -> Tensor:
                return x + y

            def two(self, x: Tensor) -> Tensor:
                return 2 * x

            def forward(self, x: Tensor) -> Tensor:
                return self.one(self.two(x), x)

        class BarMod(nn.Module):
            def one(self, x: Tensor, y: Tensor) -> Tensor:
                return x * y

            def two(self, x: Tensor) -> Tensor:
                return 2 / x

            def forward(self, x: Tensor) -> Tensor:
                return self.two(self.one(x, x))

            @torch.jit.export
            def forward2(self, x: Tensor) -> Tensor:
                return self.two(self.one(x, x)) + 1

        make_global(OneTwoModule, OneTwoClass)

        def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
            return mod_list[0].forward(x) + mod_list[1].forward(x)

        def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor:
            return mod_list[0].two(x) + mod_list[1].one(x, x)

        scripted_foo_mod = torch.jit.script(FooMod())
        scripted_bar_mod = torch.jit.script(BarMod())
        self.checkScript(
            use_module_interface,
            (
                [scripted_foo_mod, scripted_bar_mod],
                torch.rand(3, 4),
            ),
        )
        self.checkScript(
            use_class_interface,
            (
                [scripted_foo_mod, scripted_bar_mod],
                torch.rand(3, 4),
            ),
        )

        def call_module_interface_on_other_method(
            mod_interface: OneTwoModule, x: Tensor
        ) -> Tensor:
            return mod_interface.forward2(x)

        # ensure error out when we call the module on the method other than the interface specified.
        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "object has no attribute or method", "mod_interface.forward2"
        ):
            self.checkScript(
                call_module_interface_on_other_method,
                (
                    scripted_bar_mod,
                    torch.rand(3, 4),
                ),
            )

    def test_module_doc_string(self):
        @torch.jit.interface
        class TestInterface(nn.Module):
            def one(self, inp1, inp2):
                # type: (Tensor, Tensor) -> Tensor
                pass

            def forward(self, input):
                # type: (Tensor) -> Tensor
                r"""stuff 1"""
                r"""stuff 2"""
                pass  # noqa: PIE790
                r"""stuff 3"""

        class TestModule(nn.Module):
            proxy_mod: TestInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigModule()

            def forward(self, input):
                # type: (Tensor) -> Tensor
                return self.proxy_mod.forward(input)

        input = torch.randn(3, 4)
        self.checkModule(TestModule(), (input,))

    def test_module_interface_subtype(self):
        @torch.jit.interface
        class OneTwoModule(nn.Module):
            def one(self, x: Tensor, y: Tensor) -> Tensor:
                pass

            def two(self, x: Tensor) -> Tensor:
                pass

            def forward(self, x: Tensor) -> Tensor:
                pass

        make_global(OneTwoModule)

        @torch.jit.script
        def as_module_interface(x: OneTwoModule) -> OneTwoModule:
            return x

        @torch.jit.script
        class Foo:
            def one(self, x: Tensor, y: Tensor) -> Tensor:
                return x + y

            def two(self, x: Tensor) -> Tensor:
                return 2 * x

            def forward(self, x: Tensor) -> Tensor:
                return self.one(self.two(x), x)

        # check class object is not a subtype of module interface
        with self.assertRaisesRegex(
            RuntimeError, "ScriptModule class can be subtype of module interface"
        ):
            as_module_interface(Foo())

        class WrongMod(nn.Module):
            def two(self, x: int) -> int:
                return 2 * x

            def forward(self, x: Tensor) -> Tensor:
                return x + torch.randn(3, self.two(3))

        scripted_wrong_mod = torch.jit.script(WrongMod())

        # wrong module that is not compatible with module interface
        with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
            as_module_interface(scripted_wrong_mod)

        # Check that interface implementations can be contravariant in argument types and covariant in return type.
        @torch.jit.interface
        class TensorToAny(nn.Module):
            def forward(self, input: torch.Tensor) -> Any:
                pass

        make_global(TensorToAny)

        @torch.jit.script
        def as_tensor_to_any(x: TensorToAny) -> TensorToAny:
            return x

        @torch.jit.interface
        class AnyToAny(nn.Module):
            def forward(self, input: Any) -> Any:
                pass

        make_global(AnyToAny)

        @torch.jit.script
        def as_any_to_any(x: AnyToAny) -> AnyToAny:
            return x

        class TensorToAnyImplA(nn.Module):
            def forward(self, input: Any) -> Any:
                return input

        class TensorToAnyImplB(nn.Module):
            def forward(self, input: Any) -> torch.Tensor:
                return torch.tensor([1])

        class AnyToAnyImpl(nn.Module):
            def forward(self, input: Any) -> torch.Tensor:
                return torch.tensor([1])

        as_tensor_to_any(torch.jit.script(TensorToAnyImplA()))
        as_tensor_to_any(torch.jit.script(TensorToAnyImplB()))
        as_any_to_any(torch.jit.script(AnyToAnyImpl()))

    def test_module_interface_inheritance(self):
        with self.assertRaisesRegex(
            RuntimeError, "does not support inheritance yet. Please directly"
        ):

            @torch.jit.interface
            class InheritMod(nn.ReLU):
                def three(self, x: Tensor) -> Tensor:
                    return 3 * x

    def test_module_swap(self):
        @torch.jit.interface
        class ModuleInterface(nn.Module):
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                pass

            def forward(self, input: Tensor) -> Tensor:
                pass

        class TestModule(nn.Module):
            proxy_mod: ModuleInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigModule()

            def forward(self, input: Tensor) -> Tensor:
                return self.proxy_mod.forward(input)

        scripted_mod = torch.jit.script(TestModule())
        input = torch.randn(3, 4)
        self.assertEqual(scripted_mod(input), 3 * input + 2)

        # module swap with module that have the same interface
        scripted_mod.proxy_mod = torch.jit.script(NewModule())
        self.assertEqual(scripted_mod(input), input * (input + 1) + 1)

        # module swap with non-scripted module should throw error
        with self.assertRaisesRegex(
            RuntimeError, "a ScriptModule with non-scripted module"
        ):
            scripted_mod.proxy_mod = NewModule()

    def test_module_swap_wrong_module(self):
        @torch.jit.interface
        class ModuleInterface(nn.Module):
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                pass

            def forward(self, input: Tensor) -> Tensor:
                pass

        class NewModuleWrong(nn.Module):
            def forward(self, input: int) -> int:
                return input + 1

        class TestModule(nn.Module):
            proxy_mod: ModuleInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigModule()

            def forward(self, input: Tensor) -> Tensor:
                return self.proxy_mod.forward(input)

        scripted_mod = torch.jit.script(TestModule())
        # module swap with in-compatible interface
        with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
            scripted_mod.proxy_mod = torch.jit.script(NewModuleWrong())

    def test_module_swap_no_lazy_compile(self):
        @torch.jit.interface
        class ModuleInterface(nn.Module):
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                pass

            def forward(self, input: Tensor) -> Tensor:
                pass

        class TestModule(nn.Module):
            proxy_mod: ModuleInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigModule()

            def forward(self, input: Tensor) -> Tensor:
                return self.proxy_mod.forward(input)

        class NewModuleMethodNotLazyCompile(nn.Module):
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                return inp1 * inp2 + 1

            def forward(self, input: Tensor) -> Tensor:
                return input + 1

        scripted_mod = torch.jit.script(TestModule())
        # module swap with module that have the same interface, but the method not get
        # lazily compiled from forward, user need to export it explicitly for swap to work
        with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
            scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodNotLazyCompile())

        class NewModuleMethodManualExport(nn.Module):
            @torch.jit.export
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                return inp1 * inp2 + 1

            def forward(self, input: Tensor) -> Tensor:
                return input + 1

        scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport())
        input = torch.randn(3, 4)
        self.assertEqual(scripted_mod(input), input + 1)

    def test_module_swap_no_module_interface(self):
        # test module swapping with no module interface
        class TestNoModuleInterface(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigModule()

            def forward(self, input: Tensor) -> Tensor:
                return self.proxy_mod(input)

        scripted_no_module_interface = torch.jit.script(TestNoModuleInterface())
        # proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed.
        scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule())
        # proxy_mod is neither a module interface or have the same JIT type, should fail
        with self.assertRaisesRegex(
            RuntimeError,
            r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' "
            + r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'",
        ):
            scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule())

    def test_script_module_as_interface_swap(self):
        @torch.jit.interface
        class ModuleInterface(nn.Module):
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                pass

            def forward(self, input: Tensor) -> Tensor:
                pass

        class OrigScriptModule(torch.jit.ScriptModule):
            @torch.jit.script_method
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                return inp1 + inp2 + 1

            @torch.jit.script_method
            def forward(self, input: Tensor) -> Tensor:
                return input + self.one(input, input) + 1

        class NewScriptModule(torch.jit.ScriptModule):
            @torch.jit.script_method
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                return inp1 * inp2 + 1

            @torch.jit.script_method
            def forward(self, input: Tensor) -> Tensor:
                return self.one(input, input + 1)

        class TestNNModuleWithScriptModule(nn.Module):
            proxy_mod: ModuleInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigScriptModule()

            def forward(self, input: Tensor) -> Tensor:
                return self.proxy_mod.forward(input)

        input = torch.randn(3, 4)
        scripted_mod = torch.jit.script(TestNNModuleWithScriptModule())
        self.assertEqual(scripted_mod(input), 3 * input + 2)

        scripted_mod.proxy_mod = NewScriptModule()
        self.assertEqual(scripted_mod(input), input * (input + 1) + 1)

    # The call to forward of proxy_mod cannot be inlined. Making sure
    # Freezing is throwing an error for now.
    def test_freeze_module_with_interface(self):
        class SubModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b = 20

            def forward(self, x):
                return self.b

        class OrigMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 0

            def forward(self, x):
                return self.a

        @torch.jit.interface
        class ModInterface(torch.nn.Module):
            def forward(self, x: Tensor) -> int:
                pass

        class TestModule(torch.nn.Module):
            proxy_mod: ModInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigMod()
                self.sub = SubModule()  # folded

            def forward(self, x):
                return self.proxy_mod(x) + self.sub(x)

        m = torch.jit.script(TestModule())
        m.eval()
        mf = torch._C._freeze_module(m._c)
        # Assume interface has no aliasing
        mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
        input = torch.tensor([1])
        out_s = m.forward(input)
        out_f = mf.forward(input)
        self.assertEqual(out_s, out_f)

    def test_freeze_module_with_setattr_in_interface(self):
        class SubModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b = 20

            def forward(self, x):
                self.b += 2
                return self.b

            @torch.jit.export
            def getb(self, x):
                return self.b

        class OrigMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 0

            def forward(self, x):
                return self.a

        @torch.jit.interface
        class ModInterface(torch.nn.Module):
            def forward(self, x: Tensor) -> int:
                pass

        class TestModule(torch.nn.Module):
            proxy_mod: ModInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigMod()
                self.sub = SubModule()

            def forward(self, x):
                return self.proxy_mod(x) + self.sub.getb(x)

        m = torch.jit.script(TestModule())
        m.proxy_mod = m.sub
        m.eval()
        mf = torch._C._freeze_module(m._c, freezeInterfaces=True)

    def test_freeze_module_with_inplace_mutation_in_interface(self):
        class SubModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b = torch.tensor([1.5])

            def forward(self, x):
                self.b[0] += 2
                return self.b

            @torch.jit.export
            def getb(self, x):
                return self.b

        class OrigMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([0.5])

            def forward(self, x):
                return self.a

        @torch.jit.interface
        class ModInterface(torch.nn.Module):
            def forward(self, x: Tensor) -> Tensor:
                pass

        class TestModule(torch.nn.Module):
            proxy_mod: ModInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigMod()
                self.sub = SubModule()

            def forward(self, x):
                y = self.proxy_mod(x)
                z = self.sub.getb(x)
                return y[0] + z[0]

        m = torch.jit.script(TestModule())
        m.proxy_mod = m.sub
        m.sub.b = m.proxy_mod.b
        m.eval()
        mf = torch._C._freeze_module(m._c, freezeInterfaces=True)

    def test_freeze_module_with_mutated_interface(self):
        class SubModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b = torch.tensor([1.5])

            def forward(self, x):
                return self.b

            @torch.jit.export
            def getb(self, x):
                return self.b

        class OrigMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([0.5])

            def forward(self, x):
                return self.a

        @torch.jit.interface
        class ModInterface(torch.nn.Module):
            def forward(self, x: Tensor) -> Tensor:
                pass

        class TestModule(torch.nn.Module):
            proxy_mod: ModInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigMod()
                self.sub = SubModule()

            def forward(self, x):
                self.proxy_mod = self.sub
                y = self.proxy_mod(x)
                z = self.sub.getb(x)
                return y[0] + z[0]

        m = torch.jit.script(TestModule())
        m.eval()
        with self.assertRaisesRegex(
            RuntimeError, "Freezing does not support SetAttr on an interface type."
        ):
            mf = torch._C._freeze_module(m._c, freezeInterfaces=True)

    def test_freeze_module_with_interface_and_fork(self):
        class SubModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b = torch.tensor([1.5])

            def forward(self, x):
                self.b[0] += 3.2
                return self.b

        class OrigMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([0.5])

            def forward(self, x):
                return self.a

        @torch.jit.interface
        class ModInterface(torch.nn.Module):
            def forward(self, x: Tensor) -> Tensor:
                pass

        class TestModule(torch.nn.Module):
            proxy_mod: ModInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigMod()
                self.sub = SubModule()

            def forward(self, x):
                y = self.proxy_mod(x)
                z = self.sub(x)
                return y + z

        class MainModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.test = TestModule()

            def forward(self, x):
                fut = torch.jit._fork(self.test.forward, x)
                y = self.test(x)
                z = torch.jit._wait(fut)
                return y + z

        m = torch.jit.script(MainModule())
        m.eval()
        mf = torch._C._freeze_module(m._c, freezeInterfaces=True)

    def test_module_apis_interface(self):
        @torch.jit.interface
        class ModuleInterface(nn.Module):
            def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
                pass

        class TestModule(nn.Module):
            proxy_mod: ModuleInterface

            def __init__(self) -> None:
                super().__init__()
                self.proxy_mod = OrigModule()

            def forward(self, input):
                return input * 2

            @torch.jit.export
            def method(self, input):
                for module in self.modules():
                    input = module(input)
                return input

        with self.assertRaisesRegex(Exception, "Could not compile"):
            scripted_mod = torch.jit.script(TestModule())
