# Owner(s): ["oncall: jit"]

import os
import sys
import unittest
from typing import Tuple

import torch
from jit.test_hooks_modules import (
    create_forward_tuple_input,
    create_module_forward_multiple_inputs,
    create_module_forward_single_input,
    create_module_hook_return_nothing,
    create_module_multiple_hooks_multiple_inputs,
    create_module_multiple_hooks_single_input,
    create_module_no_forward_input,
    create_module_same_hook_repeated,
    create_submodule_forward_multiple_inputs,
    create_submodule_forward_single_input,
    create_submodule_forward_single_input_return_not_tupled,
    create_submodule_hook_return_nothing,
    create_submodule_multiple_hooks_multiple_inputs,
    create_submodule_multiple_hooks_single_input,
    create_submodule_no_forward_input,
    create_submodule_same_hook_repeated,
    create_submodule_to_call_directly_with_hooks,
    ModuleDirectforwardSubmodCall,
    ModuleForwardSingleInput,
    ModuleForwardTupleInput,
)


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


# Tests for JIT forward hooks and pre-hooks
class TestHooks(JitTestCase):
    def test_module_no_forward_input(self):
        self.checkModule(create_module_no_forward_input(), ())

    def test_submodule_no_forward_input(self):
        self.checkModule(create_submodule_no_forward_input(), ())

    def test_module_forward_multiple_inputs(self):
        self.checkModule(
            create_module_forward_multiple_inputs(), (["a"], "no_pre_hook")
        )

    def test_module_multiple_hooks_multiple_inputs(self):
        self.checkModule(
            create_module_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook")
        )

    def test_module_forward_single_input(self):
        self.checkModule(create_module_forward_single_input(), ("a",))

    def test_module_same_hook_repeated(self):
        self.checkModule(create_module_same_hook_repeated(), ("a",))

    def test_module_hook_return_nothing(self):
        self.checkModule(create_module_hook_return_nothing(), ("a",))

    def test_module_multiple_hooks_single_input(self):
        self.checkModule(create_module_multiple_hooks_single_input(), ("a",))

    def test_submodule_forward_multiple_inputs(self):
        self.checkModule(
            create_submodule_forward_multiple_inputs(), (["a"], "no_pre_hook")
        )

    def test_submodule_multiple_hooks_multiple_inputs(self):
        self.checkModule(
            create_submodule_multiple_hooks_multiple_inputs(),
            (["a"], "no_pre_hook"),
        )

    def test_submodule_forward_single_input(self):
        self.checkModule(create_submodule_forward_single_input(), ("a",))

    def test_submodule_called_directly_with_hooks(self):
        module = create_submodule_to_call_directly_with_hooks()
        module_scripted = torch.jit.script(module)

        submodule = module.submodule
        scripted_submodule = module_scripted.submodule

        self.assertEqual(submodule("a"), scripted_submodule("a"))

    def test_submodule_same_hook_repeated(self):
        self.checkModule(create_submodule_same_hook_repeated(), ("a",))

    def test_submodule_hook_return_nothing(self):
        self.checkModule(create_submodule_hook_return_nothing(), ("a",))

    def test_submodule_multiple_hooks_single_input(self):
        self.checkModule(create_submodule_multiple_hooks_single_input(), (["a"]))

    def test_forward_tuple_input(self):
        self.checkModule(create_forward_tuple_input(), ((3,),))

    def test_submodule_forward_single_input_return_not_tupled(self):
        self.checkModule(
            create_submodule_forward_single_input_return_not_tupled(), ("a",)
        )

    def test_hook_method_name_collision(self):
        # Hooks can't have the same name as methods.
        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")

        def foo(self, input: Tuple[str]) -> Tuple[str]:
            assert self.name == "inner_mod_name"
            assert input[0] == "a_outermod"
            return ("pre_hook_override_name",)

        m.submodule.register_forward_pre_hook(foo)

        with self.assertRaisesRegex(
            RuntimeError,
            "Can't define hook: foo on class: .+ "
            "because a method or hook with that name already exists.",
        ):
            torch.jit.script(m)

    def test_hook_hook_name_collision(self):
        # Test edge case of two hooks sharing name but not python definition
        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")

        def prehook(self, input: Tuple[str]) -> Tuple[str]:
            return "This is the first hook"

        m.submodule.register_forward_pre_hook(prehook)

        def prehook(self, input: Tuple[str]) -> Tuple[str]:
            return "This is the second hook"

        m.submodule.register_forward_pre_hook(prehook)

        with self.assertRaisesRegex(
            RuntimeError,
            "Pre-hook '.+' on .+ has at least two different python "
            "definitions. Please use unique names for all hooks.",
        ):
            torch.jit.script(m)

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")

        def hook(self, input: Tuple[str], output: str):
            return "This is the first hook"

        m.submodule.register_forward_hook(hook)

        def hook(self, input: Tuple[str]):
            return "This is the second hook"

        m.submodule.register_forward_hook(hook)

        with self.assertRaisesRegex(
            RuntimeError,
            "Hook '.+' on .+ has at least two different python "
            "definitions. Please use unique names for all hooks.",
        ):
            torch.jit.script(m)

    def test_module_direct_forward_invocation(self):
        # Test that hooks are only invoked when the module is
        # called directly and not when forward is called.
        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")

        def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
            return ("pre_hook_override_name",)

        def forward_hook(self, input: Tuple[str], output: str):
            assert self.name == "outer_mod_name"
            assert input == ("pre_hook_override_name",)
            output = output + "_fh"
            return output

        m.register_forward_pre_hook(pre_hook)
        m.register_forward_hook(forward_hook)

        m_scripted = torch.jit.script(m)

        self.assertEqual(m.forward("a"), m_scripted.forward("a"))
        self.assertNotEqual(m_scripted("a"), m_scripted.forward("a"))

    def test_submodule_direct_forward_invocation(self):
        m_submod_forward_call = ModuleDirectforwardSubmodCall(
            "outer_mod_name", "inner_mod_name"
        )
        m_submod_call = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")

        def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
            return ("pre_hook_override_name",)

        def forward_hook(self, input: Tuple[str], output: str):
            assert input == ("pre_hook_override_name",)
            return output + "_fh"

        m_submod_forward_call.submodule.register_forward_pre_hook(pre_hook)
        m_submod_forward_call.submodule.register_forward_hook(forward_hook)
        m_submod_call.submodule.register_forward_pre_hook(pre_hook)
        m_submod_call.submodule.register_forward_hook(forward_hook)

        m_submod_forward_call_scripted = torch.jit.script(m_submod_forward_call)
        m_submod_call_scripted = torch.jit.script(m_submod_call)

        self.assertEqual(
            m_submod_forward_call_scripted("a"), m_submod_forward_call("a")
        )
        self.assertNotEqual(
            m_submod_forward_call_scripted("a"), m_submod_call_scripted("a")
        )

    # TODO: add this test back once figured out how to print error msg
    @unittest.skip
    def test_hook_compilation_hint(self):
        # Tests if hook error message is printed out if erroring after schema check.
        # Useful for when user is scripting hooks while not aware of it.
        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")

        def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
            assert self.name == "outer_mod_name"
            assert input[4] == "a"  # out of bounds tuple range
            return ("pre_hook_override_name",)

        m.register_forward_pre_hook(pre_hook)

        with self.assertRaisesRegex(
            RuntimeError,
            "This error occurred while scripting the forward pre-hook 'pre_hook'",
        ):
            torch.jit.script(m)

    def test_wrong_pre_hook_signatures(self):
        # correct signature: pre_hook_c(self, input: Tuple[str])
        def pre_hook_wrong_input1(self, input: Tuple[None]) -> Tuple[str]:
            return ("hello",)

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_pre_hook(pre_hook_wrong_input1)

        with self.assertRaisesRegex(
            RuntimeError,
            "has the wrong inner types for the input tuple argument",
        ):
            torch.jit.script(m)

        def pre_hook_wrong_input2(self, input: Tuple[str], input2: str) -> Tuple[str]:
            return ("hello",)

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_pre_hook(pre_hook_wrong_input2)

        with self.assertRaisesRegex(
            RuntimeError,
            "was expected to only have exactly 2 inputs but it had 3 inputs",
        ):
            torch.jit.script(m)

        def pre_hook_wrong_input3(self, input: int) -> Tuple[str]:
            return ("hello",)

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_pre_hook(pre_hook_wrong_input3)

        with self.assertRaisesRegex(
            RuntimeError,
            "expected the input argument to be typed as a Tuple but"
            " found type: 'int' instead",
        ):
            torch.jit.script(m)

        def pre_hook_wrong_output(self, input: Tuple[str]) -> int:
            return 1  # expecting Tuple[str], str, or None

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_pre_hook(pre_hook_wrong_output)

        with self.assertRaisesRegex(
            RuntimeError,
            "returned the wrong type of: 'int'",
        ):
            torch.jit.script(m)

        def pre_hook_no_output_annotation(self, input: Tuple[str]):
            return 1  # expecting Tuple[str], str, or None

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_pre_hook(pre_hook_no_output_annotation)

        with self.assertRaisesRegex(
            RuntimeError,
            "is missing a return annotation. Return annotations"
            " are required, please add one.",
        ):
            torch.jit.script(m)

        def pre_hook_wrong_tuple_return(self, input: Tuple[Tuple[int]]) -> Tuple[int]:
            return (11,)  # doesn't work with eager, inner tuple lost

        m = ModuleForwardTupleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_pre_hook(pre_hook_wrong_tuple_return)

        with self.assertRaisesRegex(
            RuntimeError,
            "When forward has a single tuple input argument, "
            "the return needs to be 'None' or a nested tuple containing "
            r"forward's input tuple argument as in: 'Tuple\[Tuple\[int\]\]'",
        ):
            torch.jit.script(m)

    def test_wrong_hook_signatures(self):
        # correct signature:
        #   def forward_hook(self, input: Tuple[str], output: str)
        def forward_hook_wrong_input1(self, input: Tuple[str, str], output: str):
            return output

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_hook(forward_hook_wrong_input1)

        with self.assertRaisesRegex(
            RuntimeError,
            "has the wrong number of contained types for the "
            r"input argument's Tuple. Received type: 'Tuple\[str, str\]'",
        ):
            torch.jit.script(m)

        def forward_hook_wrong_input2(self, input: str, output: str):
            return output

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_hook(forward_hook_wrong_input2)

        with self.assertRaisesRegex(
            RuntimeError,
            "expected the input argument to be typed as a Tuple "
            "but found type: 'str' instead.",
        ):
            torch.jit.script(m)

        def forward_hook_wrong_input3(self, input: Tuple[None], output: str):
            return output

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_hook(forward_hook_wrong_input3)

        with self.assertRaisesRegex(
            RuntimeError,
            "has the wrong inner types for the input tuple"
            r" argument. Received type: 'Tuple\[NoneType\]'",
        ):
            torch.jit.script(m)

        def forward_hook_wrong_output(self, input: Tuple[str], output: Tuple[str]):
            return output

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_hook(forward_hook_wrong_output)

        with self.assertRaisesRegex(
            RuntimeError,
            "has the wrong type for the output argument. Received"
            r" type: 'Tuple\[str\]'. Expected type: 'str'",
        ):
            torch.jit.script(m)

        def forward_hook_correct(self, input: Tuple[str], output: str):
            return (output,)

        def forward_hook_wrong_output_from_prev_hook(
            self, input: Tuple[str], output: str
        ):
            return output

        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
        m.register_forward_hook(forward_hook_correct)
        m.register_forward_hook(forward_hook_wrong_output_from_prev_hook)

        with self.assertRaisesRegex(
            RuntimeError,
            "has the wrong type for the output argument. "
            r"Received type: 'str'. Expected type: 'Tuple\[str\]'",
        ):
            torch.jit.script(m)
