# Owner(s): ["oncall: jit"]
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase


r"""
Test TorchScript exception handling.
"""


class TestException(TestCase):
    def test_pyop_exception_message(self):
        class Foo(torch.jit.ScriptModule):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(1, 10, kernel_size=5)

            @torch.jit.script_method
            def forward(self, x):
                return self.conv(x)

        foo = Foo()
        # testing that the correct error message propagates
        with self.assertRaisesRegex(
            RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"
        ):
            foo(torch.ones([123]))  # wrong size

    def test_builtin_error_messsage(self):
        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):

            @torch.jit.script
            def close_match(x):
                return x.masked_fill(True)

        with self.assertRaisesRegex(
            RuntimeError,
            "This op may not exist or may not be currently " "supported in TorchScript",
        ):

            @torch.jit.script
            def unknown_op(x):
                torch.set_anomaly_enabled(True)
                return x

    def test_exceptions(self):
        cu = torch.jit.CompilationUnit(
            """
            def foo(cond):
                if bool(cond):
                    raise ValueError(3)
                return 1
        """
        )

        cu.foo(torch.tensor(0))
        with self.assertRaisesRegex(torch.jit.Error, "3"):
            cu.foo(torch.tensor(1))

        def foo(cond):
            a = 3
            if bool(cond):
                raise ArbitraryError(a, "hi")  # noqa: F821
                if 1 == 2:
                    raise ArbitraryError  # noqa: F821
            return a

        with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
            torch.jit.script(foo)

        def exception_as_value():
            a = Exception()
            print(a)

        with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
            torch.jit.script(exception_as_value)

        @torch.jit.script
        def foo_no_decl_always_throws():
            raise RuntimeError("Hi")

        # function that has no declared type but always throws set to None
        output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
        self.assertTrue(str(output_type) == "NoneType")

        @torch.jit.script
        def foo_decl_always_throws():
            # type: () -> Tensor
            raise Exception("Hi")  # noqa: TRY002

        output_type = next(foo_decl_always_throws.graph.outputs()).type()
        self.assertTrue(str(output_type) == "Tensor")

        def foo():
            raise 3 + 4

        with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
            torch.jit.script(foo)

        # a escapes scope
        @torch.jit.script
        def foo():
            if 1 == 1:
                a = 1
            else:
                if 1 == 1:
                    raise Exception("Hi")  # noqa: TRY002
                else:
                    raise Exception("Hi")  # noqa: TRY002
            return a

        self.assertEqual(foo(), 1)

        @torch.jit.script
        def tuple_fn():
            raise RuntimeError("hello", "goodbye")

        with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
            tuple_fn()

        @torch.jit.script
        def no_message():
            raise RuntimeError

        with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
            no_message()

    def test_assertions(self):
        cu = torch.jit.CompilationUnit(
            """
            def foo(cond):
                assert bool(cond), "hi"
                return 0
        """
        )

        cu.foo(torch.tensor(1))
        with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
            cu.foo(torch.tensor(0))

        @torch.jit.script
        def foo(cond):
            assert bool(cond), "hi"

        foo(torch.tensor(1))
        # we don't currently validate the name of the exception
        with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
            foo(torch.tensor(0))

    def test_python_op_exception(self):
        @torch.jit.ignore
        def python_op(x):
            raise Exception("bad!")  # noqa: TRY002

        @torch.jit.script
        def fn(x):
            return python_op(x)

        with self.assertRaisesRegex(
            RuntimeError, "operation failed in the TorchScript interpreter"
        ):
            fn(torch.tensor(4))

    def test_dict_expansion_raises_error(self):
        def fn(self):
            d = {"foo": 1, "bar": 2, "baz": 3}
            return {**d}

        with self.assertRaisesRegex(
            torch.jit.frontend.NotSupportedError, "Dict expansion "
        ):
            torch.jit.script(fn)

    def test_custom_python_exception(self):
        class MyValueError(ValueError):
            pass

        @torch.jit.script
        def fn():
            raise MyValueError("test custom exception")

        with self.assertRaisesRegex(
            torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"
        ):
            fn()

    def test_custom_python_exception_defined_elsewhere(self):
        from jit.myexception import MyKeyError

        @torch.jit.script
        def fn():
            raise MyKeyError("This is a user defined key error")

        with self.assertRaisesRegex(
            torch.jit.Error,
            "jit.myexception.MyKeyError: This is a user defined key error",
        ):
            fn()
