# Owner(s): ["module: dynamo"]
import copy
import re
import unittest
from textwrap import dedent
from unittest.mock import patch

import torch
import torch._dynamo
import torch._dynamo.test_case
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._dynamo.testing import CompileCounter, expectedFailureDynamic, rand_strided
from torch._functorch.aot_autograd import _aot_export_function, create_functional_call
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.profiler import profile
from torch.testing import FileCheck
from torch.testing._internal.common_utils import compare_equal_outs_and_grads


def maybe_dupe_op(x):
    y = x + 1
    z = x + 2
    if x.numel() < 5:
        return y, y
    else:
        return y, z


def is_dynamic_shape_test(test_name):
    return test_name.endswith("_dynamic_shapes")


aten = torch.ops.aten
lib = torch.library.Library("custom", "DEF")  # noqa: TOR901
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")


class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
    def test_LSTM(self):
        # https://github.com/pytorch/torchdynamo/issues/1147
        class Repro(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.self_mod_model_lstm_lstm = torch.nn.LSTM(
                    64, 64, num_layers=2, bidirectional=True
                )

            def forward(self, permute: torch.Tensor):
                self_mod_model_lstm_lstm = self.self_mod_model_lstm_lstm(permute)
                return (self_mod_model_lstm_lstm,)

        mod = Repro()

        aot_mod = torch._dynamo.optimize("aot_eager")(mod)

        args = [((92, 4, 64), (1, 5888, 92), torch.float32, "cpu", False)]
        args = [
            rand_strided(sh, st, dt, dev).requires_grad_(rg)
            for (sh, st, dt, dev, rg) in args
        ]

        eager_result = mod(*args)
        aot_result = aot_mod(*args)
        self.assertTrue(torch._dynamo.testing.same(eager_result, aot_result))

    def test_mutation(self):
        # https://github.com/pytorch/torchdynamo/issues/1301
        def fn(param, y):
            prev_grad = torch.is_grad_enabled()
            try:
                torch.set_grad_enabled(False)
                param.add_(y)
            finally:
                torch.set_grad_enabled(prev_grad)
            return y

        y = torch.randn(4)
        x = torch.nn.Parameter(torch.randn(4))
        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
        # This should not error: we mutated an autograd leaf under no_grad mode.
        aot_fn(x, y)

    def test_mutation1(self):
        def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
            getitem = diagonal_chunked_attention_scores[
                (
                    slice(None, None, None),
                    slice(None, None, None),
                    slice(None, 256, None),
                    slice(None, 257, None),
                )
            ]
            _stack0[
                (
                    slice(None, None, None),
                    slice(None, -1, None),
                    slice(None, None, None),
                    slice(256, None, None),
                )
            ] = getitem
            view = _stack0.view(1, 12, 1024, 513)
            return (view,)

        x = torch.randn(torch.Size([12, 4, 256, 513]))
        y = torch.randn(torch.Size([12, 3, 512, 513]))
        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
        aot_fn(x, y)

    def test_negative_testing_mutation(self):
        def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
            getitem = diagonal_chunked_attention_scores[
                (
                    slice(None, None, None),
                    slice(None, None, None),
                    slice(None, 256, None),
                    slice(None, 257, None),
                )
            ]
            _stack0 = torch.sin(_stack0)
            _stack0[
                (
                    slice(None, None, None),
                    slice(None, -1, None),
                    slice(None, None, None),
                    slice(256, None, None),
                )
            ] = getitem
            view = _stack0.view(1, 12, 1024, 513)
            return (view,)

        x = torch.randn(torch.Size([12, 4, 256, 513]))
        y = torch.randn(torch.Size([12, 3, 512, 513]))
        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
        aot_fn(x, y)

    def test_negative_testing(self):
        def fn(x, y):
            return torch.sin(x).add_(y)

        y = torch.randn(4)
        x = torch.randn(4)
        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
        aot_fn(x, y)

    def test_call_fn_with_non_const_inputs_aot_safe(self):
        class ModuleSpecialFwd(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(
                    in_channels=3, out_channels=20, kernel_size=(5, 5)
                )

            def _conv_forward(self, x):
                return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)

            def forward(self, x):
                return self._conv_forward(x)

        # Init mod
        mod = ModuleSpecialFwd()
        rx = torch.randn([3, 10, 10])

        # Run it for real
        real = mod(rx)

        # Run it in export
        graph, _ = torch._dynamo.export(mod)(rx)

        # Run exported graph with AOT
        self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))

        aot_fn = torch._dynamo.optimize("aot_eager")(graph)
        aot_fn(rx)

    def test_call_fn_with_non_const_inputs_aot_unsafe(self):
        class ModuleSpecialFwd(torch.nn.Module):
            def _some_bad_fwd(self, param, y):
                prev_grad = torch.is_grad_enabled()
                try:
                    torch.set_grad_enabled(False)
                    param.add_(y)
                finally:
                    torch.set_grad_enabled(prev_grad)
                return y

            def forward(self, x, y):
                return self._some_bad_fwd(x, y)

        # Init mod
        mod = ModuleSpecialFwd()
        x = torch.nn.Parameter(torch.randn(4))
        y = torch.randn([4])

        # Run it for real
        real = mod(x, y)

        # Run it in export
        graph, _ = torch._dynamo.export(mod)(x, y)

        # Assert equal
        self.assertTrue(torch._dynamo.testing.same(real, graph(x, y)))

        # Run exported graph with AOT
        aot_fn = torch._dynamo.optimize("aot_eager")(graph)
        # This should not error: we mutated an autograd leaf under no_grad mode.
        aot_fn(x, y)

    def test_call_fn_with_non_const_inputs_aot_unsafe_control_flow(self):
        class ModuleSpecialFwd(torch.nn.Module):
            def _some_bad_fwd(self, param, y):
                if y[0][0] < 3:
                    return y + param
                return param * y

            def forward(self, x, y):
                a = x * y
                a = self._some_bad_fwd(a, a)
                b = x + y
                return a * b

        # Init mod
        mod = ModuleSpecialFwd()
        x = torch.nn.Parameter(torch.randn([2, 2]))
        y = torch.randn([2, 2])

        # Run it for real
        real = mod(x, y)

        # Run it through optimize, with our capturing fn

        gms = []
        counter = CompileCounter()

        def capturing_fn(gm, inputs):
            nonlocal gms
            gms.append(gm)
            return counter(gm, inputs)

        optimized_mod = torch._dynamo.optimize(capturing_fn)(mod)

        # Assert equal
        self.assertTrue(torch._dynamo.testing.same(real, optimized_mod(x, y)))

        # Uncomment to reproduce commented out graphs below.
        # for gm in gms:
        #     print("GM CODE", gm.code)

        self.assertEqual(counter.frame_count, 4)
        self.assertEqual(counter.op_count, 7)
        # Graph 1
        # def forward(self, x : torch.nn.parameter.Parameter, y : torch.Tensor):
        #     mul = x * y;  x = y = None
        #     return (mul,)
        # BREAK
        # Graph 2
        # def forward(self, y : torch.Tensor):
        #     getitem = y[0];  y = None
        #     getitem_1 = getitem[0];  getitem = None
        #     lt = getitem_1 < 3;  getitem_1 = None
        #     return (lt,)
        # BREAK
        # Graph 3
        # def forward(self, param : torch.Tensor, y : torch.Tensor):
        #     add = y + param;  y = param = None
        #     return (add,)
        # BREAK
        # Graph 4
        # def forward(self, _stack0 : torch.Tensor, x : torch.nn.parameter.Parameter, y : torch.Tensor):
        #     add = x + y;  x = y = None
        #     mul = _stack0 * add;  _stack0 = add = None
        #     return (mul,)

        # Run fn with AOT
        torch._dynamo.reset()

        aot_fn = torch._dynamo.optimize("aot_eager")(optimized_mod)
        aot_fn(x, y)

    # Note: Dynamo recompilation guarding invalid grad
    #
    # This test is a spiritual equivalent to test_invalid_requires_grad_fake in test_autodispatch.py
    # The point of this test is to invoke aot_autograd in a way that would normally trigger an assertion
    # (This is what test_invalid_requires_grad_fake) does. However, the point of this test is to prove
    # that we do not hit this assertion, as dynamo recompiles correctly and protects this condition.
    #
    # Subnote: The reason for us having test_invalid_requires_grad_fake utilizing fake tensors
    # is because dynamo sends fake tensors down to aot_autograd.
    @patch("torch._functorch.config.debug_assert", True)
    def test_requires_grad_fake_via_dynamo_recompiles(self):
        class F(torch.nn.Module):
            def forward(self, x, y):
                return (x + y,)

        x = torch.randn(3, 3, requires_grad=True)
        y = torch.randn(3, 3, requires_grad=True)
        z = torch.randn(3, 3, requires_grad=False)

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        compare_equal_outs_and_grads(self, F(), fxy, (x, y))
        compare_equal_outs_and_grads(self, F(), fxy, (x, z))
        self.assertIn(
            """tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""",
            failure_reason,
        )

        # Reset failure reason
        failure_reason = None

        self.assertEqual(cc.frame_count, 2)

        torch._dynamo.reset()  # for new backend
        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        fxz = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        compare_equal_outs_and_grads(self, F(), fxz, (x, z))
        compare_equal_outs_and_grads(self, F(), fxz, (x, z))
        self.assertEqual(cc.frame_count, 1)
        self.assertTrue(failure_reason is None)

    def test_double_backward_errors(self):
        # Remove this test after we get double backward to actually work
        for grad_output in (torch.tensor(1.0, requires_grad=True), None):
            x = torch.tensor(1.0, requires_grad=True)
            err = "torch.compile with aot_autograd does not currently support double backward"

            # The following cases should be equivalent:

            # (1) double backward entirely inside compiled function
            def f1(x):
                y = x.sin().exp()
                (gx,) = torch.autograd.grad(
                    y, x, create_graph=True, grad_outputs=grad_output
                )
                torch.autograd.grad(gx, x)
                return gx

            compiled_f1 = torch.compile(backend="aot_eager")(f1)
            f1(x)
            with self.assertRaisesRegex(RuntimeError, err):
                compiled_f1(x)

            # (2) the second half of double backward outside compiled function
            def f2(x):
                y = x.sin().exp()
                (gx,) = torch.autograd.grad(
                    y, x, create_graph=True, grad_outputs=grad_output
                )
                return gx

            compiled_f2 = torch.compile(backend="aot_eager")(f2)
            gx = compiled_f2(x)
            with self.assertRaisesRegex(RuntimeError, err):
                torch.autograd.grad(gx, x)

            # (3) double backward entirely outside compiled function
            def f3(x):
                y = x.sin().exp()
                return y

            compiled_f3 = torch.compile(backend="aot_eager")(f3)
            y = compiled_f3(x)
            (gx,) = torch.autograd.grad(
                y, x, create_graph=True, grad_outputs=grad_output
            )
            with self.assertRaisesRegex(RuntimeError, err):
                torch.autograd.grad(gx, x)

        # create_graph=False
        def f4(x):
            y = x.sin().exp()
            return y

        compiled_f4 = torch.compile(backend="aot_eager")(f4)
        x = torch.tensor(1.0, requires_grad=True)
        y = compiled_f4(x)
        (gx,) = torch.autograd.grad(y, x, create_graph=False, grad_outputs=grad_output)

    @patch("torch._functorch.config.debug_assert", True)
    def test_arg_dupe_via_dynamo_recompiles(self):
        class F(torch.nn.Module):
            def forward(self, x, y):
                x = x.trunc_()
                y = y.trunc_()
                return (x + y,)

        x = torch.randn(3, 3, requires_grad=True)
        x1, x2, x3, x4 = x.clone(), x.clone(), x.clone(), x.clone()
        y = torch.randn(3, 3, requires_grad=True)
        y1, y2, y4 = y.clone(), y.clone(), y.clone()

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        # Note: to prevent a recompilation between the two calls,
        # we need to clone x and y on each use.
        # fxy mutates the input's metadata, so otherwise dynamo will end up recompiling.
        fxy(x1, y1)
        fxy(x2, y2)

        self.assertTrue(failure_reason is None)

        # Reset failure reason
        failure_reason = None

        self.assertEqual(cc.frame_count, 1)

        torch._dynamo.reset()  # for new backend
        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        fxx = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        fxx(x3, x3)
        fxx(x4, y4)
        self.assertEqual(cc.frame_count, 2)
        self.assertIn("""L['x'] is L['y']""", failure_reason)

    @patch("torch._functorch.config.debug_assert", True)
    def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self):
        class F(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mean = torch.nn.Parameter(torch.randn(3, 3))

            def forward(self, a, b, e, f):
                a.trunc_()
                b.trunc_()
                return (a + b + self.mean) * e * f

        a = torch.randn(3, 3, requires_grad=True)
        b = torch.randn(3, 3, requires_grad=True)
        a1, a2 = a.clone(), a.clone()
        b1, b2 = b.clone(), b.clone()

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        self.assertTrue(failure_reason is None)

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        f(a1, a1, 2, 2)
        f(a2, b2, 2, 2)
        self.assertEqual(cc.frame_count, 2)
        self.assertIn(
            """L['a'] is L['b']""",
            failure_reason,
        )

        torch._dynamo.reset()

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        c = torch.randn(3, 3, requires_grad=True)
        d = torch.randn(3, 3, requires_grad=True)
        c3, c4 = c.clone(), c.clone()
        d3, d4 = d.clone(), d.clone()

        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        f(c3, c3, 3, 3)
        f(c4, d4, 3, 3)
        self.assertEqual(cc.frame_count, 2)
        self.assertIn("""L['a'] is L['b']""", failure_reason)

    @patch("torch._functorch.config.debug_assert", True)
    def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
        z = None

        class F(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mean = torch.nn.Parameter(torch.randn(3, 3))

            def forward(self, a, b, e, f):
                a.trunc_()
                b.trunc_()
                return (a + b + z + self.mean) * e * f

        a = torch.randn(3, 3, requires_grad=True)
        b = torch.randn(3, 3, requires_grad=True)
        z = a
        a1, a2 = a.clone(), a.clone()
        b1, b2 = b.clone(), b.clone()

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        self.assertTrue(failure_reason is None)

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        f(a1, a1, 2, 2)
        f(a2, b2, 2, 2)
        self.assertEqual(cc.frame_count, 2)
        self.assertIn(
            """L['a'] is L['b']""",
            failure_reason,
        )

    @patch("torch._functorch.config.debug_assert", True)
    def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg_list(self):
        class F(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mean = torch.nn.Parameter(torch.randn(3, 3))

            def forward(self, e, f, a, b):
                a.trunc_()
                b.trunc_()
                return (a + b + self.mean) * e[0] * f[0]

        a = torch.randn(3, 3, requires_grad=True)
        b = torch.randn(3, 3, requires_grad=True)
        a1, a2 = a.clone(), a.clone()
        b1, b2 = b.clone(), b.clone()

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        self.assertTrue(failure_reason is None)

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        f([3, 2, 1], [4, 5, 6], a1, a1)
        f([3, 2, 1], [4, 5, 6], a2, b2)
        self.assertEqual(cc.frame_count, 2)
        self.assertIn(
            """L['a'] is L['b']""",
            failure_reason,
        )

        torch._dynamo.reset()

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        c = torch.randn(3, 3, requires_grad=True)
        d = torch.randn(3, 3, requires_grad=True)
        c3, c4 = c.clone(), c.clone()
        d3, d4 = d.clone(), d.clone()

        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        f([3, 2, 1], [4, 5, 6], c3, c3)
        f([3, 2, 1], [4, 5, 6], c4, d4)
        self.assertEqual(cc.frame_count, 2)

    @patch("torch._functorch.config.debug_assert", True)
    def test_arg_dupe_via_dynamo_recompiles_many_args_param(self):
        class F(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mean = torch.nn.Parameter(torch.randn(3, 3))

            def forward(self, a, b):
                a.trunc_()
                b.trunc_()
                return a + b + self.mean

        a = torch.randn(3, 3, requires_grad=True)
        b = torch.randn(3, 3, requires_grad=True)
        a1, a2 = a.clone(), a.clone()
        b1, b2 = b.clone(), b.clone()

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        self.assertTrue(failure_reason is None)

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        f(a1, a1)
        f(a2, b2)
        self.assertEqual(cc.frame_count, 2)
        self.assertIn(
            """L['a'] is L['b']""",
            failure_reason,
        )

        torch._dynamo.reset()

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        c = torch.randn(3, 3, requires_grad=True)
        d = torch.randn(3, 3, requires_grad=True)
        c3, c4 = c.clone(), c.clone()
        d3, d4 = d.clone(), d.clone()

        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        f(c3, c3)
        f(c4, d4)
        self.assertEqual(cc.frame_count, 2)
        self.assertIn("""L['a'] is L['b']""", failure_reason)

    @patch("torch._functorch.config.debug_assert", True)
    def test_arg_dupe_via_dynamo_recompiles_many_args(self):
        class F(torch.nn.Module):
            def forward(self, a, b, c, d):
                a.trunc_()
                b.trunc_()
                c.trunc_()
                d.trunc_()
                return (a + b + c + d,)

        a = torch.randn(3, 3, requires_grad=True)
        b = torch.randn(3, 3, requires_grad=True)
        a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
        b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        self.assertTrue(failure_reason is None)

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        f(a1, a1, a1, a1)
        f(a2, b2, b2, b2)
        self.assertEqual(cc.frame_count, 2)
        self.assertIn(
            """L['a'] is L['b']""",
            failure_reason,
        )

        torch._dynamo.reset()

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")

        c = torch.randn(3, 3, requires_grad=True)
        d = torch.randn(3, 3, requires_grad=True)
        c3, c4 = c.clone(), c.clone()
        d3, d4 = d.clone(), d.clone()

        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
        f(a3, b3, c3, c3)
        f(a4, b4, c4, d4)
        self.assertEqual(cc.frame_count, 2)
        self.assertIn("""L['c'] is L['d']""", failure_reason)

    def test_alias_inputs(self):
        def fn():
            a = torch.tensor([1])
            a = a[0:1]
            b = a.squeeze()
            a[0] = 0
            if a[0] < 1e5:
                pass
            a[0] = 2
            return b

        ref_output = fn()
        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
        actual_output = aot_fn()
        self.assertEqual(ref_output, actual_output)

    def test_grad_inputs_alias_inputs(self):
        class Test(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x, y):
                ctx.save_for_backward(x)
                return y

            @staticmethod
            def backward(ctx, grad):
                (x,) = ctx.saved_tensors
                return x, grad

        def fn(x, y):
            return Test.apply(x, y)

        x = torch.ones(1, requires_grad=True)
        y = torch.ones(1, requires_grad=True)
        compiled_fn = torch.compile(fn, backend="aot_eager")
        out = compiled_fn(x, y)
        out.sum().backward()

    @expectedFailureDynamic  # https://github.com/pytorch/pytorch/issues/103539
    @torch._dynamo.config.patch(automatic_dynamic_shapes=False)
    @patch("torch._functorch.config.debug_assert", True)
    def test_multiple_aot_autograd_calls_dupe_args(self):
        # this is just dealing with the fact that
        # aot_module_simplified expects submods to always return tuples/lists
        class WrapperModule(torch.nn.Module):
            def __init__(self, mod):
                super().__init__()
                self.mod = mod

            def forward(self, *args):
                out = self.mod(*args)
                if isinstance(out, (list, tuple)):
                    return out
                return (out,)

        def compile_submod(input_mod, args):
            from functorch.compile import nop
            from torch._functorch.aot_autograd import aot_module_simplified

            class WrapperModule(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.original = input_mod
                    self.submod = aot_module_simplified(input_mod, args, nop)

                def forward(self, *args):
                    return self.submod(*args)

            return WrapperModule()

        def test_compile(fx_g, example_inps):
            split_gm = torch.fx.passes.split_module.split_module(
                fx_g, None, lambda node: 1 if "mul" in str(node) else 0
            )
            submod_1_inps = split_gm.submod_0(*example_inps)
            split_gm.submod_0 = compile_submod(
                WrapperModule(split_gm.submod_0), example_inps
            )
            split_gm.submod_1 = compile_submod(
                WrapperModule(split_gm.submod_1), submod_1_inps
            )
            return split_gm

        @torch._dynamo.optimize(test_compile)
        def f(a):
            b, c = torch.ops.custom.maybe_dupe_op(a)
            return (b.mul_(c),)

        f(torch.ones(4))
        f(torch.ones(6))

    def test_nn_parameter_construction(self):
        # https://github.com/pytorch/pytorch/issues/99569
        def fn(x):
            y = x.sin()
            z = torch.nn.Parameter(torch.ones(1))
            return y + z

        x = torch.rand((4, 4))

        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
        self.assertTrue(torch._dynamo.testing.same(fn(x), opt_fn(x)))

    def test_aot_sequence_nr(self):
        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(
                    in_channels=16,
                    out_channels=16,
                    kernel_size=(1, 1),
                    stride=1,
                    padding="same",
                    bias=True,
                )
                self.bn1 = torch.nn.BatchNorm2d(num_features=16)
                self.relu1 = torch.nn.ReLU()
                self.fc1 = torch.nn.Linear(in_features=1638400, out_features=1)
                self.loss_fn = torch.nn.L1Loss()

            def forward(self, x, target):
                y = x
                x = self.conv1(x)
                x = self.bn1(x)
                x = self.relu1(x)
                x = x + y
                x = torch.flatten(x)
                x = self.fc1(x)
                output = self.loss_fn(x, target)

                return (output,)

        mod = Model()
        mod.train()
        x = torch.rand(100, 16, 32, 32, requires_grad=True)
        target = torch.rand(1)

        # Use dynamo export to get the fx graph module
        g_mod, _ = torch._dynamo.export(mod, x, target)

        def _prepare_model_args():
            named_parameters = dict(g_mod.named_parameters(remove_duplicate=False))
            named_buffers = dict(g_mod.named_buffers(remove_duplicate=False))
            params_and_buffers = {
                **dict(named_parameters),
                **dict(named_buffers),
            }
            params_and_buffers_flat, params_spec = pytree.tree_flatten(
                params_and_buffers
            )
            params_len = len(params_and_buffers_flat)
            functional_call = create_functional_call(g_mod, params_spec, params_len)
            return params_and_buffers_flat, functional_call

        full_args, fn_to_trace = _prepare_model_args()
        param_and_buf_len = len(full_args)
        full_args.extend([x, target])

        # aot_export requires a graph mod input of fwd graph
        # returns the full fwd/bwd graph in graph mod format
        with torch.enable_grad(), fx_traceback.preserve_node_meta():
            fx_g, _, _, _ = _aot_export_function(
                fn_to_trace,
                full_args,
                decompositions=None,
                num_params_buffers=param_and_buf_len,
                no_tangents=True,
            )

        # Walk all the nodes in fx graph.
        # Write the resulting ops to a table
        min_seq_nr = -1
        seq_table = "SeqNr|OrigAten|SrcFn|FwdSrcFn\n"
        for node in fx_g.graph.nodes:
            if "call_" in node.op and "getitem" not in str(node.target):
                seq_nr = node.meta.get("seq_nr", -1)
                if seq_nr < 0:
                    continue
                if min_seq_nr < 0:
                    min_seq_nr = seq_nr
                source_fn_stack = node.meta.get("source_fn_stack", [])
                orig_aten = node.meta.get("original_aten", "")
                mod_name = ""
                if len(source_fn_stack) > 0:
                    mod_name = source_fn_stack[-1][0]
                # Make all seq_nr relative so it starts at 0
                seq_nr = seq_nr - min_seq_nr
                # For backward nodes, also test that metadata from the corresponding
                # forward node is copied over.
                fwd_source_fn_stack = node.meta.get("fwd_source_fn_stack", [])
                fwd_mod_name = ""
                if len(fwd_source_fn_stack):
                    fwd_mod_name = fwd_source_fn_stack[-1][0]
                seq_table = (
                    seq_table + f"{seq_nr}|{orig_aten}|{mod_name}|{fwd_mod_name}\n"
                )

        self.maxDiff = None
        self.assertExpectedInline(
            seq_table,
            dedent(
                """\
SeqNr|OrigAten|SrcFn|FwdSrcFn
0|aten.convolution.default|l__self___conv1|
0|aten.add.Tensor|l__self___bn1|
1|aten._native_batch_norm_legit_functional.default|l__self___bn1|
2|aten.relu.default|l__self___relu1|
2|aten.detach.default|l__self___relu1|
2|aten.detach.default|l__self___relu1|
3|aten.add.Tensor|add|
4|aten.view.default|flatten|
5|aten.view.default|l__self___fc1|
6|aten.t.default|l__self___fc1|
7|aten.addmm.default|l__self___fc1|
8|aten.view.default|l__self___fc1|
9|aten.sub.Tensor|l__self___loss_fn|
10|aten.abs.default|l__self___loss_fn|
11|aten.mean.default|l__self___loss_fn|
11|aten.ones_like.default||l__self___loss_fn
11|aten.expand.default||l__self___loss_fn
11|aten.div.Scalar||l__self___loss_fn
10|aten.sgn.default||l__self___loss_fn
10|aten.mul.Tensor||l__self___loss_fn
8|aten.view.default||l__self___fc1
7|aten.t.default||l__self___fc1
7|aten.mm.default||l__self___fc1
7|aten.t.default||l__self___fc1
7|aten.mm.default||l__self___fc1
7|aten.t.default||l__self___fc1
7|aten.sum.dim_IntList||l__self___fc1
7|aten.view.default||l__self___fc1
6|aten.t.default||l__self___fc1
5|aten.view.default||l__self___fc1
4|aten.view.default||
2|aten.detach.default||l__self___relu1
2|aten.detach.default||l__self___relu1
2|aten.threshold_backward.default||l__self___relu1
1|aten.native_batch_norm_backward.default||l__self___bn1
0|aten.convolution_backward.default||l__self___conv1
11|aten.add.Tensor||l__self___loss_fn
"""
            ),
        )

    def test_split_with_sizes_aot_autograd_cleans_up_traceback_meta(self):
        from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks

        def fn(result, split_sizes):
            rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
            return rs

        example_inputs = (
            torch.randn(32, requires_grad=True),
            torch.tensor((7, 16, 9)),
        )
        outs = fn(*example_inputs)
        setup_stacktrace_preservation_hooks([out.grad_fn for out in outs])
        with fx_traceback.preserve_node_meta():
            (outs[0].sum() + outs[1].sum() + outs[2].sum()).backward()

        self.assertNotIn("grad_fn_seq_nr", fx_traceback.current_meta)
        self.assertNotIn("in_grad_fn", fx_traceback.current_meta)

    # https://github.com/pytorch/pytorch/issues/110121
    def test_aot_export_joint_simple_repro(self):
        class Mod(torch.nn.Module):
            def __init__(self, *args, **kwargs) -> None:
                super().__init__(*args, **kwargs)
                self.linear = torch.nn.Linear(5, 7)

            def forward(self, x):
                return self.linear(x)

        def mini_backend(gm, sample_inputs):
            from torch._functorch.aot_autograd import aot_export_joint_simple

            fake_mode = torch._dynamo.utils.detect_fake_mode(sample_inputs)

            with patch.object(fake_mode, "allow_non_fake_inputs", True), fake_mode:
                return aot_export_joint_simple(gm, sample_inputs, trace_joint=False)

        sample_inputs = [torch.rand((3, 4, 5))]
        model = Mod()
        m_compiled = torch.compile(model, backend=mini_backend)

        out_ref = model(*sample_inputs)
        out_test = m_compiled(*sample_inputs)
        self.assertEqual(out_ref, out_test)

    def test_eager_sequence_nr(self):
        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = torch.nn.Conv2d(
                    in_channels=16,
                    out_channels=16,
                    kernel_size=(1, 1),
                    stride=1,
                    padding="same",
                    bias=True,
                )
                self.bn1 = torch.nn.BatchNorm2d(num_features=16)
                self.relu1 = torch.nn.ReLU()
                self.fc1 = torch.nn.Linear(in_features=1638400, out_features=1)
                self.loss_fn = torch.nn.L1Loss()

            def forward(self, x, target):
                y = x
                x = self.conv1(x)
                x = self.bn1(x)
                x = self.relu1(x)
                x = x + y
                x = torch.flatten(x)
                x = self.fc1(x)
                output = self.loss_fn(x, target)

                return (output,)

        def grad_with_create_graph(mod, x, target):
            y = mod(x, target)
            # Set create_graph=True to ensure that the sequence_nr
            # for backward ops continues to count down.
            (gx,) = torch.autograd.grad(
                y[0], x, create_graph=True, grad_outputs=grad_output
            )
            return gx

        x = torch.rand(100, 16, 32, 32, requires_grad=True)
        target = torch.rand(1)
        mod = Model()
        args = [mod, x, target]
        grad_output = torch.tensor(1.0, requires_grad=True)
        compiled_f1 = torch.compile(backend="aot_eager")(grad_with_create_graph)
        model_instance = compiled_f1
        with profile(
            activities=[torch.profiler.ProfilerActivity.CPU],
            record_shapes=True,
        ) as kineto_prof:
            res = model_instance(*args)
        bwd_set = set()
        prof_str = "SeqNr|Thread|FwdThread|Name\n"
        for event in kineto_prof.events():
            if event.sequence_nr >= 0:
                prof_str = (
                    prof_str + f"{event.sequence_nr}|{event.thread}"
                    f"|{event.fwd_thread}|{event.name}|\n"
                )
                if re.search(r"Backward[01]", event.name):
                    bwd_set.add(event.sequence_nr)
        self.assertTrue(len(bwd_set), 13)

    def test_aot_grad_mode_mutation(self):
        for compiler in ["aot_eager", "inductor"]:

            def f(x):
                y = x * x
                torch.set_grad_enabled(False)
                return y.clone(), y

            f_compiled = torch.compile(f, backend=compiler, fullgraph=True)

            torch.set_grad_enabled(True)
            x = torch.ones(3, requires_grad=True) * 3
            y_ref = f(x)
            self.assertEqual(torch.is_grad_enabled(), False)
            torch.set_grad_enabled(True)
            y = f_compiled(x)
            self.assertEqual(torch.is_grad_enabled(), False)
            torch.set_grad_enabled(True)
            self.assertEqual(y_ref, y)

            self.assertIsNone(y_ref[0].grad_fn)
            self.assertIsNone(y[0].grad_fn)

            self.assertIsNotNone(y_ref[1].grad_fn)
            self.assertIsNotNone(y[1].grad_fn)

            # Check that the grad computed for the inputs, given the input, is the same
            # The tangent to `y[0]`, which has grad_required=False, is irrelevant
            self.assertEqual(
                sum(y_ref[1].grad_fn(torch.tensor([-1.0, 2.0, 0.0]))),
                sum(
                    x
                    for x in y[1].grad_fn.apply(None, torch.tensor([-1.0, 2.0, 0.0]))
                    if x is not None
                ),
            )

    def test_aot_autograd_raises_invalid_leaf_set(self):
        @torch.compile
        def f(x):
            x.set_(torch.ones(2))

        # We still want to make sure that this raises
        x = torch.ones(2, requires_grad=True)
        with self.assertRaisesRegex(
            RuntimeError, "is being used in an in-place operation"
        ):
            f(x)

    def test_aot_autograd_expand_mutation_functionalizes(self):
        def fn(x):
            y = x.expand(3, *x.shape)
            y[0, 0].add_(5)
            return y

        opt_fn = torch.compile(fn, backend="aot_eager")

        x = torch.arange(6)
        x_opt = x.clone().detach()
        self.assertEqual(fn(x), opt_fn(x_opt))
        self.assertEqual(x, x_opt)

    def test_aot_autograd_expand_mutation_backwards(self):
        def fn(x, z):
            y = x.expand(3, *x.shape)
            y[1, 1].mul_(5)
            ret = y * z
            return ret

        opt_fn = torch.compile(fn, backend="aot_eager")

        x = torch.arange(6, dtype=torch.float)
        z = x.clone().detach()
        x_opt = x.clone().detach()
        z_opt = x.clone().detach()

        z.requires_grad = True
        z_opt.requires_grad = True

        res = fn(x, z)
        opt_res = opt_fn(x_opt, z_opt)

        self.assertEqual(res, opt_res)

        res.sum().backward()
        opt_res.sum().backward()

        self.assertEqual(x, x_opt)
        self.assertEqual(z.grad, z_opt.grad)

    def test_data_ptr_access_copy(self):
        import torch._functorch.config as _config

        with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
            with FakeTensorMode():
                x = torch.randn(3)
                y = copy.copy(x)
        self.assertEqual(y.shape, x.shape)

    def test_data_ptr_access_fails_in_forward(self):
        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
            torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)

            @torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
            def _(x):
                x.data_ptr()
                return x.clone()

            x = torch.randn(3)

            def data_ptr_graph_input(x):
                r0 = torch.ops.mylib.foo(x)
                return r0

            def data_ptr_graph_intermediate(x):
                y = x.clone()
                r0 = torch.ops.mylib.foo(y)
                return r0

            tests = [data_ptr_graph_input, data_ptr_graph_intermediate]

            def ctx():
                return self.assertRaisesRegex(
                    RuntimeError, "Cannot access data pointer"
                )

            for f in tests:
                with ctx():
                    make_fx(f, tracing_mode="fake")(x)
                with ctx():
                    make_fx(f, tracing_mode="symbolic")(x)
                with ctx():
                    torch.compile(f, backend="eager", fullgraph=True)(x)

    def test_data_ptr_access_fails_in_backward(self):
        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
            torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)

            backward_called = False

            class Foo(torch.autograd.Function):
                @staticmethod
                def forward(ctx, x):
                    return x.clone()

                @staticmethod
                def backward(ctx, grad):
                    nonlocal backward_called
                    backward_called = True
                    grad.data_ptr()
                    return grad.clone()

            @torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
            def _(x):
                return Foo.apply(x)

            def f(x):
                return torch.ops.mylib.foo(x)

            x = torch.randn(3, requires_grad=True)
            with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"):
                y = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
            self.assertTrue(backward_called)

    # We don't know how to catch multiple mutations to the same memory location
    @unittest.expectedFailure
    def test_aot_autograd_expand_mutation_error(self):
        def fn(x):
            y = x.expand(3, *x.shape)
            y[0:3, 0].add_(5)
            return y

        opt_fn = torch.compile(fn, backend="aot_eager")

        x = torch.arange(6)
        x_opt = x.clone().detach()
        with self.assertRaises(Exception):
            fn(x)
        with self.assertRaises(Exception):
            opt_fn(x_opt)

    @torch._functorch.config.patch(donated_buffer=True)
    def test_donated_buffer1(self):
        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"

        @torch.compile()
        def relu(x):
            return torch.nn.functional.relu(x)

        with self.assertLogs(logger_name, level="INFO") as captured:
            relu(torch.rand([3, 3], requires_grad=True)).sum().backward()

        if is_dynamic_shape_test(self._testMethodName):
            # an extra symint exists
            expected_msg = "bw_donated_idxs=[1]"
        else:
            expected_msg = "bw_donated_idxs=[0]"

        # le is a donated buffer from relu
        FileCheck().check(expected_msg).run("\n".join(captured.output))

    @torch._functorch.config.patch("donated_buffer", True)
    def test_donated_buffer2(self):
        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"

        # we will re-use the graph for g across f1 and f2
        @torch.compile()
        def g(activation, param2):
            return torch.matmul(activation, param2)

        def f(inp, param1, param2):
            activation = inp + param1
            return g(activation, param2)

        inp = torch.ones(4, 4)
        param1 = torch.ones(4, 4, requires_grad=True)
        param2 = torch.ones(4, 4, requires_grad=True)

        with self.assertLogs(logger_name, level="INFO") as captured:
            f(inp, param1, param2).sum().backward()

        FileCheck().check("bw_donated_idxs=[]").run("\n".join(captured.output))

    @torch._functorch.config.patch("donated_buffer", True)
    def test_donated_buffer3(self):
        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"

        # we will re-use the graph for g across f1 and f2
        @torch.compile()
        def g(activation, param2):
            return torch.matmul(activation, param2)

        def f(inp, param1, param2):
            # exp saves it output (the activation) for bw
            activation = torch.exp(inp + param1)
            return g(activation, param2)

        inp = torch.ones(4, 4)
        param1 = torch.ones(4, 4, requires_grad=True)
        param2 = torch.ones(4, 4, requires_grad=True)

        with self.assertLogs(logger_name, level="INFO") as captured:
            f(inp, param1, param2).sum().backward()

        FileCheck().check("bw_donated_idxs=[]").run("\n".join(captured.output))

    @torch._functorch.config.patch("donated_buffer", True)
    def test_donated_buffer4(self):
        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"

        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.zeros([2, 2]))

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.nn.functional.relu(x) + self.param

        mod = Mod()
        mod = torch.compile(mod)

        inp = torch.ones([2, 2], requires_grad=True)

        with self.assertLogs(logger_name, level="INFO") as captured:
            mod(inp).sum().backward()

        # Forward graph:
        #   %primals_1 : [num_users=1] = placeholder[target=primals_1]
        #   %primals_2 : [num_users=1] = placeholder[target=primals_2]
        #   %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%primals_2,), kwargs = {})
        #   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %primals_1), kwargs = {})
        #   %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
        #   return [add, le]
        #
        # `le` is a donated buffer
        FileCheck().check("bw_donated_idxs=[0]").run("\n".join(captured.output))

    @torch._functorch.config.patch("donated_buffer", True)
    def test_donated_buffer5(self):
        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"

        @torch.compile()
        def f(x, z):
            y = x.view(2, 3)
            z = torch.nn.functional.relu(z)
            return torch.mm(y, x) + z

        inp = [
            torch.rand([3, 2], requires_grad=True),
            torch.rand([2, 2], requires_grad=True),
        ]

        with self.assertLogs(logger_name, level="INFO") as captured:
            f(*inp).sum().backward()

        # Forward graph:
        #   %primals_1 : [num_users=3] = placeholder[target=primals_1]
        #   %primals_2 : [num_users=1] = placeholder[target=primals_2]
        #   %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%primals_1, [2, 3]), kwargs = {})
        #   %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%primals_2,), kwargs = {})
        #   %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %primals_1), kwargs = {})
        #   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mm, %relu), kwargs = {})
        #   %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
        #   return [add, primals_1, le]
        #
        # `le` is a donated buffer but primals_1 is not.
        FileCheck().check("bw_donated_idxs=[1]").run("\n".join(captured.output))

    @torch._functorch.config.patch("donated_buffer", True)
    def test_donated_buffer_with_retain_or_create_graph1(self):
        # Gives non-empty bw_donated_idxs
        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.zeros([3, 3]))

            def forward(self, x):
                return torch.nn.functional.relu(x) + self.param

        inp = torch.randn(3, 3, requires_grad=True)

        mod = torch.compile(Mod())
        for _ in range(5):
            mod(inp).sum().backward()

    @torch._functorch.config.patch("donated_buffer", True)
    def test_donated_buffer_with_retain_or_create_graph2(self):
        # Gives non-empty bw_donated_idxs
        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.zeros([3, 3]))

            def forward(self, x):
                return torch.nn.functional.relu(x) + self.param

        inp = torch.randn(3, 3, requires_grad=True)

        mod = torch.compile(Mod())
        out = mod(inp).sum()
        for _ in range(5):
            out.backward(retain_graph=True)
        out.backward()

    @torch._functorch.config.patch("donated_buffer", True)
    def test_donated_buffer_with_retain_or_create_graph3(self):
        # Gives non-empty bw_donated_idxs
        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.zeros([3, 3]))

            def forward(self, x):
                return torch.nn.functional.relu(x) + self.param

        inp = torch.randn(3, 3, requires_grad=True)

        mod = torch.compile(Mod())
        mod(inp).sum().backward(create_graph=True)
        out = mod(inp).sum()
        for _ in range(5):
            out.backward(retain_graph=True)
        out.backward()

    @torch._functorch.config.patch("donated_buffer", True)
    def test_donated_buffer_with_retain_or_create_graph4(self):
        # Gives non-empty bw_donated_idxs
        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.zeros([3, 3]))

            def forward(self, x):
                return torch.nn.functional.relu(x) + self.param

        inp = torch.randn(3, 3, requires_grad=True)

        mod = torch.compile(Mod())
        mod(inp).sum().backward()
        out = mod(inp).sum()
        with self.assertRaisesRegex(
            RuntimeError,
            r"This backward function was compiled with non-empty donated "
            r"buffers which requires create_graph=False and retain_graph=False. "
            r"Please keep backward\(create_graph=False, retain_graph=False\) "
            r"across all backward\(\) function calls, or set "
            r"torch._functorch.config.donated_buffer=False to disable "
            r"donated buffer.",
        ):
            out.backward(retain_graph=True)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
