# Owner(s): ["module: dynamo"]

import collections
import re
import sys
import time
from io import StringIO

import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.comptime import comptime


# Because we don't support free variables in comptime at the moment,
# we have to communicate via globals.  This also means these tests cannot
# be run in parallel in a single process (not that you'd... ever want
# to do that?)
FILE = None
SELF = None


class ComptimeTests(torch._dynamo.test_case.TestCase):
    def test_print_single(self):
        global FILE
        FILE = StringIO()
        cnt = torch._dynamo.testing.CompileCounter()

        def comptime_print(e):
            @comptime
            def _(ctx):
                ctx.print(ctx.get_local("e"), file=FILE)

        Employee = collections.namedtuple("Employee", ["name", "id"])

        class mylist(list):
            pass

        @torch._dynamo.optimize(cnt, dynamic=True)
        def f(x):
            y = x * 2
            comptime_print(y)
            comptime_print(2)
            comptime_print([y, 2])
            comptime_print((y, 2))
            comptime_print({"foo": y})
            comptime_print(range(1, 3))
            comptime_print(Employee("foo", 2))
            comptime_print(mylist([1, 2]))
            comptime_print(collections.defaultdict(lambda: None))
            comptime_print(set())
            comptime_print({"a", "b"})
            comptime_print(x.size(0))
            return y + 3

        f(torch.randn(2))
        self.assertEqual(cnt.frame_count, 1)
        self.assertExpectedInline(
            FILE.getvalue().strip(),
            """\
FakeTensor(..., size=(s0,))
2
[FakeTensor(..., size=(s0,)), 2]
(FakeTensor(..., size=(s0,)), 2)
{'foo': FakeTensor(..., size=(s0,))}
range(1, 3, 1)
Employee(name='foo', id=2)
[1, 2]
defaultdict(NestedUserFunctionVariable(), {})
set()
{'a','b'}
s0""",
        )

    def test_print_graph(self):
        global FILE
        FILE = StringIO()
        cnt = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnt)
        def f(x):
            y = x * 2

            @comptime
            def _(ctx):
                ctx.print_graph(verbose=False, file=FILE)

            # Test the compact notation doesn't error or graph break;
            # you'll have to visually inspect to see that it printed
            comptime.print_graph()

            return y + 3

        f(torch.randn(2))
        self.assertEqual(cnt.frame_count, 1)
        self.assertExpectedInline(
            FILE.getvalue().strip(),
            """\
def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    y = l_x_ * 2;  l_x_ = y = None""",
        )

    def test_print_disas(self):
        global FILE
        FILE = StringIO()
        cnt = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnt)
        def f(x):
            y = x * 2

            @comptime
            def _(ctx):
                ctx.print_disas(file=FILE)

            comptime.print_disas()

            return y + 3

        def munge_disas(s):
            re.sub(
                r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)",
                "\1 \3",
                s,
                flags=re.MULTILINE,
            )

        f(torch.randn(2))
        self.assertEqual(cnt.frame_count, 1)
        out = FILE.getvalue()
        # Check that the instruction offset is working
        self.assertIn("-->", out)
        # Check that the bytecode resembles what we expect
        self.assertIn("STORE_FAST", out)
        if sys.version_info < (3, 11):
            self.assertIn("BINARY_MULTIPLY", out)
        else:
            self.assertIn("BINARY_OP", out)

    def test_print_value_stack(self):
        global FILE
        FILE = StringIO()
        cnt = torch._dynamo.testing.CompileCounter()

        def g(x):
            @comptime
            def _(ctx):
                ctx.print_value_stack(file=FILE, stacklevel=1)

            return x

        @torch._dynamo.optimize(cnt)
        def f(x):
            y = x + g(x)

            return y + comptime.print_value_stack_and_return(y * 2)

        f(torch.randn(2))
        self.assertEqual(cnt.frame_count, 1)
        self.assertExpectedInline(
            FILE.getvalue(),
            """\
- FakeTensor(..., size=(2,))
""",
        )

    def test_print_locals(self):
        global FILE
        FILE = StringIO()
        cnt = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnt)
        def f(x):
            y = x * 2

            @comptime
            def _(ctx):
                ctx.print_locals(file=FILE)

            comptime.print_locals()

            return y + 3

        f(torch.randn(2))
        self.assertEqual(cnt.frame_count, 1)
        self.assertExpectedInline(
            FILE.getvalue(),
            """\
x = FakeTensor(..., size=(2,))
y = FakeTensor(..., size=(2,))
""",
        )

    # Just make sure it doesn't crash
    def test_print_direct(self):
        cnt = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnt)
        def f(x, z):
            y = x * 2
            lambda: z
            comptime.print(z)
            return y + 3

        f(torch.randn(2), torch.randn(2))

    def test_sleep(self):
        sleep_time = 5
        cnt = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnt)
        def f(x, z, should_sleep):
            if should_sleep:
                comptime.sleep(sleep_time)
            y = x * 2
            return y + 3

        start = time.time()
        f(torch.randn(2), torch.randn(2), False)
        total_no_sleep = time.time() - start

        start = time.time()
        f(torch.randn(2), torch.randn(2), True)
        total_with_sleep = time.time() - start

        self.assertTrue(total_with_sleep > sleep_time)
        # Hopefully this won't be flaky
        self.assertTrue(abs(total_with_sleep - sleep_time - total_no_sleep) < 3)

    # Just make sure it doesn't crash
    def test_get_local_closure_variable(self):
        global SELF
        SELF = self
        cnt = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnt)
        def f(x):
            z = 3

            def g():
                @comptime
                def _(ctx):
                    r = ctx.get_local("z")
                    SELF.assertEqual(repr(r), "3")

                comptime.print(z)
                return 2

            y = x * g()
            return y + 3

        f(torch.randn(2))

    def test_print_bt(self):
        global FILE
        FILE = StringIO()
        cnt = torch._dynamo.testing.CompileCounter()

        def g(x):
            @comptime
            def _(ctx):
                ctx.print_bt(file=FILE)

            comptime.print_bt()

            return x + 3

        @torch._dynamo.optimize(cnt)
        def f(x):
            y = x * 2
            y = g(y)
            return y + 3

        def munge_filenames(s):
            return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s)

        f(torch.randn(2))
        self.assertEqual(cnt.frame_count, 1)
        bt = FILE.getvalue()
        self.assertIn("y = g(y)", bt)

    def test_print_guards(self):
        global FILE
        FILE = StringIO()
        cnt = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnt)
        def f(x):
            y = x * 2

            @comptime
            def _(ctx):
                ctx.print_guards(file=FILE)

            comptime.print_guards()

            return y + 3

        f(torch.randn(2))
        self.assertEqual(cnt.frame_count, 1)
        self.assertExpectedInline(
            re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE),
            """\

        local "L['x']" TENSOR_MATCH
        {
            'guard_types': None,
            'code': None,
            'obj_weakref': None
            'guarded_class': None
        }
        global '' GRAD_MODE
        {
            'guard_types': None,
            'code': None,
            'obj_weakref': None
            'guarded_class': None
        }
        global '' DETERMINISTIC_ALGORITHMS
        {
            'guard_types': None,
            'code': None,
            'obj_weakref': None
            'guarded_class': None
        }
        global '' TORCH_FUNCTION_STATE
        {
            'guard_types': None,
            'code': None,
            'obj_weakref': None
            'guarded_class': None
        }
        global '' DEFAULT_DEVICE
        {
            'guard_types': None,
            'code': None,
            'obj_weakref': None
            'guarded_class': None
        }
        shape_env '' SHAPE_ENV
        {
            'guard_types': None,
            'code': None,
            'obj_weakref': None
            'guarded_class': None
        }""",
        )

    def test_graph_break(self):
        cnt = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnt)
        def f(x):
            y = x * 2

            @comptime
            def _(ctx):
                pass

            return y + 3

        f(torch.randn(2))
        self.assertEqual(cnt.frame_count, 1)
        cnt.frame_count = 0

        @torch._dynamo.optimize(cnt)
        def g(x):
            y = x * 2

            @comptime
            def _(ctx):
                ctx.graph_break()

            y = y + 2

            comptime.graph_break()

            return y * 3

        g(torch.randn(2))
        self.assertEqual(cnt.frame_count, 3)

    def test_get_local(self):
        global SELF, FILE
        SELF = self
        FILE = StringIO()
        cnt = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnt)
        def f(x):
            y = x * 2
            lit = 2

            @comptime
            def _(ctx):
                y = ctx.get_local("y")
                SELF.assertEqual(y.as_fake().size(0), 2)
                SELF.assertEqual(y.size(0), 2)
                # Trigger a graph write (TODO: this is not so
                # useful right now as there's no way to make use
                # of the output proxy; maybe it's useful for inserting
                # side-effectful operations into the graph)
                y.as_proxy() + 4
                ctx.print_graph(verbose=False, file=FILE)
                SELF.assertIs(y.python_type(), torch.Tensor)
                lit = ctx.get_local("lit")
                SELF.assertEqual(lit.as_python_constant(), 2)

            return y + 3

        f(torch.randn(2))
        self.assertEqual(cnt.frame_count, 1)
        self.assertExpectedInline(
            FILE.getvalue().strip(),
            """\
def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    y = l_x_ * 2;  l_x_ = None
    add = y + 4;  y = add = None""",
        )


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
