# Owner(s): ["module: unknown"]

import os.path
import sys
import tempfile
import unittest

from model import get_custom_op_library_path, Model

import torch
import torch._library.utils as utils
from torch import ops
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase


torch.ops.import_module("pointwise")


class TestCustomOperators(TestCase):
    def setUp(self):
        self.library_path = get_custom_op_library_path()
        ops.load_library(self.library_path)

    def test_custom_library_is_loaded(self):
        self.assertIn(self.library_path, ops.loaded_libraries)

    def test_op_with_no_abstract_impl_pystub(self):
        x = torch.randn(3, device="meta")
        if utils.requires_set_python_module():
            with self.assertRaisesRegex(RuntimeError, "pointwise"):
                torch.ops.custom.tan(x)
        else:
            # Smoketest
            torch.ops.custom.tan(x)

    def test_op_with_incorrect_abstract_impl_pystub(self):
        x = torch.randn(3, device="meta")
        with self.assertRaisesRegex(RuntimeError, "pointwise"):
            torch.ops.custom.cos(x)

    @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
    def test_dynamo_pystub_suggestion(self):
        x = torch.randn(3)

        @torch.compile(backend="eager", fullgraph=True)
        def f(x):
            return torch.ops.custom.asin(x)

        with self.assertRaisesRegex(
            RuntimeError,
            r"unsupported operator: .* you may need to `import nonexistent`",
        ):
            f(x)

    def test_abstract_impl_pystub_faketensor(self):
        from functorch import make_fx

        x = torch.randn(3, device="cpu")
        self.assertNotIn("my_custom_ops", sys.modules.keys())

        with self.assertRaises(
            torch._subclasses.fake_tensor.UnsupportedOperatorException
        ):
            gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)

        torch.ops.import_module("my_custom_ops")
        gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
        self.assertExpectedInline(
            """\
def forward(self, arg0_1):
    nonzero = torch.ops.custom.nonzero.default(arg0_1);  arg0_1 = None
    return nonzero
""".strip(),
            gm.code.strip(),
        )

    def test_abstract_impl_pystub_meta(self):
        x = torch.randn(3, device="meta")
        self.assertNotIn("my_custom_ops2", sys.modules.keys())
        with self.assertRaisesRegex(NotImplementedError, r"'my_custom_ops2'"):
            y = torch.ops.custom.sin.default(x)
        torch.ops.import_module("my_custom_ops2")
        y = torch.ops.custom.sin.default(x)

    def test_calling_custom_op_string(self):
        output = ops.custom.op2("abc", "def")
        self.assertLess(output, 0)
        output = ops.custom.op2("abc", "abc")
        self.assertEqual(output, 0)

    def test_calling_custom_op(self):
        output = ops.custom.op(torch.ones(5), 2.0, 3)
        self.assertEqual(type(output), list)
        self.assertEqual(len(output), 3)
        for tensor in output:
            self.assertTrue(tensor.allclose(torch.ones(5) * 2))

        output = ops.custom.op_with_defaults(torch.ones(5))
        self.assertEqual(type(output), list)
        self.assertEqual(len(output), 1)
        self.assertTrue(output[0].allclose(torch.ones(5)))

    def test_calling_custom_op_with_autograd(self):
        x = torch.randn((5, 5), requires_grad=True)
        y = torch.randn((5, 5), requires_grad=True)
        output = ops.custom.op_with_autograd(x, 2, y)
        self.assertTrue(output.allclose(x + 2 * y + x * y))

        go = torch.ones((), requires_grad=True)
        output.sum().backward(go, False, True)
        grad = torch.ones(5, 5)

        self.assertEqual(x.grad, y + grad)
        self.assertEqual(y.grad, x + grad * 2)

        # Test with optional arg.
        x.grad.zero_()
        y.grad.zero_()
        z = torch.randn((5, 5), requires_grad=True)
        output = ops.custom.op_with_autograd(x, 2, y, z)
        self.assertTrue(output.allclose(x + 2 * y + x * y + z))

        go = torch.ones((), requires_grad=True)
        output.sum().backward(go, False, True)
        self.assertEqual(x.grad, y + grad)
        self.assertEqual(y.grad, x + grad * 2)
        self.assertEqual(z.grad, grad)

    def test_calling_custom_op_with_autograd_in_nograd_mode(self):
        with torch.no_grad():
            x = torch.randn((5, 5), requires_grad=True)
            y = torch.randn((5, 5), requires_grad=True)
            output = ops.custom.op_with_autograd(x, 2, y)
            self.assertTrue(output.allclose(x + 2 * y + x * y))

    def test_calling_custom_op_inside_script_module(self):
        model = Model()
        output = model.forward(torch.ones(5))
        self.assertTrue(output.allclose(torch.ones(5) + 1))

    def test_saving_and_loading_script_module_with_custom_op(self):
        model = Model()
        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
        # close the file after creation and try to remove it manually.
        file = tempfile.NamedTemporaryFile(delete=False)
        try:
            file.close()
            model.save(file.name)
            loaded = torch.jit.load(file.name)
        finally:
            os.unlink(file.name)

        output = loaded.forward(torch.ones(5))
        self.assertTrue(output.allclose(torch.ones(5) + 1))


if __name__ == "__main__":
    run_tests()
