# Owner(s): ["oncall: jit"]

import copy
import io
import os
import sys
import unittest
from typing import Optional

import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
    find_library_location,
    IS_FBCODE,
    IS_MACOS,
    IS_SANDCASTLE,
    IS_WINDOWS,
)
from torch.testing._internal.jit_utils import JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


@skipIfTorchDynamo("skipping as a precaution")
class TestTorchbind(JitTestCase):
    def setUp(self):
        if IS_SANDCASTLE or IS_MACOS or IS_FBCODE:
            raise unittest.SkipTest("non-portable load_library call used in test")
        lib_file_path = find_library_location("libtorchbind_test.so")
        if IS_WINDOWS:
            lib_file_path = find_library_location("torchbind_test.dll")
        torch.ops.load_library(str(lib_file_path))

    def test_torchbind(self):
        def test_equality(f, cmp_key):
            obj1 = f()
            obj2 = torch.jit.script(f)()
            return (cmp_key(obj1), cmp_key(obj2))

        def f():
            val = torch.classes._TorchScriptTesting._Foo(5, 3)
            val.increment(1)
            return val

        test_equality(f, lambda x: x)

        with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
            val = torch.classes._TorchScriptTesting._Foo(5, 3)
            val.increment("foo")

        def f():
            ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
            return ss.pop()

        test_equality(f, lambda x: x)

        def f():
            ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
            ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
            ss1.push(ss2.pop())
            return ss1.pop() + ss2.pop()

        test_equality(f, lambda x: x)

        # test nn module with prepare_scriptable function
        class NonJitableClass:
            def __init__(self, int1, int2):
                self.int1 = int1
                self.int2 = int2

            def return_vals(self):
                return self.int1, self.int2

        class CustomWrapper(torch.nn.Module):
            def __init__(self, foo):
                super().__init__()
                self.foo = foo

            def forward(self) -> None:
                self.foo.increment(1)
                return

            def __prepare_scriptable__(self):
                int1, int2 = self.foo.return_vals()
                foo = torch.classes._TorchScriptTesting._Foo(int1, int2)
                return CustomWrapper(foo)

        foo = CustomWrapper(NonJitableClass(1, 2))
        jit_foo = torch.jit.script(foo)

    def test_torchbind_take_as_arg(self):
        global StackString  # see [local resolution in python]
        StackString = torch.classes._TorchScriptTesting._StackString

        def foo(stackstring):
            # type: (StackString)
            stackstring.push("lel")
            return stackstring

        script_input = torch.classes._TorchScriptTesting._StackString([])
        scripted = torch.jit.script(foo)
        script_output = scripted(script_input)
        self.assertEqual(script_output.pop(), "lel")

    def test_torchbind_return_instance(self):
        def foo():
            ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
            return ss

        scripted = torch.jit.script(foo)
        # Ensure we are creating the object and calling __init__
        # rather than calling the __init__wrapper nonsense
        fc = (
            FileCheck()
            .check("prim::CreateObject()")
            .check('prim::CallMethod[name="__init__"]')
        )
        fc.run(str(scripted.graph))
        out = scripted()
        self.assertEqual(out.pop(), "mom")
        self.assertEqual(out.pop(), "hi")

    def test_torchbind_return_instance_from_method(self):
        def foo():
            ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
            clone = ss.clone()
            ss.pop()
            return ss, clone

        scripted = torch.jit.script(foo)
        out = scripted()
        self.assertEqual(out[0].pop(), "hi")
        self.assertEqual(out[1].pop(), "mom")
        self.assertEqual(out[1].pop(), "hi")

    def test_torchbind_def_property_getter_setter(self):
        def foo_getter_setter_full():
            fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
            # getX method intentionally adds 2 to x
            old = fooGetterSetter.x
            # setX method intentionally adds 2 to x
            fooGetterSetter.x = old + 4
            new = fooGetterSetter.x
            return old, new

        self.checkScript(foo_getter_setter_full, ())

        def foo_getter_setter_lambda():
            foo = torch.classes._TorchScriptTesting._FooGetterSetterLambda(5)
            old = foo.x
            foo.x = old + 4
            new = foo.x
            return old, new

        self.checkScript(foo_getter_setter_lambda, ())

    def test_torchbind_def_property_just_getter(self):
        def foo_just_getter():
            fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
            # getY method intentionally adds 4 to x
            return fooGetterSetter, fooGetterSetter.y

        scripted = torch.jit.script(foo_just_getter)
        out, result = scripted()
        self.assertEqual(result, 10)

        with self.assertRaisesRegex(RuntimeError, "can't set attribute"):
            out.y = 5

        def foo_not_setter():
            fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
            old = fooGetterSetter.y
            fooGetterSetter.y = old + 4
            # getY method intentionally adds 4 to x
            return fooGetterSetter.y

        with self.assertRaisesRegexWithHighlight(
            RuntimeError,
            "Tried to set read-only attribute: y",
            "fooGetterSetter.y = old + 4",
        ):
            scripted = torch.jit.script(foo_not_setter)

    def test_torchbind_def_property_readwrite(self):
        def foo_readwrite():
            fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
            old = fooReadWrite.x
            fooReadWrite.x = old + 4
            return fooReadWrite.x, fooReadWrite.y

        self.checkScript(foo_readwrite, ())

        def foo_readwrite_error():
            fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
            fooReadWrite.y = 5
            return fooReadWrite

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Tried to set read-only attribute: y", "fooReadWrite.y = 5"
        ):
            scripted = torch.jit.script(foo_readwrite_error)

    def test_torchbind_take_instance_as_method_arg(self):
        def foo():
            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
            ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
            ss.merge(ss2)
            return ss

        scripted = torch.jit.script(foo)
        out = scripted()
        self.assertEqual(out.pop(), "hi")
        self.assertEqual(out.pop(), "mom")

    def test_torchbind_return_tuple(self):
        def f():
            val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
            return val.return_a_tuple()

        scripted = torch.jit.script(f)
        tup = scripted()
        self.assertEqual(tup, (1337.0, 123))

    def test_torchbind_save_load(self):
        def foo():
            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
            ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
            ss.merge(ss2)
            return ss

        scripted = torch.jit.script(foo)
        self.getExportImportCopy(scripted)

    def test_torchbind_lambda_method(self):
        def foo():
            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
            return ss.top()

        scripted = torch.jit.script(foo)
        self.assertEqual(scripted(), "mom")

    def test_torchbind_class_attr_recursive(self):
        class FooBar(torch.nn.Module):
            def __init__(self, foo_model):
                super().__init__()
                self.foo_mod = foo_model

            def forward(self) -> int:
                return self.foo_mod.info()

            def to_ivalue(self):
                torchbind_model = torch.classes._TorchScriptTesting._Foo(
                    self.foo_mod.info(), 1
                )
                return FooBar(torchbind_model)

        inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3))
        scripted = torch.jit.script(inst.to_ivalue())
        self.assertEqual(scripted(), 6)

    def test_torchbind_class_attribute(self):
        class FooBar1234(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])

            def forward(self):
                return self.f.top()

        inst = FooBar1234()
        scripted = torch.jit.script(inst)
        eic = self.getExportImportCopy(scripted)
        assert eic() == "deserialized"
        for expected in ["deserialized", "was", "i"]:
            assert eic.f.pop() == expected

    def test_torchbind_getstate(self):
        class FooBar4321(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])

            def forward(self):
                return self.f.top()

        inst = FooBar4321()
        scripted = torch.jit.script(inst)
        eic = self.getExportImportCopy(scripted)
        # NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
        # return {1, 3, 3, 7}. I tried to make this actually depend on the
        # values at instantiation in the test with some transformation, but
        # because it seems we serialize/deserialize multiple times, that
        # transformation isn't as you would it expect it to be.
        assert eic() == 7
        for expected in [7, 3, 3, 1]:
            assert eic.f.pop() == expected

    def test_torchbind_deepcopy(self):
        class FooBar4321(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])

            def forward(self):
                return self.f.top()

        inst = FooBar4321()
        scripted = torch.jit.script(inst)
        copied = copy.deepcopy(scripted)
        assert copied.forward() == 7
        for expected in [7, 3, 3, 1]:
            assert copied.f.pop() == expected

    def test_torchbind_python_deepcopy(self):
        class FooBar4321(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])

            def forward(self):
                return self.f.top()

        inst = FooBar4321()
        copied = copy.deepcopy(inst)
        assert copied() == 7
        for expected in [7, 3, 3, 1]:
            assert copied.f.pop() == expected

    def test_torchbind_tracing(self):
        class TryTracing(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])

            def forward(self):
                return torch.ops._TorchScriptTesting.take_an_instance(self.f)

        traced = torch.jit.trace(TryTracing(), ())
        self.assertEqual(torch.zeros(4, 4), traced())

    def test_torchbind_pass_wrong_type(self):
        with self.assertRaisesRegex(RuntimeError, "but instead found type 'Tensor'"):
            torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4))

    def test_torchbind_tracing_nested(self):
        class TryTracingNest(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])

        class TryTracing123(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.nest = TryTracingNest()

            def forward(self):
                return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)

        traced = torch.jit.trace(TryTracing123(), ())
        self.assertEqual(torch.zeros(4, 4), traced())

    def test_torchbind_pickle_serialization(self):
        nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
        b = io.BytesIO()
        torch.save(nt, b)
        b.seek(0)
        # weights_only=False as trying to load ScriptObject
        nt_loaded = torch.load(b, weights_only=False)
        for exp in [7, 3, 3, 1]:
            self.assertEqual(nt_loaded.pop(), exp)

    def test_torchbind_instantiate_missing_class(self):
        with self.assertRaisesRegex(
            RuntimeError,
            "Tried to instantiate class 'foo.IDontExist', but it does not exist!",
        ):
            torch.classes.foo.IDontExist(3, 4, 5)

    def test_torchbind_optional_explicit_attr(self):
        class TorchBindOptionalExplicitAttr(torch.nn.Module):
            foo: Optional[torch.classes._TorchScriptTesting._StackString]

            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.classes._TorchScriptTesting._StackString(["test"])

            def forward(self) -> str:
                foo_obj = self.foo
                if foo_obj is not None:
                    return foo_obj.pop()
                else:
                    return "<None>"

        mod = TorchBindOptionalExplicitAttr()
        scripted = torch.jit.script(mod)

    def test_torchbind_no_init(self):
        with self.assertRaisesRegex(RuntimeError, "torch::init"):
            x = torch.classes._TorchScriptTesting._NoInit()

    def test_profiler_custom_op(self):
        inst = torch.classes._TorchScriptTesting._PickleTester([3, 4])

        with torch.autograd.profiler.profile() as prof:
            torch.ops._TorchScriptTesting.take_an_instance(inst)

        found_event = False
        for e in prof.function_events:
            if e.name == "_TorchScriptTesting::take_an_instance":
                found_event = True
        self.assertTrue(found_event)

    def test_torchbind_getattr(self):
        foo = torch.classes._TorchScriptTesting._StackString(["test"])
        self.assertEqual(None, getattr(foo, "bar", None))

    def test_torchbind_attr_exception(self):
        foo = torch.classes._TorchScriptTesting._StackString(["test"])
        with self.assertRaisesRegex(AttributeError, "does not have a field"):
            foo.bar

    def test_lambda_as_constructor(self):
        obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
        self.assertEqual(obj_no_swap.diff(), 1)

        obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
        self.assertEqual(obj_swap.diff(), -1)

    def test_staticmethod(self):
        def fn(inp: int) -> int:
            return torch.classes._TorchScriptTesting._StaticMethod.staticMethod(inp)

        self.checkScript(fn, (1,))

    def test_default_args(self):
        def fn() -> int:
            obj = torch.classes._TorchScriptTesting._DefaultArgs()
            obj.increment(5)
            obj.decrement()
            obj.decrement(2)
            obj.divide()
            obj.scale_add(5)
            obj.scale_add(3, 2)
            obj.divide(3)
            return obj.increment()

        self.checkScript(fn, ())

        def gn() -> int:
            obj = torch.classes._TorchScriptTesting._DefaultArgs(5)
            obj.increment(3)
            obj.increment()
            obj.decrement(2)
            obj.divide()
            obj.scale_add(3)
            obj.scale_add(3, 2)
            obj.divide(2)
            return obj.decrement()

        self.checkScript(gn, ())
