# Owner(s): ["module: dynamo"]

import collections
import dis
import sys
import unittest

import torch
import torch._dynamo.test_case
from torch._dynamo import bytecode_analysis, bytecode_transformation
from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312


class BytecodeTests(torch._dynamo.test_case.TestCase):
    @skipIfNotPy311
    def test_linetable_311_writer1(self):
        def fn():
            a = 10
            b = 20
            # prevent LOAD_FAST_LOAD_FAST in 3.13 by wrapping b with g()
            c = a + g(b)
            f = "linetable_writer"
            return f"Test if {f} generates correct co_linetable: {c}"

        keys = bytecode_transformation.get_code_keys()
        code_options = {k: getattr(fn.__code__, k) for k in keys}
        result = bytecode_transformation.clean_and_assemble_instructions(
            bytecode_transformation.cleaned_instructions(fn.__code__),
            keys,
            code_options,
        )
        l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
        self.assertEqual(len(l1), len(l2))
        for p1, p2 in zip(l1, l2):
            self.assertEqual(p1, p2)
        # TODO co_lnotab is deprecated in 3.12 and will be removed in 3.14
        # In 3.11+,. it is computed lazily from other linetable attributes (e.g. co_linetable),
        # so we do not set this attribute ourselves.
        self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)

    @skipIfNotPy311
    def test_linetable_311_writer2(self):
        """
        test large ops (LOAD_METHOD) and EXTENDED_ARGS
        fn_str is in the form:
        def fn():
            ...
            x0 = 1
            x1 = 1
            ...
            l = [x0, x1, ...]
        """
        fn_str = f"""\
def fn():
    foo.bar(1, 2, 3)
{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))}
    l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}]
        """
        locals = {}
        exec(fn_str, {}, locals)
        fn = locals["fn"]
        orig_inst_str = "\n".join(list(map(str, dis.get_instructions(fn))))
        self.assertIn("EXTENDED_ARG", orig_inst_str)
        load_method_str = "LOAD_ATTR" if sys.version_info >= (3, 12) else "LOAD_METHOD"
        self.assertIn(load_method_str, orig_inst_str)
        keys = bytecode_transformation.get_code_keys()
        code_options = {k: getattr(fn.__code__, k) for k in keys}
        result = bytecode_transformation.clean_and_assemble_instructions(
            bytecode_transformation.cleaned_instructions(fn.__code__),
            keys,
            code_options,
        )
        new_inst_str = "\n".join(list(map(str, result[0])))
        self.assertIn("EXTENDED_ARG", new_inst_str)
        self.assertIn(load_method_str, new_inst_str)
        l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
        self.assertEqual(len(l1), len(l2))
        for p1, p2 in zip(l1, l2):
            self.assertEqual(p1, p2)
        self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)

    @unittest.skipIf(
        sys.version_info < (3, 10) or sys.version_info >= (3, 11),
        "linetable test for Python 3.10",
    )
    def test_linetable_310_writer(self):
        def fn():
            a = 10
            b = 20
            c = a + b
            f = "linetable_writer"
            return f"Test if {f} generates correct co_linetable: {c}"

        inst = dis.get_instructions(fn)
        result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
        self.assertTrue(result[1] == fn.__code__.co_linetable)

    @unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10")
    def test_lnotab_writer(self):
        def fn():
            a = 10
            b = 20
            c = a + b
            f = "lnotab_writer"
            return f"Test if {f} generates correct co_lnotab: {c}"

        inst = dis.get_instructions(fn)
        result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
        self.assertTrue(result[1] == fn.__code__.co_lnotab)

    def test_if_tensor_is_none(self):
        """
        Python 3.11 adds new jump instructions that check if
        TOS is None. We do not support these instructions.
        """

        def f(x, y):
            z = 1
            if x is None:
                z *= 2
            if y is not None:
                z *= 3
            return z

        opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
        self.assertEqual(opt_f(None, torch.ones(2)), 6)

        if sys.version_info >= (3, 11):
            insts = bytecode_transformation.cleaned_instructions(f.__code__)
            for inst in insts:
                self.assertNotIn("_NONE", inst.opname)

    @skipIfNotPy311
    def test_py311_jump_offset(self):
        new_inst = bytecode_transformation.create_instruction
        consts = (None, 1, 2, 3, 4)

        def create_test_code(jump_opname, target_idx):
            targets = [
                new_inst("LOAD_CONST", argval=1),
                new_inst("LOAD_CONST", argval=3),
            ]
            jump_to_target_inst = new_inst(jump_opname, target=targets[target_idx])
            """
            pseudocode of generated bytecode:
            def test_py311_fn():
                goto target1
            target0:
                return 1
            target1:
                goto [target0/target2] (via fwd or bwd jump)
                return 2
            target2:
                return 3
                return 4
            """
            # test with LOAD_GLOBAL since it has a different instruction size
            insts = [
                new_inst("RESUME", arg=0),
                new_inst("JUMP_FORWARD", target=jump_to_target_inst),
                targets[0],
                new_inst("LOAD_GLOBAL", arg=0, argval="print"),
                new_inst("POP_TOP"),
                new_inst("RETURN_VALUE"),
                jump_to_target_inst,
                new_inst("LOAD_CONST", argval=2),
                new_inst("LOAD_GLOBAL", arg=0, argval="print"),
                new_inst("POP_TOP"),
                new_inst("RETURN_VALUE"),
                targets[1],
                new_inst("RETURN_VALUE"),
                new_inst("LOAD_CONST", argval=4),
                new_inst("RETURN_VALUE"),
            ]
            code_options = collections.OrderedDict(
                [
                    ("co_argcount", 0),
                    ("co_posonlyargcount", 0),
                    ("co_kwonlyargcount", 0),
                    ("co_nlocals", 0),
                    ("co_stacksize", 2),
                    ("co_flags", 3),
                    ("co_code", b""),
                    ("co_consts", consts),
                    ("co_names", ("print",)),
                    ("co_varnames", ()),
                    ("co_filename", __file__),
                    ("co_name", "test_py311_fn"),
                    ("co_qualname", "test_py311_fn"),
                    ("co_firstlineno", 1),
                    ("co_linetable", b""),
                    ("co_exceptiontable", b""),
                    ("co_freevars", ()),
                    ("co_cellvars", ()),
                ]
            )
            return bytecode_transformation.clean_and_assemble_instructions(
                insts,
                list(code_options.keys()),
                code_options,
            )

        # format: jump_opname, target_idx, expected forward jump, expected return value
        test_args = (
            ("JUMP_FORWARD", 0, False, 1),
            ("JUMP_FORWARD", 1, True, 3),
            ("JUMP_BACKWARD", 0, False, 1),
            ("JUMP_BACKWARD", 1, True, 3),
        )

        for test in test_args:
            insts, code = create_test_code(test[0], test[1])
            # check if offset of latest jump instruction is forward/backward
            for inst in reversed(insts):
                if inst.opname.startswith("JUMP"):
                    if test[2]:
                        self.assertIn("FORWARD", inst.opname)
                    else:
                        self.assertIn("BACKWARD", inst.opname)
                    break
            # run the code and check result

            def dummy_fn():
                pass

            dummy_fn.__code__ = code
            self.assertEqual(dummy_fn(), test[3])

            dummy_opt = torch._dynamo.optimize("eager")(dummy_fn)
            self.assertEqual(dummy_opt(), test[3])

    def test_exception_table_encode_varint(self):
        # these numbers have no real meaning to them
        nums = [
            0b111_101010_000000,
            0b1100_111000_010101_101010,
        ]
        b = bytecode_transformation.encode_exception_table_varint(
            nums[0]
        ) + bytecode_transformation.encode_exception_table_varint(nums[1])
        nums_new = []
        b_iter = iter(bytes(b))
        while True:
            try:
                nums_new.append(
                    bytecode_transformation.decode_exception_table_varint(b_iter)
                )
            except StopIteration:
                break
        self.assertEqual(nums, nums_new)

    @skipIfNotPy311
    def test_exception_table_parsing(self):
        def fn():
            try:
                with a():
                    b()
                c()
            except Exception:
                d()
            finally:
                e()
            f()

        tab = bytecode_transformation.parse_exception_table(
            fn.__code__.co_exceptiontable
        )
        b = bytecode_transformation.assemble_exception_table(tab)
        self.assertEqual(b, fn.__code__.co_exceptiontable)

    @skipIfNotPy311
    def test_exception_table_e2e(self):
        def fn():
            try:
                with a():
                    b()
                c()
            except Exception:
                d()
            finally:
                e()
            f()

        def nothing(*args):
            pass

        code = bytecode_transformation.transform_code_object(fn.__code__, nothing)
        self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable)

    @skipIfNotPy311
    def test_exception_table_e2e_2(self):
        # last instructions of an exn_table entry is a large instruction
        # i.e., LOAD_GLOBAL a
        def fn():
            try:
                return a
            except Exception:
                pass

        def nothing(*args):
            pass

        code = bytecode_transformation.transform_code_object(fn.__code__, nothing)
        self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable)

    @skipIfNotPy311
    def test_exception_table_entry_propagation(self):
        insts = []
        for _ in range(10):
            insts.append(bytecode_transformation.create_instruction("NOP"))
        insts[8].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[0], insts[9], insts[0], 0, True
        )
        insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[0], insts[0], insts[1], 0, True
        )
        insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[0], insts[2], insts[2], 0, True
        )
        insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[4], insts[6], insts[3], 0, True
        )
        insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[9], insts[9], insts[4], 0, True
        )
        insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[7], insts[9], insts[5], 0, True
        )
        bytecode_transformation.propagate_inst_exn_table_entries(insts)
        expected = [1, 2, 2, 0, 3, 3, 3, 5, 5, 4]
        for inst, exp in zip(insts, expected):
            self.assertIsNotNone(inst.exn_tab_entry)
            self.assertIs(inst.exn_tab_entry.target, insts[exp])

    @skipIfNotPy311
    def test_compute_exception_table_nested(self):
        insts = []
        for _ in range(20):
            insts.append(bytecode_transformation.create_instruction("NOP"))
        insts[10].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[1], insts[10], insts[0], 0, True
        )
        insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[1], insts[1], insts[1], 0, True
        )
        insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[1], insts[3], insts[2], 0, True
        )
        insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[5], insts[7], insts[3], 0, True
        )
        insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[10], insts[10], insts[4], 0, True
        )
        insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[8], insts[10], insts[5], 0, True
        )
        insts[14].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[13], insts[17], insts[6], 0, True
        )
        insts[16].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            insts[15], insts[16], insts[7], 0, True
        )
        bytecode_transformation.update_offsets(insts)
        tab = bytecode_transformation.compute_exception_table(insts)
        expected = [
            (1, 1, 1),
            (2, 3, 2),
            (4, 4, 0),
            (5, 7, 3),
            (8, 9, 5),
            (10, 10, 4),
            (13, 14, 6),
            (15, 16, 7),
            (17, 17, 6),
        ]
        self.assertEqual(len(tab), len(expected))
        for entry, exp in zip(tab, expected):
            self.assertEqual(entry.start, exp[0] * 2)
            self.assertEqual(entry.end, exp[1] * 2)
            self.assertEqual(entry.target, exp[2] * 2)

    @skipIfNotPy311
    def test_remove_dead_code_with_exn_table_entries(self):
        create_instruction = bytecode_transformation.create_instruction
        target1 = create_instruction("NOP")
        target2 = create_instruction("NOP")
        target3 = create_instruction("NOP")
        exn_start = create_instruction("NOP")
        exn_end = create_instruction("NOP")
        insts = [
            create_instruction("JUMP_FORWARD", target=target1),
            exn_start,  # dead
            target1,
            create_instruction("JUMP_FORWARD", target=target3),
            exn_end,  # dead
            target2,
            target3,
        ]
        exn_start.exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
            exn_start, exn_end, target2, 0, True
        )
        bytecode_transformation.propagate_inst_exn_table_entries(insts)
        insts = bytecode_analysis.remove_dead_code(insts)
        self.assertEqual(len(insts), 5)
        self.assertNotIn(exn_start, insts)
        self.assertNotIn(exn_end, insts)
        self.assertIn(target2, insts)
        self.assertIn(target3, insts)
        bytecode_transformation.update_offsets(insts)
        tab = bytecode_transformation.compute_exception_table(insts)
        self.assertEqual(len(tab), 1)
        self.assertEqual(tab[0].start, 2)
        self.assertEqual(tab[0].end, 4)
        self.assertEqual(tab[0].target, 6)

    def test_bytecode_from_template(self):
        def fn(d1):
            for k, v in d1.items():
                d2[k] = v

        varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"}
        insts = bytecode_transformation.bytecode_from_template(fn, varname_map)
        for inst in insts:
            self.assertIsNone(inst.starts_line)
            if inst.opname.startswith("LOAD"):
                self.assertNotIn(inst.argval, varname_map)
                if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"):
                    self.assertIsNone(inst.arg)
            self.assertFalse(inst.opname.startswith("RETURN"))

    @skipIfNotPy311
    def test_bytecode_from_template_noprefix(self):
        # Test that 3.11+ prefix instructions are removed
        def gen_fn():
            cl = None

            def fn():
                return cl

            return fn

        fn = gen_fn()

        dis_insts = list(dis.get_instructions(fn))
        names = {inst.opname for inst in dis_insts}
        self.assertIn("RESUME", names)
        self.assertIn("COPY_FREE_VARS", names)

        insts = bytecode_transformation.bytecode_from_template(fn)
        names = {inst.opname for inst in insts}
        self.assertNotIn("RESUME", names)
        self.assertNotIn("COPY_FREE_VARS", names)

    def test_bytecode_from_template_noreturn1(self):
        # Test that functions with multiple returns will have their
        # returns replaced with jumps to the end
        def fn():
            if x:
                return y
            z = 3
            return z

        dis_insts = list(dis.get_instructions(fn))
        dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts))
        self.assertGreater(len(dis_returns), 1)
        self.assertTrue(dis_insts[-1].opname.startswith("RETURN"))

        insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
        self.assertEqual(insts[-1].opname, "NOP")
        self.assertEqual(len(dis_insts), len(insts))
        for i0, i1 in zip(dis_insts, insts):
            if i0.opname.startswith("RETURN"):
                if i1 is insts[-1]:
                    continue
                self.assertIn("JUMP", i1.opname)
                self.assertIs(i1.target, insts[-1])

    # Should work with 3.10, but testing with 3.11+ is sufficient.
    # In 3.8, `fn` ends with a RETURN_VALUE.
    @skipIfNotPy311
    def test_bytecode_from_template_noreturn2(self):
        # Test function that doesn't end with RETURN_VALUE
        def fn():
            if x:
                return x
            if x:
                return x
            raise RuntimeError

        dis_insts = list(dis.get_instructions(fn))
        self.assertFalse(dis_insts[-1].opname.startswith("RETURN"))

        insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
        self.assertEqual(insts[-1].opname, "NOP")
        self.assertEqual(insts[-2].opname, dis_insts[-1].opname)
        self.assertEqual(len(dis_insts) + 1, len(insts))
        for i0, i1 in zip(dis_insts, insts):
            if i0.opname.startswith("RETURN"):
                self.assertIn("JUMP", i1.opname)
                self.assertIs(i1.target, insts[-1])

    @skipIfNotPy312
    def test_bytecode_from_template_noreturn_const(self):
        # Test 3.12+ RETURN_CONST
        def fn():
            if x:
                return 1
            return 0

        dis_insts = list(dis.get_instructions(fn))
        dis_return_consts = list(
            filter(lambda x: x.opname == "RETURN_CONST", dis_insts)
        )
        self.assertGreater(len(dis_return_consts), 1)
        self.assertTrue(dis_insts[-1].opname == "RETURN_CONST")

        insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
        self.assertEqual(insts[-1].opname, "NOP")
        insts_i = 0
        for i, inst in enumerate(dis_insts):
            if inst.opname == "RETURN_CONST":
                self.assertEqual(insts[insts_i].opname, "LOAD_CONST")
                insts_i += 1
                if insts_i != len(insts) - 1:
                    self.assertIn("JUMP", insts[insts_i].opname)
                    self.assertIs(insts[insts_i].target, insts[-1])
            insts_i += 1


class BytecodeHookTests(torch._dynamo.test_case.TestCase):
    def test_bytecode_hook(self):
        def fn(a, b):
            return a - b * 10

        def hook(code, out_code):
            print(code)
            print(out_code)
            return code

        torch._dynamo.reset()
        handle = torch._dynamo.convert_frame.register_bytecode_hook(hook)
        try:
            opt_fn = torch.compile(fn)
            for i in range(2, 12):
                opt_fn(torch.randn(i), torch.randn(i))
        finally:
            handle.remove()


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
