# Owner(s): ["oncall: pt2"]

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import itertools
import unittest
import warnings
from contextlib import nullcontext
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import patch

from common_utils import decorate, decorateForModules, skip, skipOps, xfail

import torch
import torch._dynamo as torchdynamo
import torch.nn as nn
import torch.utils._pytree as pytree
from functorch import grad, jacrev, make_fx, vjp, vmap
from functorch.compile import (
    aot_function,
    aot_module,
    aot_module_simplified,
    compiled_function,
    compiled_module,
    default_decompositions,
    default_partition,
    get_aot_compilation_context,
    make_boxed_compiler,
    make_boxed_func,
    memory_efficient_fusion,
    min_cut_rematerialization_partition,
    nnc_jit,
    nop,
)
from functorch.experimental import control_flow
from torch._decomp import decomposition_table
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
from torch.nn.utils.rnn import PackedSequence
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests,
    ops,
    tol,
    toleranceOverride,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
    compare_equal_outs_and_grads,
    instantiate_parametrized_tests,
    IS_ARM64,
    IS_MACOS,
    IS_WINDOWS,
    IS_X86,
    outs_and_grads,
    parametrize,
    run_tests,
    skipIfRocm,
    skipIfTorchDynamo,
    TestCase,
    xfail_inherited_tests,
    xfailIfTorchDynamo,
)
from torch.testing._internal.custom_tensor import ConstantExtraMetadataTensor
from torch.testing._internal.hop_db import hop_db
from torch.testing._internal.optests import (
    _test_aot_autograd_forwards_backwards_helper,
    aot_autograd_check,
)
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode


USE_TORCHVISION = False
try:
    import torchvision

    USE_TORCHVISION = True
except ImportError:
    warnings.warn(
        "Couldn't import torchvision. Some of our tests use it, try "
        "to install it with commands from pytorch.org, post-fixed with "
        "`--no-deps` to avoid overwriting the pytorch installation",
        UserWarning,
    )

USE_NETWORKX = False
try:
    import networkx  # noqa: F401

    USE_NETWORKX = True
except ImportError:
    warnings.warn("Some tests use networkx but it was not installed", UserWarning)

# NB: numpy is a testing dependency!


class AOTTestCase(TestCase):
    pass


class TestPythonKey(AOTTestCase):
    def test_make_fx(self, device):
        def f(x):
            return torch.sin(x)

        inp = torch.randn(3)
        fx_f = make_fx(f)(inp)

        new_inp = torch.randn(3)
        self.assertEqual(fx_f(new_inp), f(new_inp))

    def test_make_fx_grad(self, device):
        def f(x):
            return torch.sin(x).sum()

        inp = torch.randn(3)
        f = grad(f)
        fx_f = make_fx(f)(inp)

        new_inp = torch.randn(3)
        self.assertEqual(fx_f(new_inp), f(new_inp))

    def test_scalar_device(self, device):
        def f(a, b):
            return a + b

        inps = [torch.randn(3, device=device), torch.tensor(5)]
        fx_f = make_fx(f)(*inps)
        self.assertEqual(fx_f(*inps), f(*inps))

    def test_make_fx_vmap(self, device):
        def f(x):
            return torch.sin(x)

        inp = torch.randn(5, 3)
        f = vmap(f)
        fx_f = make_fx(f)(inp)
        new_inp = torch.randn(5, 3)
        self.assertEqual(fx_f(new_inp), f(new_inp))

    def test_make_fx_jacrev(self, device):
        def f(x):
            return x.sin().sum()

        inp = torch.randn(3)
        f = jacrev(jacrev(f))
        fx_f = make_fx(f)(inp)
        new_inp = torch.randn(3)
        self.assertEqual(fx_f(new_inp), f(new_inp))

    def test_make_fx_vjp(self, device):
        def f(x):
            return torch.sin(x).sum()

        primals = torch.randn(3)
        _, vjp_fn = vjp(f, primals)
        cotangent = torch.randn(())
        fx_f = make_fx(vjp_fn)(cotangent, True, True)
        new_cotangent = torch.randn(())
        self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))

    def test_make_fx_functionalize(self, device):
        from functorch.experimental import functionalize

        def fn(a):
            a = a * 2
            a.relu_()
            return a

        a = torch.randn(3, device=device)
        symbolic_gm = torch.fx.symbolic_trace(fn)
        includes_method_relu_ = any(
            str(n.target) == "relu_" for n in symbolic_gm.graph.nodes
        )
        self.assertTrue(includes_method_relu_)
        # Also verifies fix for https://github.com/pytorch/pytorch/issues/84570
        gm = make_fx(functionalize(symbolic_gm))(a)
        includes_aten_relu = any(
            n.target == torch.ops.aten.relu.default for n in gm.graph.nodes
        )
        self.assertTrue(includes_aten_relu)

    def test_make_fx_no_decompose(self, device):
        # FIXME
        return self.skipTest("error: maximum recursion reached")

        def f(x):
            return torch.tanh(x).sum()

        fx_f = make_fx(grad(f))(torch.randn(5))
        ops = {i.target for i in fx_f.graph.nodes}

        self.assertEqual(torch.ops.aten.tanh_backward in ops, True)

        fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5))
        ops = {i.target for i in fx_f.graph.nodes}
        self.assertEqual(torch.ops.aten.tanh_backward in ops, False)

    def test_nnc_jit(self, device):
        def f(x):
            return torch.sin(x)

        jit_f = nnc_jit(f)

        inp = torch.randn(3)
        self.assertEqual(jit_f(inp), f(inp))

    def test_nnc_scalar(self, device):
        def f(x):
            return torch.sin(x)

        jit_f = nnc_jit(f)

        inp = torch.randn(())
        self.assertEqual(jit_f(inp), f(inp))

    def test_nnc_pytrees(self, device):
        def f(x):
            return [torch.sin(x[0])]

        jit_f = nnc_jit(f)

        inp = [torch.randn(3)]
        self.assertEqual(jit_f(inp), f(inp))

    def test_external_calls(self, device):
        def f(a, b):
            return torch.mv(a, b)

        jit_f = nnc_jit(f)
        inp = [torch.randn(3, 3), torch.randn(3)]
        self.assertEqual(jit_f(*inp), f(*inp))

    def test_nnc_passthrough(self, device):
        def f(x, y):
            return x + y, y

        inp = (torch.randn(3), torch.randn(3))
        jit_f = nnc_jit(f)
        self.assertEqual(jit_f(*inp), f(*inp))

        def f(x):
            x["a"] = x["a"] * 2
            return x

        inp = ({"a": torch.randn(3), "b": torch.randn(3)},)
        jit_f = nnc_jit(f)
        self.assertEqual(jit_f(*inp), f(*inp))

    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
    def test_resnet18_backward_trace(self, device):
        mod = torchvision.models.resnet18()

        def f(x):
            out = mod(x)
            out.sum().backward()
            return [a.grad for a in mod.parameters()]

        inp = torch.randn(3, 3, 250, 250, requires_grad=True)
        grads = f(inp)

        mod.zero_grad()
        mod(inp).sum().backward()
        grads2 = [a.grad for a in mod.parameters()]
        self.assertEqual(grads, grads2)


def get_base(t):
    return t._base if t._is_view() else t


def is_in_base(t, maybe_tensors):
    t_base = get_base(t)
    for maybe_tensor in maybe_tensors:
        if isinstance(maybe_tensor, torch.Tensor):
            if t_base is get_base(maybe_tensor):
                return True
    return False


def skipIfDynamoInput(reason):
    """
    Skip TestAOTAutograd if running with dynamo input
    """

    def decorator(func):
        @wraps(func)
        def wrapper(self, *args, **kwargs):
            if isinstance(self, TestAOTAutogradWithDynamo):
                self.skipTest(
                    f"Skipping {self._testMethodName} in TestAOTAutogradWithDynamo because {reason}"
                )
            else:
                func(self, *args, **kwargs)

        return wrapper

    return decorator


class TestAOTAutograd(AOTTestCase):
    def run_autograd(
        self,
        f: Callable,
        fw_graph_cell: List[Optional[Callable]],
        decompositions: Optional[Dict],
        keep_input_mutations: bool,
        dynamic: bool,
    ):
        """
        Runs aot_autograd with the specified settings on f.
        """
        if isinstance(f, nn.Module):
            compiled_f = aot_module(
                f,
                fw_compiler=make_boxed_compiler(
                    partial(extract_graph, graph_cell=fw_graph_cell)
                ),
                bw_compiler=nop,
                decompositions=decompositions,
                keep_inference_input_mutations=keep_input_mutations,
                dynamic=dynamic,
            )
        else:
            compiled_f = aot_function(
                f,
                fw_compiler=make_boxed_compiler(
                    partial(extract_graph, graph_cell=fw_graph_cell)
                ),
                bw_compiler=nop,
                decompositions=decompositions,
                keep_inference_input_mutations=keep_input_mutations,
                dynamic=dynamic,
            )
        return compiled_f

    # test_mutation will:
    # - Ensure that inputs are non-leaves, so our graphs can mutate them
    # - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs)
    @patch("functorch.compile.config.debug_assert", True)
    def verify_aot_autograd(
        self,
        f,
        inp_: Union[Callable, List[Any]],
        *,
        test_mutation: bool = False,
        keep_inp_mutations: bool = False,
        decompositions: Optional[Dict] = None,
        dynamic: bool = False,
        # Only active when inp_ is Callable.
        # TODO: probably consolidate all tests to make inp a Callable.
        make_inputs_subclasses: bool = False,
    ):
        def make_inputs(inp_):
            # Some tests pass in a callable for inp, to generate the inputs
            # (useful if we want to generate complicated aliasing inputs)
            if isinstance(inp_, Callable):
                inp_callable = inp_
                # The callable should return a tuple of f_inputs, f_graph_inputs
                # (The idea is that we might want to compile a function with the graph inputs,
                # but test autograd backprop all the way through the actual inputs)
                with TwoTensorMode() if make_inputs_subclasses else nullcontext():
                    inp, graph_inps = inp_callable()
            else:
                inp = []
                # Our input clones need to mimic when inputs are duplicates of one another
                dupes_map = {}
                for i, x in enumerate(inp_):
                    if x in dupes_map:
                        x_dupe_idx = dupes_map[x]
                        inp.append(inp[x_dupe_idx])
                    else:
                        dupes_map[x] = i
                        if not isinstance(x, torch.Tensor):
                            x_copy = x
                        else:
                            x_copy = x.clone().detach().requires_grad_(x.requires_grad)
                            if x.requires_grad and not x.is_leaf:
                                x_copy = x_copy.clone()

                        inp.append(x_copy)

                if test_mutation:
                    # For graphs where we mutate inputs, need our test to make sure inputs aren't leaves
                    graph_inps = [x.add(1) for x in inp]
                else:
                    graph_inps = inp

            return inp, graph_inps

        def check_results(
            ref_results,
            test_results,
            ref_graph_inps,
            test_graph_inps,
            ref_inp,
            test_inp,
        ):
            ref_out, ref_grad = ref_results
            test_out, test_grad = test_results
            self.assertEqual(ref_grad, test_grad)
            if isinstance(ref_out, torch.Tensor):
                self.assertTrue(isinstance(test_out, torch.Tensor))
                ref_out, test_out = [ref_out], [test_out]
            for ref_o, test_o in zip(ref_out, test_out):
                if isinstance(ref_o, torch.Tensor):
                    self.assertEqual(ref_o.requires_grad, test_o.requires_grad)
                    self.assertEqual(ref_o.is_leaf, test_o.is_leaf)
                    ref_is_view_of_non_interm = is_in_base(
                        ref_o, ref_graph_inps
                    ) or is_in_base(ref_o, ref_out)
                    test_is_view_of_non_interm = is_in_base(
                        test_o, test_graph_inps
                    ) or is_in_base(test_o, test_out)
                    self.assertEqual(
                        ref_is_view_of_non_interm, test_is_view_of_non_interm
                    )
                    self.assertEqual(ref_o, test_o)
                    if test_mutation:
                        # This tests that autograd meta is set properly on the output we can
                        # mutate it.
                        ref_o.add_(2)
                        test_o.add_(2)
                        self.assertEqual(ref_o, test_o)
                        # Reverse the modification
                        ref_o.sub_(2)
                        test_o.sub_(2)
                        self.assertEqual(ref_o, test_o)
            for ref_i, test_i in zip(ref_inp, test_inp):
                if isinstance(ref_i, torch.Tensor):
                    self.assertEqual(ref_i.requires_grad, test_i.requires_grad)
                self.assertEqual(ref_i, test_i)

        for keep_input_mutations in [True] if keep_inp_mutations else [True, False]:
            inp, graph_inps = make_inputs(inp_)
            test_inp, test_graph_inps = make_inputs(inp_)
            fw_graph_cell = [None]
            compiled_f = self.run_autograd(
                f, fw_graph_cell, decompositions, keep_input_mutations, dynamic
            )
            ref_results = outs_and_grads(f, graph_inps, inp)
            test_results = outs_and_grads(compiled_f, test_graph_inps, test_inp)

            check_results(
                ref_results, test_results, graph_inps, test_graph_inps, inp, test_inp
            )
            if isinstance(self, TestAOTAutogradWithCache):
                # When testing with cache, run compiled_f a second time
                cached_inp, cached_graph_inps = make_inputs(inp_)
                cached_results = outs_and_grads(
                    compiled_f, cached_graph_inps, cached_inp
                )
                check_results(
                    ref_results,
                    cached_results,
                    graph_inps,
                    cached_graph_inps,
                    inp,
                    cached_inp,
                )

        return fw_graph_cell[0]

    def test_non_tensor_and_none_inputs(self):
        # int, None, Tensor
        def f(a, b, c):
            return a * c

        inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=True)]
        self.verify_aot_autograd(f, inp)
        inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=False)]
        self.verify_aot_autograd(f, inp)

    def test_single_output(self):
        def f(a, b):
            return a + b

        inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
        self.verify_aot_autograd(f, inp)
        inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)]
        self.verify_aot_autograd(f, inp)

    def test_multi_output(self):
        def f(a, b):
            return a + b, a - b

        inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
        self.verify_aot_autograd(f, inp)
        inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)]
        self.verify_aot_autograd(f, inp)

    def test_multi_output_list(self):
        def f(a, b):
            return [a + b, a - b]

        inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
        self.verify_aot_autograd(f, inp)
        inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)]
        self.verify_aot_autograd(f, inp)

    # Test for bug occurring at the intersection of fake tensors & functionalization.
    def test_squeeze_mutation(self):
        def f(a):
            b = a.clone().squeeze(-1)
            b.add_(1.0)
            return a + b

        inp = [torch.randn(3, 1, requires_grad=True)]
        self.verify_aot_autograd(f, inp, dynamic=True)
        inp = [torch.randn(3, 1, requires_grad=False)]
        self.verify_aot_autograd(f, inp, dynamic=True)

    def test_complex_linear(self):
        # https://github.com/pytorch/pytorch/issues/93424
        inp = [torch.randn(1, 10, 10, dtype=torch.complex64)]

        class F(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = nn.Linear(10, 10, dtype=torch.complex64)

            def forward(self, x):
                return self.linear(x).sum().abs()

        self.verify_aot_autograd(F(), inp)

    def test_embedding_bag_view_dynamic(self):
        # Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper;
        # test that this works even though the sparse tensor has no storage.

        class F(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.emb = torch.nn.EmbeddingBag(100, 8, sparse=True)

            def forward(self, x, y):
                return self.emb(x, y).view(-1)

        x = torch.arange(3)
        y = torch.arange(3)
        self.verify_aot_autograd(F(), [x, y], dynamic=False)
        self.verify_aot_autograd(F(), [x, y], dynamic=True)

    def test_input_mutation_simple(self):
        def f(a):
            a.mul_(2)
            return a * 3

        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        # Things to note:
        # - the extra clone is because we need to pass the pre-mutated input to grad(),
        #   but autograd operates above functionalization so we need to manually clone.
        #   Hopefully backends can optimize this easily.
        # - The extra return arg is because the compiled forward returns (mutated inputs + outputs)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    mul = torch.ops.aten.mul.Tensor(clone, 2);  clone = None
    mul_1 = torch.ops.aten.mul.Tensor(mul, 3)
    return (mul, mul_1)""",
        )

    def test_input_mutation_set__input_mutation(self):
        def f(a):
            b = torch.arange(9, dtype=a.dtype).reshape(3, 3)
            with torch.no_grad():
                a.set_(b)
            return a * b

        inp = [torch.ones(3, 3, requires_grad=True)]
        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)

    def test_set__steals_view_chain(self):
        def f(a, b):
            a_ = a.mul(2)
            b_ = b.mul(2)
            b_slice = b_[1].view(3, 3)
            # a_clone should inherit the view chain from b_slice
            a_.set_(b_slice)
            # Also mutates b_,
            a_.view(-1).mul_(2)
            return a_ * b_slice

        inp = [
            torch.ones(3, 3, requires_grad=False),
            torch.zeros(3, 9, requires_grad=False),
        ]
        self.verify_aot_autograd(f, inp, keep_inp_mutations=True)

    @skipIfDynamoInput(
        "Test doesn't make sense with dynamo, which changes order of mutations"
    )
    def test_set__and_data_mutation_good(self):
        def f(a, b):
            # The data mutation happens *after* the set_(). This is ok (see the graph below)
            with torch.no_grad():
                a.set_(b)
                b.mul_(2)
            return a + b

        inp = [
            torch.ones(3, 3, requires_grad=True),
            torch.ones(3, 3, requires_grad=True),
        ]
        fw_graph = self.verify_aot_autograd(
            f, inp, test_mutation=True, keep_inp_mutations=True
        )
        inp = [
            torch.ones(3, 3, requires_grad=False),
            torch.zeros(3, 3, requires_grad=False),
        ]
        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
        # Important things to note:
        # - "return a.set_(b)" desugars into "return b"
        # - Both a and b are recorded as experiencing mutations,
        #   which is why we see "b_updated" (output of the mul) twice in the graph outputs.
        #   a is recorded as both a data mutation and a metadata mutation (due to set_ swapping its storage).
        # - the runtime epilogue for a is "a.set_(mul)"
        # - the runtime epilogue for b is "b.copy_(mul)"
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2):
    mul = torch.ops.aten.mul.Tensor(primals_2, 2)
    add = torch.ops.aten.add.Tensor(mul, mul)
    set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul);  primals_1 = set_ = None
    copy_ = torch.ops.aten.copy_.default(primals_2, mul);  primals_2 = mul = copy_ = None
    return (add,)""",
        )

    # This is a (hopefully) extremely rare case that is difficult to handle,
    # so we ban it.
    # https://github.com/pytorch/pytorch/issues/126236
    # https://github.com/pytorch/pytorch/pull/126113
    @xfailIfTorchDynamo
    def test_set__and_data_mutation_bad(self):
        def f(a):
            a_view = a.view(-1)
            tmp = torch.ones(3, 3, requires_grad=True)
            # Now, any mutations on either tmp
            # will be tracked as graph input mutations.
            with torch.no_grad():
                a.set_(tmp)
                # BAD: a_view is now detached from every graph input,
                # so we won't recognize that this caused an input mutation!
                a_view.mul_(2)
            return a + tmp

        inp = [torch.ones(3, 3, requires_grad=True)]
        with self.assertRaisesRegex(
            RuntimeError, "cannot mutate tensors with frozen storage"
        ):
            self.verify_aot_autograd(
                f, inp, test_mutation=True, keep_inp_mutations=True
            )

    @skipIfDynamoInput(
        "Test doesn't make sense with dynamo, which changes order of mutations"
    )
    def test_set__not_allowed(self):
        def f(a, b):
            with torch.no_grad():
                a.set_(b)
            # Mutating a will change a's grad_fn, which requires us to replay the mutation outside of the graph.
            # We currently ban this today, when the input also received a set_() input mutation.
            a.mul_(2)
            return a + b

        inp = [
            torch.ones(3, 3, requires_grad=True),
            torch.ones(3, 3, requires_grad=True),
        ]
        with self.assertRaisesRegex(
            AssertionError, "but the input has other mutations that we cannot"
        ):
            fw_graph = self.verify_aot_autograd(
                f, inp, test_mutation=True, keep_inp_mutations=True
            )

    def test_input_mutation_set__nop(self):
        def f(a):
            b = torch.arange(9, dtype=a.dtype)
            a_old = torch.ops.aten.alias.default(a)
            with torch.no_grad():
                a.set_(b)
                a.set_(a_old)
            return a + b.reshape(3, 3)

        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(
            f, inp, test_mutation=True, keep_inp_mutations=True
        )
        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
        # Things to note:
        # - There are no set_() calls in the graph (we functionalize a.set_(b) into "b")
        # - There is only **1** graph output. We properly realized that the two set_() calls
        #   undo each other, and so effectively no inputs are mutated.
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    alias = torch.ops.aten.alias.default(primals_1);  primals_1 = None
    view = torch.ops.aten.view.default(arange, [3, 3]);  arange = None
    add = torch.ops.aten.add.Tensor(alias, view);  alias = view = None
    return (add,)""",
        )

    @unittest.skipIf(IS_WINDOWS, "TODO: need to fix the test case")
    @unittest.skipIf(IS_MACOS, "TODO: need to fix the test case")
    def test_input_mutation_fsdp_set__into_same_input(self):
        import torch.distributed._composable.fsdp._fsdp_param

        def f(a):
            b = torch.arange(9, dtype=a.dtype).view(3, 3)
            c = torch.arange(9, dtype=a.dtype).view(3, 3)
            d = torch.arange(9, dtype=a.dtype).view(3, 3)
            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
                torch.ops.fsdp.set_.default(a, b)
            x = a * a
            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
                torch.ops.fsdp.set_.default(a, c)
            y = a * a
            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
                torch.ops.fsdp.set_.default(a, c)
            z = a * a
            return x + y + z

        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(
            f, inp, test_mutation=True, keep_inp_mutations=True
        )
        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
        """
        Expected behavior:
        (1) When there are multiple set_() calls on the same graph input primal_X,
        we want those set_() calls to all show up with primal_X as the first arg in the graph.
        (2) Behavior (1) is not the case today with normal aten.set_ (blocked on #129892),
        but using a custom fsdp.set_ op with no returns is a simple workaround to achieve that behavior.
        """
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    view = torch.ops.aten.view.default(arange, [3, 3]);  arange = None
    arange_1 = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    view_1 = torch.ops.aten.view.default(arange_1, [3, 3]);  arange_1 = None
    set_ = torch.ops.fsdp.set_.default(primals_1, view);  view = set_ = None
    mul = torch.ops.aten.mul.Tensor(primals_1, primals_1)
    set__1 = torch.ops.fsdp.set_.default(primals_1, view_1);  set__1 = None
    mul_1 = torch.ops.aten.mul.Tensor(primals_1, primals_1)
    set__2 = torch.ops.fsdp.set_.default(primals_1, view_1);  view_1 = set__2 = None
    mul_2 = torch.ops.aten.mul.Tensor(primals_1, primals_1)
    add = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
    add_1 = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None
    return (add_1, primals_1)""",
        )
        self.assertEqual(torch.compile(f, backend="inductor")(*inp), f(*inp))

    def test_input_mutation_simple_with_none_and_nontensor(self):
        # Tensor, None, int
        def f(a, b, c):
            return a * c

        f_compiled = aot_function(f, nop)
        for req_grad in [True, False]:
            inp = [torch.ones(3, 3, requires_grad=req_grad), None, 3]
            out_ref = f(*inp)
            out_test = f_compiled(*inp)
            self.assertEqual(out_ref, out_test)

    # https://github.com/pytorch/pytorch/issues/93363
    def test_mutates_input_noncontiguous(self):
        def f(a):
            a.add_(1)
            return ()

        f_compiled = aot_function(f, nop)
        ref = torch.ones(4, requires_grad=True) + 0
        ref_view = ref[0::2]

        test = torch.ones(4, requires_grad=True) + 0
        test_view = test[0::2]

        out_ref = f(ref_view)
        out_test = f_compiled(test_view)
        self.assertEqual(ref, test)

    def test_input_mutation_modifies_autograd_meta_of_aliases(self):
        def f(a):
            a.mul_(2)
            out = a + 1
            return out.detach()

        x_ref = torch.ones(3, 3, requires_grad=True).clone()
        x_ref_view = x_ref.view(3, 3)

        x_test = torch.ones(3, 3, requires_grad=True).clone()
        x_test_view = x_test.view(3, 3)

        f_compiled = aot_function(f, nop, keep_inference_input_mutations=True)
        f(x_ref)
        f_compiled(x_test)
        # f will mutate aliases of the input, including its autograd metadata!
        # y.grad_fn is AsStridedBackward
        self.assertEqual(x_ref_view, x_test_view)
        self.assertEqual(x_ref_view._version, x_test_view._version)
        self.assertEqual(x_ref_view.grad_fn.__class__, x_test_view.grad_fn.__class__)
        # Test the actual gradients are correct
        (x_ref * x_ref_view).sum().backward()
        (x_test * x_test_view).sum().backward()
        self.assertEqual(x_ref.grad, x_test.grad)
        self.assertEqual(x_ref_view.grad, x_test_view.grad)

    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
    def test_nested_subclasses(self):
        @torch.compile(backend="aot_eager")
        def f(x):
            return x.sin().cos()

        a = torch.ones(4, requires_grad=True)
        a2 = a.clone().detach().requires_grad_()
        a3 = a.clone().detach().requires_grad_()
        a4 = a.clone().detach().requires_grad_()
        aa = TwoTensor(a, a2)
        aa2 = TwoTensor(a3, a4)
        aaaa = TwoTensor(aa, aa2)
        out = f(aaaa)
        self.assertTrue(isinstance(out, TwoTensor))
        self.assertTrue(isinstance(out.a, TwoTensor))
        self.assertTrue(isinstance(out.b, TwoTensor))
        self.assertTrue(isinstance(out.a.a, torch.Tensor))
        self.assertTrue(isinstance(out.a.b, torch.Tensor))
        self.assertTrue(isinstance(out.b.a, torch.Tensor))
        self.assertTrue(isinstance(out.b.b, torch.Tensor))

        out.sum().backward()
        self.assertTrue(isinstance(aaaa.grad, TwoTensor))
        self.assertTrue(isinstance(aaaa.grad.a, TwoTensor))
        self.assertTrue(isinstance(aaaa.grad.b, TwoTensor))

    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
    def test_nested_subclasses_non_nested_grad(self):
        @torch.compile(backend="aot_eager")
        def f(x):
            return x.sin().cos()

        a = torch.ones(4, requires_grad=True)
        a2 = a.clone().detach().requires_grad_()
        a3 = a.clone().detach().requires_grad_()
        a4 = a.clone().detach().requires_grad_()
        new_aa = TwoTensor(a3, a4)
        aa = TwoTensor(a, a2)

        aa2 = aa.clone().detach().requires_grad_()
        aaaa = TwoTensor(aa, aa2)
        out = f(new_aa)
        new_out = out + aaaa
        with self.assertRaisesRegex(
            RuntimeError,
            "The grad inputs should be same tensor subclass type as forward output",
        ):
            new_out.sum().backward()

    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
    def test_custom_tensor_metadata(self):
        def f(x):
            x_elem = x.elem
            x_elem_elem = x_elem.elem
            x_elem_metadata = x_elem.constant_attribute
            return x * x_elem * x_elem_elem * x_elem_metadata

        a = torch.ones(4, requires_grad=True)
        custom_a = ConstantExtraMetadataTensor(a)
        custom_a.constant_attribute = 6
        custom_aa = ConstantExtraMetadataTensor(custom_a)
        custom_aa.constant_attribute = 4

        custom_aa_compile = custom_aa.clone().detach().requires_grad_()
        custom_aa_compile.elem.constant_attribute = 6
        out_eager = f(custom_aa)

        compiled_f = torch.compile(f, backend="aot_eager")
        out = compiled_f(custom_aa_compile)

        self.assertTrue(torch.allclose(out_eager, out))

        out.sum().backward()

        self.assertTrue(isinstance(custom_aa_compile.grad, ConstantExtraMetadataTensor))
        self.assertTrue(
            isinstance(custom_aa_compile.grad.elem, ConstantExtraMetadataTensor)
        )

    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
    def test_nested_subclasses_complicated_inps(self):
        def f(x, y, z):
            temp = x + y
            temp_plain = x.a + y.b
            res = temp.sum() + temp_plain.sum()
            return x.sin().cos() + res

        x = torch.ones(4, requires_grad=True)
        x2 = x.clone().detach().requires_grad_()
        xx = TwoTensor(x, x2)
        xx2 = xx.clone().detach().requires_grad_()

        x_nested = TwoTensor(xx, xx2)
        x_nested_compile = x_nested.clone().detach().requires_grad_()

        y_nested = x_nested.clone().detach().requires_grad_()
        y_nested_compile = y_nested.clone().detach().requires_grad_()

        z = x.clone().detach().requires_grad_()
        z_compile = z.clone().detach().requires_grad_()

        out_eager = f(x_nested, y_nested, z)
        compiled_f = torch.compile(f, backend="aot_eager")
        out = compiled_f(x_nested_compile, y_nested_compile, z_compile)
        self.assertTrue(torch.allclose(out_eager, out))

        self.assertTrue(isinstance(out, TwoTensor))
        self.assertTrue(isinstance(out.a, TwoTensor))
        self.assertTrue(isinstance(out.b, TwoTensor))
        self.assertTrue(isinstance(out.a.a, torch.Tensor))
        self.assertTrue(isinstance(out.a.b, torch.Tensor))
        self.assertTrue(isinstance(out.b.a, torch.Tensor))
        self.assertTrue(isinstance(out.b.b, torch.Tensor))

        out.sum().backward()
        out_eager.sum().backward()

        self.assertTrue(isinstance(x_nested_compile.grad, TwoTensor))
        self.assertTrue(isinstance(x_nested_compile.grad.a, TwoTensor))
        self.assertTrue(isinstance(x_nested_compile.grad.b, TwoTensor))

        self.assertTrue(isinstance(y_nested_compile.grad, TwoTensor))
        self.assertTrue(isinstance(y_nested_compile.grad.a, TwoTensor))
        self.assertTrue(isinstance(y_nested_compile.grad.b, TwoTensor))

        self.assertTrue(torch.allclose(x_nested_compile.grad.a.a, x_nested.grad.a.a))
        self.assertTrue(torch.allclose(x_nested_compile.grad.a.b, x_nested.grad.a.b))
        self.assertTrue(torch.allclose(y_nested_compile.grad.a.a, y_nested.grad.a.a))
        self.assertTrue(torch.allclose(y_nested_compile.grad.a.b, y_nested.grad.a.b))

    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
    def test_nested_subclasses_complicated_inps_mixed(self):
        def f(x, y):
            y_elem = y.elem
            y_elem_elem = y_elem.elem
            y_elem_metadata = y_elem.constant_attribute
            return y * y_elem * y_elem_elem * y_elem_metadata + x

        x = torch.ones(4, requires_grad=True)
        x2 = x.clone().detach().requires_grad_()
        xx = TwoTensor(x, x2)
        xx2 = xx.clone().detach().requires_grad_()

        x_nested = TwoTensor(xx, xx2)
        x_nested_compile = x_nested.clone().detach().requires_grad_()

        a = torch.ones(4, requires_grad=True)
        custom_a = ConstantExtraMetadataTensor(a)
        custom_a.constant_attribute = 6
        custom_aa = ConstantExtraMetadataTensor(custom_a)
        custom_aa.constant_attribute = 4

        custom_aa_compile = custom_aa.clone().detach().requires_grad_()
        custom_aa_compile.constant_attribute = 4
        custom_aa_compile.elem.constant_attribute = 6

        compiled_f = torch.compile(f, backend="aot_eager")
        out_eager = f(x_nested, custom_aa)
        out = compiled_f(x_nested_compile, custom_aa_compile)
        self.assertTrue(torch.allclose(out_eager, out))

        out.sum().backward()
        out_eager.sum().backward()

        self.assertTrue(torch.allclose(x_nested_compile.grad, x_nested.grad))
        self.assertTrue(torch.allclose(custom_aa_compile.grad, custom_aa.grad))

    @skipIfTorchDynamo("This test suite already uses dynamo")
    def test_composite_impl_compile(self):
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)

            def forward(self, a):
                return self.linear(a)

        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(Foo(), inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=False)]
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3):
    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
    addmm = torch.ops.aten.addmm.default(primals_2, primals_3, t);  primals_2 = None
    return (addmm, primals_3, t)""",
        )

        with torch.inference_mode():
            fw_graph = self.verify_aot_autograd(Foo(), inp, test_mutation=True)
            inp = [torch.ones(3, 3, requires_grad=False)]
            self.assertExpectedInline(
                fw_graph.code.strip(),
                """\
def forward(self, arg0_1, arg1_1, arg2_1):
    t = torch.ops.aten.t.default(arg0_1);  arg0_1 = None
    addmm = torch.ops.aten.addmm.default(arg1_1, arg2_1, t);  arg1_1 = arg2_1 = t = None
    return (addmm,)""",
            )

    def test_outputs_are_aliased(self):
        # Tensor, None, int
        def f(a):
            b = a.mul(2)
            c = b.view(-1)
            return b, c

        f_compiled = aot_function(f, nop)
        for req_grad in [True, False]:
            inp = torch.ones(3, requires_grad=req_grad)
            out_ref = f(inp)
            out_test = f_compiled(inp)
            self.assertEqual(out_ref[0], out_test[0])
            self.assertEqual(out_ref[1], out_test[1])
            # Try mutating one of the outputs, which is aliased.
            out_ref[0].mul_(3)
            out_test[0].mul_(3)
            # Assert that the aliasing relationship was preserved
            self.assertEqual(out_ref[0], out_test[0])
            self.assertEqual(out_ref[1], out_test[1])

    def test_input_mutation_is_output(self):
        def f(a):
            a.mul_(2)
            return a

        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    mul = torch.ops.aten.mul.Tensor(clone, 2);  clone = None
    return (mul, mul)""",
        )

    def test_input_mutation_multiple(self):
        def f(a, b, c):
            a.mul_(2)
            c.mul_(2)
            return a + b + c

        def create_inp(req_grad):
            return [
                torch.ones(3, 3, requires_grad=req_grad),
                torch.ones(3, 3, requires_grad=req_grad),
                torch.ones(3, 3, requires_grad=req_grad),
            ]

        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)

        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    clone_1 = torch.ops.aten.clone.default(primals_3);  primals_3 = None
    mul = torch.ops.aten.mul.Tensor(clone, 2);  clone = None
    mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2);  clone_1 = None
    add = torch.ops.aten.add.Tensor(mul, primals_2);  primals_2 = None
    add_1 = torch.ops.aten.add.Tensor(add, mul_1);  add = None
    return (mul, mul_1, add_1)""",
        )

    def test_input_mutation_return(self):
        def f(a, b):
            return torch.sin(a, out=b)

        inp = [torch.randn(3, 3), torch.ones(3, 3)]

        fw_graph = self.verify_aot_autograd(
            f, inp, test_mutation=True, keep_inp_mutations=True
        )
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)""",
        )

    def test_input_mutation_metadata(self):
        def f(a, b):
            a.transpose_(1, 0)
            return a + b

        def create_inp(req_grad):
            return [
                torch.ones(3, 3, requires_grad=req_grad),
                torch.ones(3, 3, requires_grad=req_grad),
            ]

        self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)

    def test_input_mutation_storage_resize_up(self):
        def f(a):
            torch.ops.inductor.resize_storage_bytes_(a, 32)
            # float32, 4 bytes per element, 32 bytes == 8 elements
            with torch.no_grad():
                a.copy_(torch.ones(8))
            return a + 1

        inp = torch.zeros(8, requires_grad=True)
        # Input starts with zero-size-storage
        inp.untyped_storage().resize_(0)

        fw_graph_cell = [None]
        compiled_f = aot_function(
            f,
            fw_compiler=make_boxed_compiler(
                partial(extract_graph, graph_cell=fw_graph_cell)
            ),
            bw_compiler=nop,
            decompositions={},
            keep_inference_input_mutations=True,
            dynamic=False,
        )
        out = compiled_f(inp)
        # Final functionalized graph has two mutation ops:
        # (1) a resize_() to resize input tensor up
        # (2) a copy_() to fill in the resized input with valid data
        self.assertExpectedInline(
            fw_graph_cell[0].code.strip(),
            """\
def forward(self, primals_1):
    resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 32);  resize_storage_bytes_ = None
    ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False)
    copy = torch.ops.aten.copy.default(primals_1, ones);  ones = None
    add = torch.ops.aten.add.Tensor(copy, 1)
    copy_ = torch.ops.aten.copy_.default(primals_1, copy);  primals_1 = copy = copy_ = None
    return (add,)""",
        )

    def test_input_mutation_storage_resize_down(self):
        def f(a):
            out = a.sin()
            torch.ops.inductor.resize_storage_bytes_(a, 0)
            return out

        inp = torch.zeros(8, requires_grad=True)

        fw_graph_cell = [None]
        compiled_f = aot_function(
            f,
            fw_compiler=make_boxed_compiler(
                partial(extract_graph, graph_cell=fw_graph_cell)
            ),
            bw_compiler=nop,
            decompositions={},
            keep_inference_input_mutations=True,
            dynamic=False,
        )
        out = compiled_f(inp)
        # Final functionalized graph has one mutation ops:
        # (1) a resize_() to resize input tensor down
        # Even though there was technically a "data mutation" on the input (from a.copy_()),
        # We don't include it in the graph since the final input size has zero storage
        self.assertExpectedInline(
            fw_graph_cell[0].code.strip(),
            """\
def forward(self, primals_1):
    sin = torch.ops.aten.sin.default(primals_1)
    resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0);  resize_storage_bytes_ = None
    return (sin, primals_1)""",
        )

    #     def test_input_mutation_storage_resize_up_down(self):
    #         def f(a):
    #             torch.ops.inductor.resize_storage_bytes_(a, 32)
    #             # float32, 4 bytes per element, 32 bytes == 8 elements
    #             with torch.no_grad():
    #                 a.copy_(torch.ones(8))
    #             out = a.sin()
    #             torch.ops.inductor.resize_storage_bytes_(a, 0)
    #             return out

    #         inp = torch.zeros(8, requires_grad=True)
    #         # Input starts with zero-size-storage
    #         inp.untyped_storage().resize_(0)

    #         fw_graph_cell = [None]
    #         compiled_f = aot_function(
    #             f,
    #             fw_compiler=make_boxed_compiler(
    #                 partial(extract_graph, graph_cell=fw_graph_cell)
    #             ),
    #             bw_compiler=nop,
    #             decompositions={},
    #             keep_inference_input_mutations=True,
    #             dynamic=False,
    #         )
    #         out = compiled_f(inp)
    #         # Final graph has two interesting properties:
    #         # (1) no resizes in the functional graph, since the two resizes cancel out
    #         #     and the final size is zero
    #         # (2) no copy_ in the functional graph, even though we copied data into the input,
    #         #     because the input has no storage at the end of graph execution (so no data to copy)
    #         self.assertExpectedInline(
    #             fw_graph_cell[0].code.strip(),
    #             """\
    # def forward(self, primals_1):
    #     ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False)
    #     copy = torch.ops.aten.copy.default(primals_1, ones);  primals_1 = ones = None
    #     sin = torch.ops.aten.sin.default(copy)
    #     return [sin, copy]""",
    #         )

    def test_input_mutation_storage_resize_down_and_set_(self):
        # Meant to mimic ppFSDP
        class TracableCreateParameter(torch.autograd.Function):
            @staticmethod
            def forward(ctx, tensor, placeholder):
                assert not tensor.requires_grad
                return placeholder.set_(tensor)

            @staticmethod
            def backward(ctx, grad):
                return None, grad  # grad flows to placeholder

        def f(dummy_param, param_shard):
            # simulate allgather
            with torch.no_grad():
                allgather_param = torch.cat([param_shard, param_shard])
            # simulate propagating grad state through dummy param, using data of allgather param
            dummy_param_with_grad_state = TracableCreateParameter.apply(
                allgather_param, dummy_param
            )
            out = dummy_param.sin()
            # Resize out dummy param, which now has the allgather data
            torch.ops.inductor.resize_storage_bytes_(dummy_param, 0)
            return out

        # Simulates the local shard of our param
        param_shard = torch.zeros(8, requires_grad=True)
        # The dummy, zero-sized allgathered param that autograd will actually compute gradients on
        dummy_param = torch.zeros(16, requires_grad=True)
        dummy_param.untyped_storage().resize_(0)

        fw_graph_cell = [None]
        compiled_f = aot_function(
            f,
            fw_compiler=make_boxed_compiler(
                partial(extract_graph, graph_cell=fw_graph_cell)
            ),
            bw_compiler=nop,
            decompositions={},
            keep_inference_input_mutations=True,
            dynamic=False,
        )
        out = compiled_f(dummy_param, param_shard)
        # Important stuff to point out:
        # (1) We save cat for backward (input to the sin()).
        #     While the original code was dummy_param.sin(),
        #     dummy_param actually contains the `cat` tensor due to the set_() call
        # (2) We emit a cat.resize_storage_(0) in the graph.
        #     After the set_(), cat is the actually data of dummy_param, which is what we call resize_() on
        self.assertExpectedInline(
            fw_graph_cell[0].code.strip(),
            """\
def forward(self, primals_1, primals_2):
    cat = torch.ops.aten.cat.default([primals_2, primals_2]);  primals_2 = None
    sin = torch.ops.aten.sin.default(cat)
    resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0);  resize_storage_bytes_ = None
    set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat);  primals_1 = set_ = None
    return (sin, cat)""",
        )

    def test_input_mutation_storage_resize_before_set_(self):
        def f(a):
            with torch.no_grad():
                torch.ops.inductor.resize_storage_bytes_(a, 0)
                a.set_(torch.ones(2))

        inp = torch.zeros(8, requires_grad=True)

        compiled_f = aot_function(
            f,
            fw_compiler=nop,
            bw_compiler=nop,
            decompositions={},
            keep_inference_input_mutations=True,
            dynamic=False,
        )
        out = compiled_f(inp)

    # def test_input_mutation_storage_resize_not_supported(self):
    #     def f(a):
    #         a.mul_(2)
    #         torch.ops.inductor.resize_storage_bytes_(a, 0)
    #         return a

    #     inp = torch.zeros(8, requires_grad=True)

    #     with self.assertRaisesRegex(
    #         AssertionError, "the input has other mutations that we cannot"
    #     ):
    #         compiled_f = aot_function(
    #             f,
    #             fw_compiler=nop,
    #             bw_compiler=nop,
    #             decompositions={},
    #             keep_inference_input_mutations=True,
    #             dynamic=False,
    #         )
    #         out = compiled_f(inp)

    def test_input_output_aliase_custom_autograd_function(self):
        class Foo(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                return x

            @staticmethod
            def backward(ctx, gx):
                return gx * 0.5

        def f(x):
            return Foo.apply(x)

        inp = [torch.ones(2, 2, requires_grad=True)]
        self.verify_aot_autograd(f, inp, test_mutation=False)

    def test_input_mutation_requires_grad_detach(self):
        # Here, "a" requires grad, and gets mutated, so we append a copy_() to the end of the graph.
        # Its mutation doesn't take part in autograd though, because we mutated a detach'd view.
        # Need to make sure that this copy_() doesn't error, and doesn't participate in autograd either.
        def f(a):
            a.detach().mul_(2)
            return a + 3

        inp = [torch.ones(4, requires_grad=True)]
        self.verify_aot_autograd(f, inp, test_mutation=False)
        inp = [torch.ones(4, requires_grad=True)]
        # test_mutation=True will first do some compute on inp, so it is no longer an autograd leaf
        # by the time it becomes a graph input. Good to test both cases.
        self.verify_aot_autograd(f, inp, test_mutation=True)

    def test_input_mutation_hidden_from_autograd_aliasing(self):
        def f(a):
            a_alias = a.view(-1)
            with torch.no_grad():
                a_alias.mul_(2)
            return a + 1

        inp = [torch.ones(4, requires_grad=True)]
        # The important bit: we detected that the input mutation is safe
        # to include **inside** the graph, since it was under no_grad
        # (so all we need to do is use mark_dirty() on the input to bump the VC)
        fw_graph = self.verify_aot_autograd(
            f, inp, test_mutation=True, keep_inp_mutations=True
        )
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    view = torch.ops.aten.view.default(primals_1, [-1])
    mul = torch.ops.aten.mul.Tensor(view, 2);  view = None
    view_1 = torch.ops.aten.view.default(mul, [4]);  mul = None
    add = torch.ops.aten.add.Tensor(view_1, 1)
    copy_ = torch.ops.aten.copy_.default(primals_1, view_1);  primals_1 = view_1 = copy_ = None
    return (add,)""",
        )

    def test_input_mutation_requires_grad_no_grad(self):
        def f(a):
            with torch.no_grad():
                a.mul_(2)
            return a + 3

        inp = [torch.ones(4, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(
            f, inp, test_mutation=True, keep_inp_mutations=True
        )
        # Even though the input requires_grad, we expect the keep the input mutation in the graph
        # (Even though this is a training graph!)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    mul = torch.ops.aten.mul.Tensor(primals_1, 2)
    add = torch.ops.aten.add.Tensor(mul, 3)
    copy_ = torch.ops.aten.copy_.default(primals_1, mul);  primals_1 = mul = copy_ = None
    return (add,)""",
        )

    def test_input_mutation_requires_grad_no_grad_inference_graph(self):
        def f(a):
            with torch.no_grad():
                a.mul_(2)
                return a + 3

        inp = [torch.ones(4, requires_grad=True)]
        # Even though the input requires_grad, we expect the keep the input mutation in the graph
        fw_graph = self.verify_aot_autograd(
            f, inp, test_mutation=True, keep_inp_mutations=True
        )

        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, arg0_1):
    mul = torch.ops.aten.mul.Tensor(arg0_1, 2)
    add = torch.ops.aten.add.Tensor(mul, 3)
    copy_ = torch.ops.aten.copy_.default(arg0_1, mul);  arg0_1 = mul = copy_ = None
    return (add,)""",
        )

    def test_input_mutation_requires_grad_no_grad_detach_mixed(self):
        # Perform a mix of mutations on a:
        # 1 normal, 1 in no_grad, 1 on a detach'd tensor.
        # Only the first should participate in gradient computation.
        def f(a):
            a.detach().mul_(2)
            a.mul_(3)
            with torch.no_grad():
                a.mul_(4)
            return a + 5

        inp = [torch.ones(4, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)

    def test_input_mutation_metadata2(self):
        def f(a):
            a.transpose_(1, 0)
            a.mul_(2)
            return a + 1

        inp = [torch.ones(3, 3, requires_grad=True)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)

    def test_input_mutation_batchnorm(self):
        def f(inpt, weight, bias, running_mean, running_var):
            # This is additionally a good test, because the input tensors that we mutate
            # are *also* saved for backwards.
            # This tests that what we save for the backward is actually cloned inputs,
            # and not the original inputs that got mutated.
            return torch._native_batch_norm_legit(
                inpt, weight, bias, running_mean, running_var, True, 0.5, 1e-5
            )

        def create_inp(req_grad):
            return [
                torch.ones(2, 5, 5, 5, requires_grad=req_grad),
                torch.ones(5, requires_grad=req_grad),
                torch.ones(5, requires_grad=req_grad),
                torch.ones(5),
                torch.ones(5),
            ]

        from torch._decomp import get_decompositions

        # This simulates what inductor does (running the fw + bw decompositions)
        decompositions = get_decompositions(
            [
                torch.ops.aten._native_batch_norm_legit_functional,
                torch.ops.aten.native_batch_norm_backward,
            ]
        )
        self.verify_aot_autograd(
            f, create_inp(True), test_mutation=True, decompositions=decompositions
        )
        self.verify_aot_autograd(
            f, create_inp(False), test_mutation=True, decompositions=decompositions
        )

    def test_batchnorm_inference(self):
        inp = [
            torch.ones(2, 5, 5, 5, requires_grad=True),
            torch.ones(5, requires_grad=True),
            torch.ones(5, requires_grad=True),
            torch.ones(5),
            torch.ones(5),
        ]

        m = torch.nn.BatchNorm2d(4, 4)
        m.eval()
        fw_graph_cell = [None]
        inp = torch.ones(4, 4, 4, 4)
        fw_graph_cell = [None]
        compiled_m = aot_module(
            m,
            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
            bw_compiler=nop,
            keep_inference_input_mutations=True,
        )
        inp = torch.ones(4, 4, 4, 4)
        with torch.no_grad():
            out = compiled_m(inp)
        # expectation: there are no copy_() calls in the decomposed batch norm when running under training=False (eval mode)
        code = fw_graph_cell[0].code.strip()
        self.assertTrue("copy_" not in str(code))

    def test_input_output_view_simple(self):
        def f(a):
            return a.view(-1)

        inp = [torch.ones(2, 2, requires_grad=False).add(1)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(2, 2, requires_grad=True).add(1)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        # Outputs that alias inputs are pulled out of the graph entirely, so we don't compile anything here
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, arg0_1):
    view = torch.ops.aten.view.default(arg0_1, [-1]);  arg0_1 = None
    return (view,)""",
        )

    def test_input_output_view_mutate_multiple(self):
        def f(a, b, c):
            a.mul_(2)
            c.mul_(3)
            return b.view(2, 2), c.view(2, 2)

        def create_inp(req_grad):
            return [
                torch.ones(2, 2, requires_grad=req_grad).add(1),
                torch.ones(2, 2, requires_grad=req_grad).add(1),
                torch.ones(2, 2, requires_grad=req_grad).add(1),
            ]

        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
        # The original function returned two outputs, both of which aliased inputs.
        # We expect two outputs in the functional graph, a_updated and c_updated.
        # The actual aliased outputs themselves aren't in the compiled forward graph;
        # Instead, they're generated outside of  the graph.
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    clone_1 = torch.ops.aten.clone.default(primals_3);  primals_3 = None
    mul = torch.ops.aten.mul.Tensor(clone, 2);  clone = None
    mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3);  clone_1 = None
    view = torch.ops.aten.view.default(primals_2, [2, 2]);  primals_2 = None
    view_2 = torch.ops.aten.view.default(mul_1, [2, 2])
    return (mul, mul_1, view, view_2)""",
        )

    def test_input_output_view_metadata_mutate_multiple(self):
        def f(a, b, c):
            b.mul_(3)
            c.t_()
            return a.view(2, 2), b.view(2, 2), c.view(2, 2)

        def create_inp(req_grad):
            return [
                torch.ones(2, 2, requires_grad=req_grad).add(1),
                torch.ones(2, 2, requires_grad=req_grad).add(1),
                torch.ones(2, 2, requires_grad=req_grad).add(1),
            ]

        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
        # Important thing to check here: of the three inputs:
        # Only the b.mul_(3) should show up in the graph (we functionalize it and return it).
        # Everything else that does not show up in the graph includes:
        # - The metadata mutation on c (we do it outside the graph)
        # - All 3 original fw outputs, which are aliases of inputs (we regenerate them outside of the graph)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3):
    clone = torch.ops.aten.clone.default(primals_2);  primals_2 = None
    view = torch.ops.aten.view.default(primals_3, [2, 2]);  primals_3 = None
    mul = torch.ops.aten.mul.Tensor(clone, 3);  clone = None
    t = torch.ops.aten.t.default(view);  view = None
    view_1 = torch.ops.aten.view.default(primals_1, [2, 2]);  primals_1 = None
    view_3 = torch.ops.aten.view.default(t, [2, 2])
    view_4 = torch.ops.aten.view.default(mul, [2, 2])
    return (mul, t, view_1, view_4, view_3)""",
        )

    def test_input_mutation_and_output_view(self):
        def f(a):
            a.add_(1)
            return a.view(-1)

        inp = [torch.ones(2, 2, requires_grad=False).add(1)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(2, 2, requires_grad=True).add(1)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        # Here, total # of outputs is 1 because:
        # - num_mutated_inps = 1 (a_updated)
        # - num_fw_outputs = 0 (the output is an alias of the input, so we move it outside the compiled fw)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
    view_1 = torch.ops.aten.view.default(add, [-1])
    return (add, view_1)""",
        )

    def test_input_mutation_output_view_multiple(self):
        def f(a, b, c, d):
            b.transpose_(1, 0)
            c.add_(1)
            return d + 1, b.diagonal(), a + c

        def create_inp(req_grad):
            return [
                torch.arange(4, requires_grad=req_grad, dtype=torch.float32)
                .view(2, 2)
                .add(1),
                torch.arange(4, requires_grad=req_grad, dtype=torch.float32)
                .view(2, 2)
                .add(1),
                torch.ones(2, 2, requires_grad=req_grad).add(1),
                torch.ones(2, 2, requires_grad=req_grad).add(1),
            ]

        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3, primals_4):
    view = torch.ops.aten.view.default(primals_2, [2, 2]);  primals_2 = None
    clone = torch.ops.aten.clone.default(primals_3);  primals_3 = None
    transpose = torch.ops.aten.transpose.int(view, 1, 0);  view = None
    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
    add_1 = torch.ops.aten.add.Tensor(primals_4, 1);  primals_4 = None
    diagonal = torch.ops.aten.diagonal.default(transpose)
    add_2 = torch.ops.aten.add.Tensor(primals_1, add);  primals_1 = None
    return (transpose, add, add_1, diagonal, add_2)""",
        )

    def test_output_aliases_intermediate_single(self):
        def f(a):
            out = torch.mul(a, 3)
            return out.view(-1)

        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        # In AOTAutograd, we are obligated to make the compiled forward directly return `out`,
        # and reconstruct `out.view(-1)` as a fresh output.
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
    view = torch.ops.aten.view.default(mul, [-1]);  mul = None
    return (view,)""",
        )

    def test_output_aliases_input_multi_output_view_should_raise_autograd_error(self):
        def f1(a):
            return list(a.unbind(0))

        f1_compiled = aot_function(f1, nop)

        inp1 = torch.ones(3, 3, requires_grad=True).clone()
        inp2 = torch.ones(3, 3, requires_grad=True).clone()
        inp3 = torch.ones(3, 3, requires_grad=True).clone()

        with self.assertRaisesRegex(
            RuntimeError, "Such functions do not allow the output views"
        ):
            out_test1 = f1_compiled(inp1)
            # This raises a runtime error from autograd in eager mode
            out_test1[0].mul_(2)

        with self.assertRaisesRegex(
            RuntimeError, "Such functions do not allow the output views"
        ):
            out_test2 = f1_compiled(inp2)
            inp2.mul_(2)
            # In eager mode, if we mutate a tensor, any multi-output-view aliases
            # get their grad_fn replaced with error nodes, so accessing grad_fn should error
            grad_fn = out_test2[0].grad_fn

        with self.assertRaisesRegex(
            RuntimeError, "Such functions do not allow the output views"
        ):
            out_test3 = f1_compiled(inp3)
            out_test1[0].detach().mul_(2)
            # The above case also applies to detached aliases (they turn the multi-output-view
            # alias's grad_fns into error nodes)
            grad_fn = out_test2[0].grad_fn

    def test_output_aliases_input_multi_output_view(self):
        # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
        def f1(a):
            return list(a.unbind(0))

        inp = torch.ones(3, 3, requires_grad=True)
        inp_ref = torch.ones(3, 3, requires_grad=True)
        f1_compiled = aot_function(f1, nop)

        out_ref = f1(inp_ref)
        out_test = f1_compiled(inp)
        # Assert that we get CompiledFunctionBackward in the backward graph,
        # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case.
        # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
        self.assertTrue(
            all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test)
        )

        sum(out_ref).sum().backward()
        sum(out_test).sum().backward()
        self.assertEqual(inp_ref.grad, inp.grad)

        # Several of the outputs are from multi-output views.
        # However: they are part of the same alias set as "a", and "a.view(out.shape)",
        # which are both user-visible.
        # AOTAutograd will not try to be smart here and hide the aliasing relationships from autograd.
        # Instead, it will perform its "output aliases input" logic, and regenerate all aliases.
        def f3(a):
            return *list(a.unbind(0)), a.view(a.shape)

        inp = torch.ones(3, 3, requires_grad=True)
        inp_ref = torch.ones(3, 3, requires_grad=True)
        f3_compiled = aot_function(f3, nop)

        inp_ref_clone = inp_ref.clone()
        inp_clone = inp.clone()
        out_ref = f3(inp_ref_clone)
        out_test = f3_compiled(inp_clone)
        self.assertTrue(all("UnbindBackward" in str(o.grad_fn) for o in out_test[:3]))

        # The last output is not from a multi-output view, so autograd will let us mutate it.
        out_ref[-1].mul_(2)
        out_test[-1].mul_(2)
        # Also mutate the input, which should affect the aliased output.
        inp_ref_clone.view(-1).mul_(3)
        inp_clone.view(-1).mul_(3)
        # Do backward
        (inp_ref + out_ref[-1]).sum().backward()
        (inp + out_test[-1]).sum().backward()
        self.assertEqual(inp_ref.grad, inp.grad)

    def test_output_aliases_intermediate_multi_output_view(self):
        # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
        def f1(a):
            out = torch.mul(a, 3)
            return list(out.unbind(0))

        inp = torch.ones(3, 3, requires_grad=True)
        inp_ref = torch.ones(3, 3, requires_grad=True)
        f1_compiled = aot_function(f1, nop)

        out_ref = f1(inp_ref)
        out_test = f1_compiled(inp)
        # Assert that we get CompiledFunctionBackward in the backward graph,
        # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case.
        # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
        self.assertTrue(
            all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test)
        )

        sum(out_ref).sum().backward()
        sum(out_test).sum().backward()
        self.assertEqual(inp_ref.grad, inp.grad)

        # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
        def f2(a):
            out = torch.mul(a, 3)
            return *list(out.unbind(0)), out

        inp = torch.ones(3, 3, requires_grad=True)
        inp_ref = torch.ones(3, 3, requires_grad=True)
        f2_compiled = aot_function(f2, nop)

        out_ref = f2(inp_ref)
        out_test = f2_compiled(inp)
        # Assert that we get CompiledFunctionBackward in the backward graph,
        # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case.
        # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
        self.assertTrue(
            all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test)
        )

        # The last output is not from a multi-output view, so autograd will let us mutate it.
        out_ref[-1].mul_(2)
        out_test[-1].mul_(2)
        out_ref[-1].sum().backward()
        out_test[-1].sum().backward()
        self.assertEqual(inp_ref.grad, inp.grad)

        # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
        def f3(a):
            out = torch.mul(a, 3)
            return *list(out.unbind(0)), out.view(out.shape)

        inp = torch.ones(3, 3, requires_grad=True)
        inp_ref = torch.ones(3, 3, requires_grad=True)
        f3_compiled = aot_function(f3, nop)

        out_ref = f3(inp_ref)
        out_test = f3_compiled(inp)
        # Assert that we get CompiledFunctionBackward in the backward graph,
        # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case.
        # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
        self.assertTrue(
            all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test)
        )

        # The last output is not from a multi-output view, so autograd will let us mutate it.
        out_ref[-1].mul_(2)
        out_test[-1].mul_(2)
        out_ref[-1].sum().backward()
        out_test[-1].sum().backward()
        self.assertEqual(inp_ref.grad, inp.grad)

        # There are 5 outputs that all alias each other.
        # 3 of them come from multi-output views, but the other 3 are "ordinary" aliases.
        # Therefore, AOTAutograd will not attempt the multi-output-view optimization,
        # and apply the intermediate_base logic to all aliases.
        # (In theory we could probably get AOTAutograd to only apply the intermediate base
        # logic to the last 2 outputs and not the first 3. We should probably
        # just do the graph partitioning defined in this doc instead though).
        # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit
        def f4(a):
            out = torch.mul(a, 3)
            # also return the graph intermediate directly,
            # which will force AOTAutograd to do the "intermediate base" logic.
            # (Why? The user can mutate "out", which should change the autograd metadata
            #  of the other aliased outputs)
            return *list(out.unbind(0)), out, out.view(out.shape)

        inp = torch.ones(3, 3, requires_grad=True)
        inp_ref = torch.ones(3, 3, requires_grad=True)
        f4_compiled = aot_function(f4, nop)

        out_ref = f4(inp_ref)
        out_test = f4_compiled(inp)
        # Mutate the last output of f4 (autograd will allow this, since it is not a multi-output view,
        # as long as *only* the non-multi-output views participate in the backward)
        # Note: We could probably try to hide **only** the multi-output views from autograd here
        # and only do the intermediate base logic for the last two aliases.
        # Longer term solution of graph partitioning is probably cleaner though (see the note).
        out_ref[-1].mul_(2)
        out_test[-1].mul_(2)

        out_ref_sum = out_ref[-1] + out_ref[-2]
        out_test_sum = out_test[-1] + out_test[-2]
        out_ref_sum.sum().backward()
        out_test_sum.sum().backward()
        self.assertEqual(inp_ref.grad, inp.grad)

    def test_output_aliases_intermediate_mutation_linear(self):
        def f(x):
            return (x + 1).view(-1)

        inp = [torch.ones(3, 3, requires_grad=True)]
        # use inductor's decomps (which will e.g. turn _unsafe_view() into view())
        from torch._inductor.decomposition import decompositions

        f_compiled = aot_function(f, nop, decompositions=decompositions)

        out_ref = f(*inp)
        out_test = f_compiled(*inp)

        out_ref.mul_(2)
        out_test.mul_(2)
        self.assertEqual(out_ref, out_test)

    def test_output_aliases_intermediate_no_grad(self):
        def f(a, b):
            out = torch.mul(a, 3)
            # First output is an alias of an intermediate that doesn't require grad
            return out.view(-1), b.add(1)

        inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        # important bit: we don't bother generating an intermediate base as an output in the graph,
        # because the intermediate base itself didn't require gradients.
        # (the only problematic case is when both the base and the aliasesed output require gradients).
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2):
    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
    view = torch.ops.aten.view.default(mul, [-1]);  mul = None
    add = torch.ops.aten.add.Tensor(primals_2, 1);  primals_2 = None
    return (view, add)""",
        )

    def test_output_aliases_intermediate_returned_multiple_times(self):
        def f(a):
            out = torch.mul(a, 3)
            out_view = out.view(-1)
            return out, out_view, out

        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)

    def test_output_aliases_intermediate_multiple(self):
        def f(a):
            out = torch.mul(a, 3)
            # AOTAutograd should manually generate these two output views in the epilogue.
            return out.view(-1), out.view(-1)

        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
    view = torch.ops.aten.view.default(mul, [-1])
    view_1 = torch.ops.aten.view.default(mul, [-1])
    return (view, view_1, mul)""",
        )

    def test_output_aliases_intermediate_and_returned(self):
        def f(a):
            out = torch.mul(a, 3)
            # AOTAutograd should manually generate the first output (a view of an intermediate)
            # but not the second (which is itself the intermediate for the first)
            return out.view(-1), out

        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
    view = torch.ops.aten.view.default(mul, [-1])
    return (view, mul)""",
        )

    def test_output_aliases_intermediate_and_returned_flipped(self):
        def f(a):
            out = torch.mul(a, 3)
            # AOTAutograd should manually generate the first output (a view of an intermediate)
            # but not the second (which is itself the intermediate for the first)
            return out, out.view(-1)

        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
    view = torch.ops.aten.view.default(mul, [-1])
    return (mul, view)""",
        )

    def test_output_aliases_intermediate_and_returned_different_grad(self):
        def f(a):
            out = torch.mul(a, 3)
            # AOTAutograd should manually generate the first output (a view of an intermediate)
            # but not the second (which is itself the intermediate for the first)
            return out.view(-1), out, out[0].detach()

        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
    view = torch.ops.aten.view.default(mul, [-1])
    select = torch.ops.aten.select.int(mul, 0, 0)
    detach = torch.ops.aten.detach.default(select);  select = None
    detach_1 = torch.ops.aten.detach.default(detach);  detach = None
    detach_2 = torch.ops.aten.detach.default(detach_1);  detach_1 = None
    return (view, mul, detach_2)""",
        )

    def test_output_aliases_intermediate_inplace_view(self):
        def f(a):
            out = torch.mul(a, 3)
            out.t_()
            return out

        inp = [torch.ones(2, 4, requires_grad=True)]

        # TODO: fix this test.
        # See https://github.com/pytorch/pytorch/issues/90507
        # self.verify_aot_autograd(f, inp, test_mutation=True)

    def test_output_aliases_intermediate_inplace_view_with_detach(self):
        def f(a):
            out = torch.mul(a, 3)
            out.t_()
            out.detach_()
            # Thanks to the detach_() AOT Autograd doesn't need to do anything.
            # `out` will show up as having OutputType.non_alias,
            # and ._is_view() == False
            return out, a + 1

        inp = [torch.ones(2, 4, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(2, 4, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    mul = torch.ops.aten.mul.Tensor(primals_1, 3)
    t = torch.ops.aten.t.default(mul);  mul = None
    add = torch.ops.aten.add.Tensor(primals_1, 1);  primals_1 = None
    return (t, add)""",
        )

    def test_output_aliases_intermediate_inplace_view_and_view(self):
        def f(a):
            out = torch.mul(a, 3)
            out_view = out.unsqueeze(0)
            out.t_()
            out_view2 = out.unsqueeze(0)
            return out_view, out, out_view2

        inp = [torch.ones(2, 4, requires_grad=True)]

        # TODO: fix this test.
        # See <github issue link>
        # self.verify_aot_autograd(f, inp, test_mutation=True)

    def test_output_aliases_intermediate_multiple_mixed(self):
        def f(a):
            out1 = torch.mul(a, 3)
            out2 = torch.mul(a, 4)
            # AOTAutograd should manually generate these two output views in the epilogue.
            return out1.view(-1), out2.transpose(1, 0), out1.transpose(1, 0)

        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    mul = torch.ops.aten.mul.Tensor(primals_1, 3)
    mul_1 = torch.ops.aten.mul.Tensor(primals_1, 4);  primals_1 = None
    view = torch.ops.aten.view.default(mul, [-1])
    transpose = torch.ops.aten.transpose.int(mul_1, 1, 0);  mul_1 = None
    transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0)
    return (view, transpose, transpose_1, mul)""",
        )

    def test_output_all_alias_types(self):
        # There are 3 types of aliasing that require us to return metadata in the compiled fw:
        # (1) outputs that are views of inputs
        # (2) outputs that are views of intermediates
        # (3) inputs that get metadata mutations
        # test all 3 of them here
        def f(a):
            a.transpose_(1, 0)
            tmp = a.mul(2)
            return tmp.squeeze(), tmp.transpose(1, 0), a.unsqueeze(0)

        def inp_callable(req_grad):
            x = torch.ones(1, 2, 4, requires_grad=req_grad).clone()
            return [(x,), (x,)]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        # TODO: make this test run with dynamic shapes so it is more meaningful
        # metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    view = torch.ops.aten.view.default(primals_1, [1, 2, 4]);  primals_1 = None
    transpose = torch.ops.aten.transpose.int(view, 1, 0);  view = None
    mul = torch.ops.aten.mul.Tensor(transpose, 2)
    squeeze = torch.ops.aten.squeeze.default(mul)
    transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0)
    unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0)
    return (transpose, squeeze, transpose_1, unsqueeze, mul)""",
        )

    @parametrize("req_grad", [False, True])
    def test_subclass_metadata_mutation(self, req_grad):
        def f(a):
            a.transpose_(1, 0)
            tmp = a.mul(2)
            return tmp.transpose(1, 0)

        def inp_callable(req_grad):
            x = torch.ones(1, 2, 4, requires_grad=req_grad).clone()
            return [(x,), (x,)]

        # See https://github.com/pytorch/pytorch/issues/114975
        with self.assertRaisesRegex(
            RuntimeError,
            "Metadata mutations are currently not allowed on tensor subclasses",
        ):
            self.verify_aot_autograd(
                f,
                partial(inp_callable, req_grad=req_grad),
                test_mutation=True,
                make_inputs_subclasses=True,
            )

    def test_input_data_and_metadata_mutation(self):
        def f(a):
            a.t_()
            a[0].mul_(2)
            return a.view(a.shape)

        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=True)]
        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    t = torch.ops.aten.t.default(clone)
    select = torch.ops.aten.select.int(t, 0, 0);  t = None
    mul = torch.ops.aten.mul.Tensor(select, 2);  select = None
    t_1 = torch.ops.aten.t.default(clone);  clone = None
    select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0);  t_1 = mul = None
    t_2 = torch.ops.aten.t.default(select_scatter);  select_scatter = None
    t_4 = torch.ops.aten.t.default(t_2)
    t_6 = torch.ops.aten.t.default(t_2);  t_2 = None
    view_1 = torch.ops.aten.view.default(t_6, [3, 3]);  t_6 = None
    return (t_4, view_1)""",
        )

    def test_view_and_inplace_view(self):
        def f(a, b):
            a.t_()
            return b.view(b.shape), a.view(a.shape)

        def create_inp(req_grad):
            return [
                torch.ones(3, 3, requires_grad=req_grad),
                torch.ones(3, 3, requires_grad=req_grad),
            ]

        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, arg0_1, arg1_1):
    t = torch.ops.aten.t.default(arg0_1);  arg0_1 = None
    view = torch.ops.aten.view.default(arg1_1, [3, 3]);  arg1_1 = None
    view_1 = torch.ops.aten.view.default(t, [3, 3])
    return (t, view, view_1)""",
        )

    def test_view_detach(self):
        def f(a):
            tmp = a.detach()
            a.mul_(2)
            return a, tmp

        inp = [torch.ones(3, 3, requires_grad=True)]
        self.verify_aot_autograd(f, inp, test_mutation=True)
        inp = [torch.ones(3, 3, requires_grad=False)]
        self.verify_aot_autograd(f, inp, test_mutation=True)

    def test_input_inplace_requires_grad_true(self):
        def f(a, b):
            a.requires_grad_(True)
            return a.mul(3), b.mul(4)

        inp = [
            # First inp doesnt require grad, but we switch it on
            torch.ones(3, 3, requires_grad=False),
            torch.ones(3, 3, requires_grad=True),
        ]

        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2):
    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4);  primals_2 = None
    return (mul, mul_1)""",
        )

    # This is a torture test:
    # a and b get turned into a synthetic base in the compiled graph
    # One gets a data mutation, the other gets a metadata mutation.
    # We need to make sure that the metadata mutation gets propagated
    # back to the original input.
    @skipIfDynamoInput("Dynamo removes runtime error")
    def test_input_data_and_metadata_mutation_aliases_other_input(self):
        # a and b are aliased
        def f(a, b):
            a.mul_(2)
            b.t_()
            return a.mul(b)

        def inp_callable(req_grad):
            base = torch.ones(2, 2, requires_grad=req_grad)
            # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them.
            x = base.add(1)
            inp1 = x[0]
            inp2 = x[0]
            return [base], [inp1, inp2]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        with self.assertRaisesRegex(
            RuntimeError,
            "Encountered aliased inputs that are mutated in the graph, but",
        ):
            self.verify_aot_autograd(
                f,
                partial(inp_callable, req_grad=False),
                test_mutation=True,
                make_inputs_subclasses=True,
            )
        with self.assertRaisesRegex(
            RuntimeError,
            "Encountered aliased inputs that are mutated in the graph, but",
        ):
            self.verify_aot_autograd(
                f,
                partial(inp_callable, req_grad=True),
                test_mutation=True,
                make_inputs_subclasses=True,
            )

    # https://github.com/pytorch/pytorch/issues/106456
    def test_input_mutation_noncontiguous(self):
        def f(a):
            a.mul_(2)
            return a + 1

        def inp_callable(req_grad):
            base = torch.ones(2, 2, requires_grad=req_grad)
            x = base.add(1)
            # create a non-contiguous view to pass as an input to the compiler
            inp = x[:, 0]
            return [base], [inp]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        with self.assertRaisesRegex(
            RuntimeError,
            "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses",
        ):
            self.verify_aot_autograd(
                f,
                partial(inp_callable, req_grad=False),
                test_mutation=True,
                make_inputs_subclasses=True,
            )
        with self.assertRaisesRegex(
            RuntimeError,
            "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses",
        ):
            self.verify_aot_autograd(
                f,
                partial(inp_callable, req_grad=True),
                test_mutation=True,
                make_inputs_subclasses=True,
            )

    def test_backward_mutation_data(self):
        class BwMutation(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                ctx.save_for_backward(x)
                return x.clone()

            @staticmethod
            def backward(ctx, grad_output):
                (x,) = ctx.saved_tensors
                # bw mutation
                x.mul_(2)
                return grad_output.clone()

        def f(a, b):
            out = BwMutation.apply(b)
            return a * out

        inp_no_grad = [
            torch.ones(3, 3, requires_grad=True),
            torch.ones(3, 3, requires_grad=False),
        ]

        # Mutation on buffer that does not require grad during the backward is allowed
        self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)

        inp_grad = [
            torch.ones(3, 3, requires_grad=True),
            torch.ones(3, 3, requires_grad=True),
        ]
        self.verify_aot_autograd(f, inp_grad, test_mutation=True)

    def test_backward_mutation_metadata(self):
        class BwMutation(torch.autograd.Function):
            @staticmethod
            def forward(ctx, a, b):
                ctx.save_for_backward(b)
                return a.clone(), b.clone()

            @staticmethod
            def backward(ctx, grad_a, grad_b):
                (b,) = ctx.saved_tensors
                # bw metadata mutation
                b.transpose_(1, 0)
                return grad_a.clone(), grad_b.clone()

        def f(a, b):
            a_, b_ = BwMutation.apply(a, b)
            out = a_ * b_
            return out

        inp_no_grad = [
            torch.ones(3, 3, requires_grad=True),
            torch.ones(3, 3, requires_grad=False),
        ]

        with self.assertRaisesRegex(
            AssertionError, "input that had its metadata mutated in the backward"
        ):
            self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)

    def test_backward_mutation_on_grad_out(self):
        class BwMutation(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                return x.clone()

            @staticmethod
            def backward(ctx, grad_output):
                grad_output.mul_(2)
                return grad_output.clone()

        def f(a, b):
            tmp = a * b
            out = BwMutation.apply(tmp)
            return out

        inp_grad = [
            torch.ones(3, 3, requires_grad=True),
            torch.ones(3, 3, requires_grad=True),
        ]
        f_compiled = aot_function(f, nop)
        with self.assertRaisesRegex(
            AssertionError, "input to the backward that was mutated during the backward"
        ):
            out = f_compiled(*inp_grad)

    def test_backward_mutation_forward_inputs(self):
        @torch.library.custom_op("_test::_clone", mutates_args={})
        def f(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
            return x.clone()

        def f_fake(x, x1):
            return torch.empty_like(x)

        def backward(ctx, grad):
            with torch.no_grad():
                ctx.x1.zero_()
            return grad * 2, None

        def setup_context(ctx, inputs, output):
            (x, x1) = inputs
            ctx.x = x
            ctx.x1 = x1

        f.register_fake(f_fake)
        f.register_autograd(backward, setup_context=setup_context)

        def fn(x: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
            x2.mul_(5)
            return torch.ops._test._clone(x, x1) + x2

        inp_x, inp_x1, inp_x2 = (
            torch.randn(3, requires_grad=True),
            torch.randn(3, requires_grad=False),
            torch.randn(3, requires_grad=False),
        )

        ref_x, ref_x1, ref_x2 = inp_x.clone(), inp_x1.clone(), inp_x2.clone()
        ref_y = fn(ref_x, ref_x1, ref_x2)

        compiled_f = aot_function(fn, nop, keep_inference_input_mutations=True)

        x, x1, x2 = inp_x.clone(), inp_x1.clone(), inp_x2.clone()
        y = compiled_f(x, x1, x2)

        # Verify mutation in forward applied and mutation in backward is not in forward
        self.assertEqual(ref_x, x)
        self.assertEqual(ref_x1, x1)
        self.assertEqual(ref_x2, x2)
        self.assertEqual(ref_y, y)

        ref_y.sum().backward()
        y.sum().backward()

        # Verify mutations in backward applied
        self.assertEqual(ref_x, x)
        self.assertEqual(ref_x1, x1)
        self.assertEqual(ref_x2, x2)
        self.assertEqual(ref_y, y)

        self.assertEqual(ref_x.grad, x.grad)
        self.assertEqual(ref_x1.grad, x1.grad)
        self.assertEqual(ref_x2.grad, x2.grad)

    def test_backward_mutation_forward_inputs_create_graph(self):
        @torch.library.custom_op("_test::_clone_create_graph", mutates_args={})
        def f(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
            return x.clone()

        def f_fake(x, x1):
            return torch.empty_like(x)

        def backward(ctx, grad):
            with torch.no_grad():
                ctx.x1.zero_()
            return grad * 2, None

        def setup_context(ctx, inputs, output):
            (x, x1) = inputs
            ctx.x = x
            ctx.x1 = x1

        f.register_fake(f_fake)
        f.register_autograd(backward, setup_context=setup_context)

        def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
            return torch.ops._test._clone_create_graph(x, x1)

        inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn(
            3, requires_grad=True
        )

        ref_x, ref_x1 = inp_x.clone(), inp_x1.clone()
        ref_y = f(ref_x, ref_x1)
        ref_y.sum().backward()
        x, x1 = inp_x.clone(), inp_x1.clone()
        compiled_f = aot_function(fn, nop)
        y = compiled_f(x, x1)
        loss = y.sum()
        with self.assertRaisesRegex(
            RuntimeError,
            "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True",
        ):
            torch.autograd.grad(loss, inp_x, create_graph=True)
        # Not checking equality of ref and x as Exception is expected

    # Partially addresses https://github.com/pytorch/pytorch/issues/106457
    def test_input_mutation_false_aliasing(self):
        def f(a, b):
            a.mul_(3)
            b.mul_(2)
            return a.clone().view(-1) + b.clone().view(-1)

        # No overlap, contiguous
        def inp_callable1(req_grad):
            base = torch.ones(4, 4, requires_grad=req_grad)
            x = base.add(1)
            # create two views that share storage, but are actually non-overlapping
            a = x[0:2]
            b = x[2:4]
            return [base], [a, b]

        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable1, req_grad=False), test_mutation=True
        )
        self.verify_aot_autograd(
            f, partial(inp_callable1, req_grad=True), test_mutation=True
        )
        self.verify_aot_autograd(
            f,
            partial(inp_callable1, req_grad=False),
            test_mutation=True,
            make_inputs_subclasses=True,
        )
        # Input mutations on subclasses with training graphs fail backward guards today.
        with self.assertRaisesRegex(
            AssertionError,
            "attempted to compile the backward with incorrect subclass metadata",
        ):
            self.verify_aot_autograd(
                f,
                partial(inp_callable1, req_grad=True),
                test_mutation=True,
                make_inputs_subclasses=True,
            )

        # Important characteristic: the graph takes in 2 inputs!
        # That shows that we didn't try to run our complicated synthetic base logic,
        # because we successfully detected false aliasing across the two inputs.
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, arg0_1, arg1_1):
    mul = torch.ops.aten.mul.Tensor(arg0_1, 3);  arg0_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(arg1_1, 2);  arg1_1 = None
    clone = torch.ops.aten.clone.default(mul)
    view = torch.ops.aten.view.default(clone, [-1]);  clone = None
    clone_1 = torch.ops.aten.clone.default(mul_1)
    view_1 = torch.ops.aten.view.default(clone_1, [-1]);  clone_1 = None
    add = torch.ops.aten.add.Tensor(view, view_1);  view = view_1 = None
    return (mul, mul_1, add)""",
        )

        # No overlap, non-contiguous: first tensor ends before second tensor start
        def inp_callable2(req_grad):
            base = torch.ones(256, requires_grad=req_grad)
            x = base.add(1)
            a = x.as_strided((4, 4), (8, 1), storage_offset=0)
            b = x.as_strided((4, 4), (8, 1), storage_offset=28)
            return [base], [a, b]

        # No overlap, non-contiguous: tensors are perfectly interleaved
        def inp_callable3(req_grad):
            base = torch.ones(4, 4, requires_grad=req_grad)
            x = base.add(1)
            a = x[:, 0:2]
            b = x[:, 2:4]
            return [base], [a, b]

        # No overlap, non-contiguous
        def inp_callable4(req_grad):
            base = torch.ones(256, requires_grad=req_grad)
            x = base.add(1)
            a = x.as_strided((4, 4), (9, 1), storage_offset=0)
            b = x.as_strided((4, 4), (9, 1), storage_offset=22)
            return [base], [a, b]

        # No overlap, non-contiguous
        def inp_callable5(req_grad):
            base = torch.ones(256, requires_grad=req_grad)
            x = base.add(1)
            a = x.as_strided((4, 4), (9, 1), storage_offset=0)
            b = x.as_strided((4, 4), (9, 1), storage_offset=23)
            return [base], [a, b]

        # No overlap, non-contiguous
        def inp_callable6(req_grad):
            base = torch.ones(256, requires_grad=req_grad)
            x = base.add(1)
            # a's last element is at offset 195 (24 total elements)
            a = x.as_strided((2, 4, 3), (110, 24, 4), storage_offset=5)
            # b's first element is at offset 196: no overlap
            b = x[196 : 196 + a.numel()]
            return [base], [a, b]

        # overlap! non-contiguous
        def inp_callable_overlap1(req_grad):
            base = torch.ones(256, requires_grad=req_grad)
            x = base.add(1)
            a = x.as_strided((4, 4), (9, 1), storage_offset=0)
            b = x.as_strided((4, 4), (9, 1), storage_offset=24)
            return [base], [a, b]

        # overlap! non-contiguous
        def inp_callable_overlap2(req_grad):
            base = torch.ones(256, requires_grad=req_grad)
            x = base.add(1)
            a = x.as_strided((4, 4), (9, 1), storage_offset=0)
            b = x.as_strided((4, 4), (9, 1), storage_offset=25)
            return [base], [a, b]

        # overlap! non-contiguous
        def inp_callable_overlap3(req_grad):
            base = torch.ones(256, requires_grad=req_grad)
            x = base.add(1)
            # a's last element is at offset 195 (24 total elements)
            a = x.as_strided((2, 4, 3), (110, 24, 4), storage_offset=5)
            # b's first element is at offset 195: overlap!
            b = x[195 : 195 + a.numel()]
            return [base], [a, b]

        fw_graph2 = self.verify_aot_autograd(
            f, partial(inp_callable2, req_grad=False), test_mutation=True
        )
        fw_graph3 = self.verify_aot_autograd(
            f, partial(inp_callable3, req_grad=False), test_mutation=True
        )
        fw_graph4 = self.verify_aot_autograd(
            f, partial(inp_callable4, req_grad=False), test_mutation=True
        )
        fw_graph5 = self.verify_aot_autograd(
            f, partial(inp_callable5, req_grad=False), test_mutation=True
        )
        fw_graph6 = self.verify_aot_autograd(
            f, partial(inp_callable6, req_grad=False), test_mutation=True
        )

        fw_graph_overlap1 = self.verify_aot_autograd(
            f, partial(inp_callable_overlap2, req_grad=False), test_mutation=True
        )
        fw_graph_overlap2 = self.verify_aot_autograd(
            f, partial(inp_callable_overlap1, req_grad=False), test_mutation=True
        )

        # All non-overlap graphs should be the same since we detected false aliasing
        self.assertEqual(str(fw_graph.code), str(fw_graph2.code))
        self.assertEqual(str(fw_graph.code), str(fw_graph3.code))
        self.assertEqual(str(fw_graph.code), str(fw_graph4.code))
        self.assertEqual(str(fw_graph.code), str(fw_graph5.code))
        self.assertEqual(str(fw_graph.code), str(fw_graph6.code))

        # All overlap graphs should be the same since we detected real aliasing
        self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap1.code))
        self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap2.code))
        self.assertTrue("as_strided_scatter" in str(fw_graph_overlap1.code))
        self.assertTrue("as_strided_scatter" in str(fw_graph_overlap2.code))

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
    def test_mem_leak_from_save_for_bw(self):
        # See a full diagnosis at this issue: https://github.com/pytorch/pytorch/issues/94990
        # Note [Detaching saved tensors in AOTAutograd]
        # This program creates a ref-cycle. Long term, we should fix this ref cycle
        # (since it can arise, naturally albeit rarely, from uses of autograd.Function).
        # But AOTAutograd makes it more likely to show up from tracing user programs,
        # so we deal with it by manually detaching the tensors that we save for backward.
        # This is completely wrong and would give wrong results if we were to do double backward.
        # Fortunately today, double backward is explicitly banned in AOTAutograd.
        def f(a, b):
            add = a + a
            split = torch.functional.split(add, [4, 4], dim=1)
            getitem_2 = split[1]
            unsqueeze = getitem_2.unsqueeze(-1)
            mul = unsqueeze * b
            return (getitem_2, mul)

        f_compiled = aot_function(f, nop)
        inps = [
            torch.ones(8, 8, device="cuda", requires_grad=True),
            torch.ones(1, 4, 1, device="cuda", requires_grad=True),
        ]
        mem_before = torch.cuda.memory_allocated()
        f_compiled(*inps)
        mem_after = torch.cuda.memory_allocated()
        self.assertTrue(mem_after == mem_before)

    def test_output_aliases_multiple_inputs_get_correct_one(self):
        # a and b are aliased, but have different shapes
        # The first output should view off the first input, the 2nd output should view off the 2nd input
        def f(a, b):
            return a.view(a.shape), b.view(b.shape)

        def inp_callable(req_grad):
            base = torch.ones(2, 2, requires_grad=req_grad)
            # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them.
            x = base.mul(2)
            inp1 = x.view(-1)
            inp2 = x[0]
            return [base], [inp1, inp2]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        self.verify_aot_autograd(
            f,
            partial(inp_callable, req_grad=False),
            test_mutation=True,
            make_inputs_subclasses=True,
        )
        self.verify_aot_autograd(
            f,
            partial(inp_callable, req_grad=True),
            test_mutation=True,
            make_inputs_subclasses=True,
        )

    def test_input_mutation_aliases_other_input(self):
        def f(a, b):
            a.add_(1)
            return a + b

        def inp_callable(req_grad):
            base = torch.ones(4, 2, requires_grad=req_grad)
            # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them.
            x = base.add(1)
            inp1 = x[0]
            inp2 = x[0]
            return [base], [inp1, inp2]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        # Important parts of the graph:
        # - the compiled graph takes in a base, and we generate a and b (the views) off of the base
        # - clone() is still in the graph, because we need to call grad() on the original (non-mutated) inputs
        # - We re-generate the views *after* the clone, to preserve view relationships.
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0);  clone = add = None
    as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
    as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
    add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5);  as_strided_2 = as_strided_5 = None
    return (as_strided_scatter, add_1)""",
        )  # noqa: B950

    def test_input_mutation_aliases_other_input2(self):
        def f(a, b):
            a.add_(1)
            return a + b

        def inp_callable(req_grad):
            base = torch.ones(2, 2, requires_grad=req_grad)
            x = base.add(1)
            inp1 = x[0]
            # Here, one of the aliased inputs is the base itself
            inp2 = x
            return [base], [inp1, inp2]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0);  clone = add = None
    as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
    as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
    add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5);  as_strided_2 = as_strided_5 = None
    return (as_strided_scatter, add_1)""",
        )  # noqa: B950

    def test_input_mutation_aliases_and_output_alias(self):
        def f(a, b):
            # Here, we need to take care:that because and b are aliased
            # since a and b are aliased, we generate a view off of "updated b"
            a.add_(1)
            return b.view(b.shape)

        def inp_callable(req_grad):
            base = torch.ones(2, 2, requires_grad=req_grad)
            x = base.add(1)
            return [base], [x.view(-1), x.view(-1)]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0);  clone = add = None
    as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
    view_1 = torch.ops.aten.view.default(as_strided_8, [4]);  as_strided_8 = None
    return (as_strided_scatter, view_1)""",
        )  # noqa: B950

    def test_input_aliased_with_mutation_output_alias(self):
        def f(a, b, c):
            # a and c alias
            c.mul_(2)
            # The main thing we're testing here is that
            # (1) We need to reconstruct c.view(-1) from the 3rd input to the forward
            # (2) But we need to be careful to do this *before* converting aliased inputs into synthetic bases.
            #     The original fw takes in 3 args, but the compiled fw takes in only 2 args.
            return b.add(1), c.view(-1)

        def inp_callable(req_grad):
            base1 = torch.ones(2, 2, requires_grad=req_grad)
            base2 = torch.ones(2, 2, requires_grad=req_grad)
            x = base1.add(1)
            y = base2.add(1)
            return [base1, base2], [x.view(-1), y, x.view(-1)]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
    mul = torch.ops.aten.mul.Tensor(as_strided_1, 2);  as_strided_1 = None
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0);  clone = mul = None
    add = torch.ops.aten.add.Tensor(primals_2, 1);  primals_2 = None
    as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
    view_1 = torch.ops.aten.view.default(as_strided_7, [-1]);  as_strided_7 = None
    return (as_strided_scatter, add, view_1)""",
        )  # noqa: B950

    def test_input_metadata_mutation_aliases(self):
        def f(a, b):
            # a and b alias, and we do a metadata mutation on a
            # Since we're not mutating data, then b isn't affected at all.
            # We expect aot autograd to not bother with constructing a synthetic base.
            a.t_()
            return a + b

        def inp_callable(req_grad):
            base = torch.ones(2, 2, requires_grad=req_grad)
            x = base.add(1)
            return [base], [x.view(-1), x.view(-1)]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        # Expectation: fwd() takes in 2 args, and we don't construct a synthetic base.
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2):
    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
    add = torch.ops.aten.add.Tensor(t, primals_2);  t = primals_2 = None
    return (add,)""",
        )

    def test_input_mutation_aliases_and_none_require_gradients(self):
        def f(a, b, c):
            # a and b alias, but neither require gradients (so they don't have a _base)
            # aot autograd should construct the synthetic base from `torch.Tensor(a.storage())`
            a.mul_(2)
            return b + 1, c + 1

        def inp_callable(req_grad):
            base = torch.ones(2, 2)
            c_arg = torch.ones(2, 2, requires_grad=req_grad)
            x = base.add(1)
            return [base, c_arg], [x.view(-1), x.view(-1), c_arg]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )

        with self.assertRaisesRegex(
            RuntimeError, "is a tensor subclass. This is not supported today"
        ):
            self.verify_aot_autograd(
                f,
                partial(inp_callable, req_grad=False),
                test_mutation=True,
                make_inputs_subclasses=True,
            )

        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2):
    as_strided = torch.ops.aten.as_strided.default(primals_1, [4], [1], 0)
    mul = torch.ops.aten.mul.Tensor(as_strided, 2);  as_strided = None
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(primals_1, mul, [4], [1], 0);  primals_1 = mul = None
    as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
    add = torch.ops.aten.add.Tensor(as_strided_3, 1);  as_strided_3 = None
    add_1 = torch.ops.aten.add.Tensor(primals_2, 1);  primals_2 = None
    return (as_strided_scatter, add, add_1)""",
        )  # noqa: B950

    @skipIfDynamoInput("Fails with dynamo")
    def test_input_mutation_aliases_bases_out_of_order(self):
        # This tests our calling convention: if b and d are aliased, then the outer calling convention
        # that we send to the compiled forward becomes:
        # (b_d_base, a, c)
        # Importantly, even though a and c alias in our test, neither inputs are mutated,
        # So we don't need to do the base construction / deconstruction
        def f(a, b, c, d):
            b.add_(1)
            d.unsqueeze_(0)
            return a + c + d, b.view(-1)

        def inp_callable(req_grad):
            base1 = torch.ones(2, 2, requires_grad=req_grad)
            base2 = torch.ones(2, 2, requires_grad=req_grad)
            x1 = base1.add(1)
            x2 = base2.add(1)
            # a and c alias, b and d alias
            return [base1, base2], [x1.view(-1), x2.view(-1), x1.view(-1), x2.view(-1)]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )

        with self.assertRaisesRegex(
            RuntimeError,
            "Metadata mutations are currently not allowed on tensor subclasses",
        ):
            self.verify_aot_autograd(
                f,
                partial(inp_callable, req_grad=False),
                test_mutation=True,
                make_inputs_subclasses=True,
            )

        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        # 3 graph inputs: (b_d_base, a, c)
        # 2 returns: (b_updated, a+c+d)
        # (there are 2 original fw outs, but one is a view of b so it's not part of the graph)
        # (there are also 2 input mutations, but one is a metadata-only mutation so the compiled forward doesn't return it)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0);  clone = add = None
    add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3);  primals_2 = primals_3 = None
    as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
    unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0);  as_strided_5 = None
    add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1);  add_1 = None
    as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
    view_2 = torch.ops.aten.view.default(as_strided_14, [-1]);  as_strided_14 = None
    return (as_strided_scatter, add_2, view_2, unsqueeze_1)""",
        )  # noqa: B950

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
    def test_synthetic_base_base_attribute_is_none(self):
        def f(a, b):
            a.add_(1)
            return a + b

        def inp_callable():
            base = torch.ones(4, 4, device="cuda")
            # detach() so that none of the inputs have a ._base attribute.
            a = base[0].detach()
            b = base[1].detach()
            base2 = torch.ones(2, 2, requires_grad=True)
            return [base], [a, b]

        self.verify_aot_autograd(f, inp_callable, test_mutation=True)

    def test_input_mutation_alias_everything(self):
        # Mondo test that tests a combination of:
        # input is mutated, that aliases another input (so we make a synthetic base)
        # an output is an alias of another output
        # an output is an alias of an intermediate
        # a and c are aliased
        def f(a, b, c):
            c.mul_(2)  # mutates c
            b.t_()  # metadata mutate b
            tmp = a + c
            out1 = tmp.view(-1)
            out2 = b.t()
            out3 = out1.unsqueeze(0)
            # out1 and out3 are aliases of an intermediate, and alias each other!
            # out2 aliases an input, so we don't return it
            return out1, out2, out3

        def inp_callable(req_grad):
            base1 = torch.ones(2, 2, requires_grad=req_grad)
            base2 = torch.ones(2, 2, requires_grad=req_grad)
            # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them.
            base1_ = base1.add(1)
            base2_ = base2.add(1)
            a = base1_.view(-1)
            b = base2_
            c = base1_.view(-1)
            return [base1, base2], [a, b, c]

        self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=False), test_mutation=True
        )
        fw_graph = self.verify_aot_autograd(
            f, partial(inp_callable, req_grad=True), test_mutation=True
        )
        # Expected:
        # - 2 inputs in the forward: synthetic_base_a_c, b
        # - 1 output in the forward: "tmp"
        #   out2 is an alias of an input, and will be generated off of b outside of the compiled fn
        #   out1 and out3 are aliases of tmp, that we generate outside of the compiled function
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, primals_1, primals_2):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    view = torch.ops.aten.view.default(primals_2, [2, 2]);  primals_2 = None
    as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
    mul = torch.ops.aten.mul.Tensor(as_strided_1, 2);  as_strided_1 = None
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0);  clone = mul = None
    as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
    t = torch.ops.aten.t.default(view);  view = None
    as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
    add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2);  as_strided_5 = as_strided_2 = None
    view_1 = torch.ops.aten.view.default(add, [-1])
    t_1 = torch.ops.aten.t.default(t)
    unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0)
    return (as_strided_scatter, t, view_1, t_1, unsqueeze, add)""",
        )  # noqa: B950

    def test_dynamic_shape_output_not_in_bw_graph(self):
        def f(x):
            return [x + 1, x.shape[0]]

        inp = torch.ones(5, requires_grad=True)
        bw_graph_cell = [None]
        compiled_f = aot_function(
            f,
            fw_compiler=nop,
            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
            decompositions={},
            keep_inference_input_mutations=False,
            dynamic=True,
        )
        out = compiled_f(inp)
        out[0].sum().backward()
        # The important bit: the forward fn returns 2 outputs,
        # but one of them is a symint so we should only see
        # 1 grad_output as an input to the backward graph.
        # (Otherwise, autograd will plumb a None as the value of the grad_output,
        # which causes inductor to complain).
        self.assertExpectedInline(
            bw_graph_cell[0].code.strip(),
            """\
def forward(self, tangents_1):
    return (tangents_1,)""",
        )

    def test_no_grad_input_output(self):
        def f(a, b):
            return a.cos(), b.cos(), a * b

        inp_thunks = [
            lambda: torch.randn(5, requires_grad=True),
            lambda: torch.randn(5, requires_grad=False),
        ]
        for inps in itertools.product(inp_thunks, repeat=2):
            inps = [i() for i in inps]
            self.verify_aot_autograd(f, inps)

    def test_some_output_requires_grad_input_doesnt(self):
        def f(a, b):
            a_view = a.view(-1)
            a_view.requires_grad_(True)
            return a_view

        inp = [torch.randn(3, 3), torch.randn(3, 3, requires_grad=True)]
        self.verify_aot_autograd(f, inp)

    def test_some_outputs_dont_require_grad_view(self):
        def f(a, b):
            return a.detach(), b

        inp = [
            torch.randn(3, 3, requires_grad=True),
            torch.randn(3, 3, requires_grad=True),
        ]
        self.verify_aot_autograd(f, inp)

    def test_some_outputs_dont_require_grad_non_view(self):
        def f(a, b):
            return a.add(1).detach(), b

        inp = [
            torch.randn(3, 3, requires_grad=True),
            torch.randn(3, 3, requires_grad=True),
        ]
        self.verify_aot_autograd(f, inp)

    def test_inner_grad(self):
        def foo(x):
            y = torch.exp(x)
            z = torch.autograd.grad(y, x)
            return z

        inps = [torch.randn((), requires_grad=True)]
        self.verify_aot_autograd(foo, inps)

    def test_grad_context(self):
        def foo(x):
            return x * 2

        inps = [torch.randn((), requires_grad=True)]
        graph_size = None

        def get_graph_size(fx_g, _):
            nonlocal graph_size
            graph_size = len(fx_g.graph.nodes)
            return fx_g

        f = aot_function(foo, nop, get_graph_size)
        with torch.set_grad_enabled(False):
            f(*inps)
        self.assertIsNone(graph_size)

        f = aot_function(foo, nop, get_graph_size)
        with torch.set_grad_enabled(True):
            out = f(*inps)
            self.assertIsNone(graph_size)
            out.sum().backward()
            self.assertTrue(graph_size > 2)

    def test_output_dict(self):
        def f(x):
            return {"a": x, "b": x}

        inp = [torch.randn(3, 3, requires_grad=True)]
        self.verify_aot_autograd(f, inp)

        def f(x, y):
            return {"a": x, "b": y + x}

        inp = [torch.randn(3, requires_grad=True), torch.randn(3)]
        self.verify_aot_autograd(f, inp)

        def f(x):
            new_d = {}
            for k in x:
                new_d[k] = x[k] * 2
            return new_d

        a = torch.randn(3, requires_grad=True)
        b = torch.randn(3, requires_grad=True)

        def inp_callable():
            inps = [{"a": a, "b": b}]
            return inps, inps

        self.verify_aot_autograd(f, inp_callable)

    def test_module(self):
        mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
        compiled_mod = compiled_module(mod, nop, nop)
        inp = torch.randn(32, 32)
        ref_out = mod(inp)
        ref_out.sum().backward()
        ref_grads = sorted([(name, p.grad) for name, p in mod.named_parameters()])
        out = compiled_mod(inp)
        out.sum().backward()
        grads = sorted([(name, p.grad) for name, p in mod.named_parameters()])
        self.assertEqual((out, grads), (ref_out, ref_grads))

    def test_batchnorm(self):
        mod = compiled_module(nn.BatchNorm2d(4), nop, nop)
        x = torch.ones(1, 4, 2, 2)
        mod(x).sum().backward()

    def test_list_codegen(self):
        def list_nop(f, _):
            def g(inps):
                return f(*inps)

            g._boxed_call = True
            return g

        def f(a, b, c):
            return a.sin() * b.cos() * c.sin()

        f = aot_function(f, list_nop)
        inp = [torch.randn(5, requires_grad=True) for _ in range(3)]
        f(*inp).sum().backward()

    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
    def test_compilation_context(self, counter):
        def f(x):
            return x.sin().sin()

        count = []

        def compiler(fx_g, _):
            context = get_aot_compilation_context()
            count.append((context[0], len(fx_g.graph.nodes)))
            return fx_g

        f = aot_function(f, compiler)
        out = f(torch.randn(5, requires_grad=True))
        f = aot_function(f, compiler)
        f(torch.randn(5))
        out.sum().backward()
        self.assertExpectedInline(
            str(count),
            """[(['0_forward'], 4), (['1_inference'], 4), (['0_backward'], 8)]""",
        )

    def test_dupe_arg(self):
        def f(x, y):
            return x + y

        x = torch.randn(3, 3, requires_grad=True)
        self.verify_aot_autograd(f, [x, x])

    def test_dupe_arg_torture(self):
        def f(x, y):
            x.t_()
            y.unsqueeze_(0)
            return x + y

        x = torch.randn(3, 3, requires_grad=True).clone()
        self.verify_aot_autograd(f, [x, x])

    # See https://github.com/pytorch/pytorch/issues/100224
    def test_dupe_arg_returned_as_output(self):
        def f(a, b, a_):
            a[0].add_(1)
            return a_

        f_compiled = aot_function(f, nop)
        a = torch.ones(2)
        b = torch.ones(2)
        out_ref = f(a, b, a)

        a2 = torch.ones(2)
        b2 = torch.ones(2)
        out_test = f_compiled(a2, b2, a2)

        self.assertEqual(out_ref, out_test)
        self.assertEqual(a, a2)

    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
    @patch("torch._functorch.config.debug_assert", True)
    def test_invalid_dupe_left_bias(self, counter):
        # This test checks that, just because only the first
        # argument did a metadata mutation, we still correctly
        # switch to strategy 2 (deduplicate)
        # See: https://github.com/pytorch/pytorch/pull/89896#discussion_r1036224447
        class F(torch.nn.Module):
            def forward(self, x, y):
                x.t_()
                return (x + y,)

        x = torch.randn(3, 3, requires_grad=True).clone()
        y = torch.randn(3, 3, requires_grad=True)
        self.verify_aot_autograd(F(), [x, x])

        fxx = aot_module_simplified(F(), (x, x), nop)
        self.assertExpectedRaisesInline(
            AssertionError,
            lambda: fxx(x, y),
            """At compilation time, graph 2 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case.  This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""",  # noqa: B950
        )

    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
    @patch("torch._functorch.config.debug_assert", True)
    def test_invalid_dupe(self, counter):
        self._test_invalid_dupe(counter, fake=False)

    # See Note: Dynamo recompilation guarding invalid grad for why this test exists
    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
    @patch("torch._functorch.config.debug_assert", True)
    def test_invalid_dupe_fake(self, counter):
        self._test_invalid_dupe(counter, fake=True)

    def _test_invalid_dupe(self, counter, fake):
        class F(torch.nn.Module):
            def forward(self, x, y):
                x.unsqueeze_(0)
                y.unsqueeze_(0)
                return (x + y,)

        x = torch.randn(3, 3, requires_grad=True).clone()
        y = torch.randn(3, 3, requires_grad=True).clone()

        if fake:
            shape_env = ShapeEnv()
            fake_mode = FakeTensorMode(shape_env=shape_env)

            fake_x = fake_mode.from_tensor(x)
            fake_y = fake_mode.from_tensor(y)

        if fake:
            fxy = aot_module_simplified(F(), (fake_x, fake_y), nop)
        else:
            fxy = aot_module_simplified(F(), (x, y), nop)

        fxy(x, y)
        x = torch.randn(3, 3, requires_grad=True).clone()
        y = torch.randn(3, 3, requires_grad=True).clone()
        fxy(x, x)  # is ok!

        if fake:
            fxx = aot_module_simplified(F(), (fake_x, fake_x), nop)
        else:
            fxx = aot_module_simplified(F(), (x, x), nop)

        x = torch.randn(3, 3, requires_grad=True).clone()
        y = torch.randn(3, 3, requires_grad=True).clone()
        fxx(x, x)
        # Note This should not raise! Once we have guards in place here,
        # we will have this working correctly, as it should recompile.
        x = torch.randn(3, 3, requires_grad=True).clone()
        y = torch.randn(3, 3, requires_grad=True).clone()
        self.assertExpectedRaisesInline(
            AssertionError,
            lambda: fxx(x, y),
            """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case.  This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""",  # noqa: B950
        )

    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
    @patch("torch._functorch.config.debug_assert", True)
    def test_invalid_requires_grad(self, counter):
        self._test_invalid_requires_grad(counter, fake=False)

    # See Note: Dynamo recompilation guarding invalid grad for why this test exists
    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
    @patch("torch._functorch.config.debug_assert", True)
    def test_invalid_requires_grad_fake(self, counter):
        self._test_invalid_requires_grad(counter, fake=True)

    def _test_invalid_requires_grad(self, counter, fake):
        class F(torch.nn.Module):
            def forward(self, x, y):
                return (x + y,)

        x = torch.randn(3, 3, requires_grad=True)
        y = torch.randn(3, 3, requires_grad=True)
        z = torch.randn(3, 3, requires_grad=False)

        if fake:
            shape_env = ShapeEnv()
            fake_mode = FakeTensorMode(shape_env=shape_env)

            fake_x = fake_mode.from_tensor(x)
            fake_y = fake_mode.from_tensor(y)
            fake_z = fake_mode.from_tensor(z)

        if fake:
            fxy = aot_module_simplified(F(), (fake_x, fake_y), nop)
        else:
            fxy = aot_module_simplified(F(), (x, y), nop)

        compare_equal_outs_and_grads(self, F(), fxy, (x, y))
        compare_equal_outs_and_grads(self, F(), fxy, (x, z))

        if fake:
            fxz = aot_module_simplified(F(), (fake_x, fake_z), nop)
        else:
            fxz = aot_module_simplified(F(), (x, z), nop)

        compare_equal_outs_and_grads(self, F(), fxz, (x, z))

        self.assertExpectedRaisesInline(
            AssertionError,
            lambda: fxz(x, y),
            """At compilation time, graph 1 was compiled under the assumption that input 1 would not require grad, but at runtime this was not the case.  This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""",  # noqa: B950
        )

    def test_custom_autograd(self):
        class CustomFn(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                return x.clone()

            @staticmethod
            def backward(ctx, grad_output):
                return grad_output + 1

        def f(x):
            return CustomFn.apply(x)

        self.verify_aot_autograd(f, [torch.randn(3)])

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
    def test_autocast_disable_guard(self):
        with torch._C._DisableAutocast():
            x = torch.rand([4, 4]).cuda()
            y = x @ x
            self.assertEqual(y.dtype, torch.float32)

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
    def test_nonidempotent_amp(self):
        def f(self_s_emb, add_3):
            einsum_2 = torch.functional.einsum("ah,th->t", self_s_emb, add_3)
            log_softmax_2 = einsum_2.log_softmax(-1)
            return (log_softmax_2,)

        args = [
            torch.rand((1, 256), dtype=torch.float32, device="cuda"),
            torch.rand((30, 256), dtype=torch.float16, device="cuda"),
        ]
        with torch.cuda.amp.autocast(enabled=True):
            self.verify_aot_autograd(f, args)

        args = [e.requires_grad_(True) for e in args]
        with torch.cuda.amp.autocast(enabled=True):
            self.verify_aot_autograd(f, args)

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
    @unittest.skipIf(not torch.backends.cudnn.is_available(), "CUDNN is unavailable")
    @skipIfRocm  # https://github.com/pytorch/pytorch/issues/96560
    def test_batch_norm_amp(self):
        device = "cuda"
        input_dtype = torch.float16
        param_dtype = torch.float32
        weight, bias = (
            torch.ones(64, device=device, dtype=param_dtype, requires_grad=True)
            for _ in range(2)
        )
        running_mean, running_var = (
            torch.ones(64, device=device, dtype=param_dtype) for _ in range(2)
        )

        def bn(x):
            return torch.ops.aten.cudnn_batch_norm(
                x,
                weight,
                bias,
                running_mean,
                running_var,
                False,
                0.1,
                1e-05,
            )

        inp = torch.ones(
            torch.Size([16, 64, 112, 112]), dtype=input_dtype, device=device
        )

        ref = bn(inp)
        cudnn_batch_norm_decomp = torch._decomp.get_decompositions(
            {torch.ops.aten.cudnn_batch_norm}
        )
        aot_fn = make_fx(bn, decomposition_table=cudnn_batch_norm_decomp)(inp)
        res = aot_fn(inp)
        for a, b in zip(ref, res):
            assert torch.allclose(a, b)

    def test_output_op_depending_on_symint(self):
        """
        It won't be obvious from reading this test what it's testing for.  We should probably make it into a more
        focused unit test.

        An issue with the following program was the expand op would end up depending on a symint whose proxy was
        incorrectly associated with one of the grad tensors rather than input tensors.  It broke partitioner logic
        and the net result was aot_function failed to produce a function and threw an exception instead.
        """
        inp = torch.randn(5, requires_grad=True)

        def f(x):
            return x.expand(x.shape)

        # TODO(whc) make this work (test setup is wrong somehow)
        # joint_forward_backward = create_joint_forward_backward(f)
        # out = f(inp)
        # joint_inputs =  ([inp], [out.detach().contiguous()])
        # fx_g = make_fx(joint_forward_backward)(*joint_inputs)
        # TODO: assert outputs of fwd graph trace to correct symint

        # e2e test that fails without symint clone fix
        af = aot_function(
            f,
            nop,
            partition_fn=partial(
                min_cut_rematerialization_partition, compiler="inductor"
            ),
            dynamic=True,
        )
        out = af(inp)
        self.assertEqual(out, f(inp))

    def test_inference_mode(self):
        m = torch.nn.Linear(4, 4)
        inp = torch.randn(4, 4)

        aot_mod = aot_module(m, fw_compiler=nop)

        with torch.inference_mode():
            out_ref = m(inp)
            out_test = aot_mod(inp)
        self.assertEqual(out_ref, out_test)

    def test_default_partitioner_saves_symints_not_tensors_for_bw(self):
        """
        In this test, the important thing is that primals_1 is **only** needed in the backward
        in order to grab its sizes.
        We need to assert that what we save for the backward are the tensor's sizes, and not the tensor itself.

        The way this test is set up, it will actually fail if we try to save the input tensor for backward.
        Why?
        b.masked_fill_(c, 0) has a backward that requires knowing a's sizes
        b.masked_fill_(c, 0) **also** mutates a (because b and a are aliased)
        The autograd engine yells at us if we save "a" for backward, and then try to mutate it.
        """
        inp = torch.randn(2, 2, requires_grad=True)

        def f(a):
            b = a[0]
            c = torch.ones_like(b, dtype=torch.bool)
            d = b.masked_fill_(c, 0)
            return d

        compiled_f = aot_function(f, nop, dynamic=True)
        inp_ref = torch.ones(2, 2, requires_grad=True)
        inp_test = torch.ones(2, 2, requires_grad=True)

        out_ref = f(inp_ref.clone())
        out_test = compiled_f(inp_test.clone())

        self.assertEqual(out_ref, out_test)

        out_ref.sum().backward()
        out_test.sum().backward()

        self.assertEqual(inp_ref.grad, inp_test.grad)

    def test_buffer_copied_in_graph(self):
        class MyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buf = torch.nn.Buffer(torch.zeros(1))
                self.w1 = torch.nn.Parameter(torch.zeros(1))
                self.w2 = torch.nn.Parameter(torch.zeros(1))

            def forward(self, x):
                self.buf.add_(1)
                return (self.w1 * x * self.w2).sum() + self.buf.sum()

        model_for_eager = MyModel()
        model_for_compile = copy.deepcopy(model_for_eager)

        fw_graph_cell = [None]
        compiled_f = aot_module(
            model_for_compile,
            fw_compiler=make_boxed_compiler(
                partial(extract_graph, graph_cell=fw_graph_cell)
            ),
            bw_compiler=nop,
            keep_inference_input_mutations=True,
        )
        inp_ref = torch.ones(1, requires_grad=True)
        inp_test = torch.ones(1, requires_grad=True)

        out_ref = model_for_eager(inp_ref.clone())
        out_test = compiled_f(inp_test.clone())

        self.assertExpectedInline(
            fw_graph_cell[0].code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add.Tensor(primals_3, 1)
    mul = torch.ops.aten.mul.Tensor(primals_1, primals_4)
    mul_1 = torch.ops.aten.mul.Tensor(mul, primals_2)
    sum_1 = torch.ops.aten.sum.default(mul_1);  mul_1 = None
    sum_2 = torch.ops.aten.sum.default(add)
    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
    copy_ = torch.ops.aten.copy_.default(primals_3, add);  primals_3 = add = copy_ = None
    return (add_1, primals_1, primals_2, primals_4, mul)""",
        )

        self.assertEqual(out_ref, out_test)

        out_ref.sum().backward()
        out_test.sum().backward()

        eager_grads = [p.grad for _, p in model_for_eager.named_parameters()]
        compile_grads = [p.grad for _, p in model_for_compile.named_parameters()]

        self.assertEqual(eager_grads, compile_grads)
        self.assertEqual(inp_ref.grad, inp_test.grad)

    def test_buffer_copied_in_graph_with_different_shapes(self):
        class MyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buf = torch.nn.Buffer(torch.ones(4, 4))
                self.w = torch.nn.Parameter(
                    torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]])
                )

            def forward(self, x):
                self.buf.add_(1)
                return (self.w @ x).sum() + self.buf.sum()

        model_for_eager = MyModel()
        model_for_compile = copy.deepcopy(model_for_eager)

        fw_graph_cell = [None]
        compiled_f = aot_module(
            model_for_compile,
            fw_compiler=make_boxed_compiler(
                partial(extract_graph, graph_cell=fw_graph_cell)
            ),
            bw_compiler=nop,
            keep_inference_input_mutations=True,
        )
        inp_ref = torch.ones(2, 4, requires_grad=True)
        inp_test = torch.ones(2, 4, requires_grad=True)

        out_ref = model_for_eager(inp_ref.clone())
        out_test = compiled_f(inp_test.clone())

        self.assertExpectedInline(
            fw_graph_cell[0].code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3):
    add = torch.ops.aten.add.Tensor(primals_2, 1)
    mm = torch.ops.aten.mm.default(primals_1, primals_3)
    sum_1 = torch.ops.aten.sum.default(mm);  mm = None
    sum_2 = torch.ops.aten.sum.default(add)
    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
    copy_ = torch.ops.aten.copy_.default(primals_2, add);  primals_2 = add = copy_ = None
    return (add_1, primals_1, primals_3)""",
        )
        self.assertEqual(out_ref, out_test)

        out_ref.sum().backward()
        out_test.sum().backward()

        eager_grads = [p.grad for _, p in model_for_eager.named_parameters()]
        compile_grads = [p.grad for _, p in model_for_compile.named_parameters()]

        self.assertEqual(eager_grads, compile_grads)

        self.assertEqual(inp_ref.grad, inp_test.grad)

    def test_buffer_batch_norm(self):
        class MyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.m = torch.nn.BatchNorm1d(100)

            def forward(self, x):
                return self.m(x)

        model_for_eager = MyModel()
        model_for_compile = copy.deepcopy(model_for_eager)

        fw_graph_cell = [None]
        bw_graph_cell = [None]
        compiled_f = aot_module(
            model_for_compile,
            fw_compiler=make_boxed_compiler(
                partial(extract_graph, graph_cell=fw_graph_cell)
            ),
            bw_compiler=make_boxed_compiler(
                partial(extract_graph, graph_cell=bw_graph_cell)
            ),
            keep_inference_input_mutations=True,
        )
        inp_ref = torch.ones(20, 100, requires_grad=True)
        inp_test = torch.ones(20, 100, requires_grad=True)

        out_ref = model_for_eager(inp_ref.clone())
        out_test = compiled_f(inp_test.clone())

        self.assertExpectedInline(
            fw_graph_cell[0].code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6):
    add = torch.ops.aten.add.Tensor(primals_5, 1)
    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(primals_6, primals_1, primals_2, primals_3, primals_4, True, 0.1, 1e-05);  primals_2 = None
    getitem = _native_batch_norm_legit_functional[0]
    getitem_1 = _native_batch_norm_legit_functional[1]
    getitem_2 = _native_batch_norm_legit_functional[2]
    getitem_3 = _native_batch_norm_legit_functional[3]
    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
    copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3);  primals_3 = copy_ = None
    copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4);  primals_4 = copy__1 = None
    copy__2 = torch.ops.aten.copy_.default(primals_5, add);  primals_5 = add = copy__2 = None
    return (getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4)""",  # noqa: B950
        )

        self.assertEqual(out_ref, out_test)

        out_ref.sum().backward()
        out_test.sum().backward()

        eager_grads = [p.grad for _, p in model_for_eager.named_parameters()]
        compile_grads = [p.grad for _, p in model_for_compile.named_parameters()]
        self.assertEqual(eager_grads, compile_grads)

        self.assertExpectedInline(
            bw_graph_cell[0].code.strip(),
            """\
def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4, tangents_1):
    native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(tangents_1, primals_6, primals_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]);  tangents_1 = primals_6 = primals_1 = getitem_3 = getitem_4 = getitem_1 = getitem_2 = None
    getitem_5 = native_batch_norm_backward[0]
    getitem_6 = native_batch_norm_backward[1]
    getitem_7 = native_batch_norm_backward[2];  native_batch_norm_backward = None
    return (getitem_6, getitem_7, None, None, None, getitem_5)""",  # noqa: B950
        )

        self.assertEqual(inp_ref.grad, inp_test.grad)

    def test_new_inp_requires_grad_now(self):
        def f(x, y):
            return x.add_(y)

        fw_graph_cell = [None]
        bw_graph_cell = [None]
        compiled_f = aot_function(
            f,
            fw_compiler=make_boxed_compiler(
                partial(extract_graph, graph_cell=fw_graph_cell)
            ),
            bw_compiler=make_boxed_compiler(
                partial(extract_graph, graph_cell=bw_graph_cell)
            ),
            keep_inference_input_mutations=True,
        )

        inp_ref = (
            torch.ones(20, 100, requires_grad=False),
            torch.ones(20, 100, requires_grad=True),
        )
        inp_test = (
            torch.ones(20, 100, requires_grad=False),
            torch.ones(20, 100, requires_grad=True),
        )

        out_ref = f(*inp_ref)
        out_test = compiled_f(*inp_test)

        # There is no copy_ method
        self.assertExpectedInline(
            fw_graph_cell[0].code.strip(),
            """\
def forward(self, primals_1, primals_2):
    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
    add = torch.ops.aten.add.Tensor(clone, primals_2);  clone = primals_2 = None
    return (add, add)""",
        )  # noqa: B950

        self.assertEqual(out_ref, out_test)

        out_ref.sum().backward()
        out_test.sum().backward()

        self.assertExpectedInline(
            bw_graph_cell[0].code.strip(),
            """\
def forward(self, tangents_1):
    return (None, tangents_1)""",
        )  # noqa: B950

    def test_real_weights_in_symbolic_mode(self):
        from functorch.experimental import functionalize

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 5)

            def forward(self, x):
                x = self.linear(x)
                return x

        m = M().eval()

        inp = torch.randn(2, 5)

        gm = make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
        self.assertEqual(gm(torch.ones(2, 5)), m(torch.ones(2, 5)))

        gm_functionalized = make_fx(
            functionalize(
                gm,
            ),
            tracing_mode="symbolic",
            _allow_non_fake_inputs=True,
        )(inp)
        self.assertEqual(gm_functionalized(torch.ones(2, 5)), m(torch.ones(2, 5)))

        inp_count = 0
        for node in gm.graph.nodes:
            if node.op == "placeholder":
                inp_count += 1

        # No more param lifting
        self.assertEqual(inp_count, 1)

        inp_count = 0
        for node in gm_functionalized.graph.nodes:
            if node.op == "placeholder":
                inp_count += 1

        # No more param lifting
        self.assertEqual(inp_count, 1)

        with self.assertRaisesRegex(
            Exception, "Please convert all Tensors to FakeTensors"
        ):
            make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=False)(
                torch.randn(2, 5)
            )

    def test_real_weights_in_symbolic_mode_with_inplace_ops(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer = torch.nn.Buffer(torch.ones(4, 5))

            def forward(self, x):
                y = self.buffer.add_(3)
                y.resize_([20])
                assert y.shape == self.buffer.shape
                return x.sum() + self.buffer.sum()

        m = M().eval()
        inp = torch.randn(2, 5)
        # inplace mutation on attr is not allowed
        with self.assertRaisesRegex(Exception, "Can't call metadata"):
            make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)

    def _compile_and_erase_bases(self, *output_view_indices):
        # Overrides _base and _view_func tensor attributes, so as to avoid the view-replay
        # execution path when reconstructing views.
        class NoViewReplayTensor(torch.Tensor):
            @property
            def _base(self):
                return None

            @property
            def _view_func(self):
                return None

        # Wraps the outputs that are views of the FX graph 'g' with NoViewReplayTensor,
        # since they are the only ones that will get reconstructed.
        def wrapper(g, *args, **kwargs):
            outs = list(g(*args, **kwargs))
            for i in output_view_indices:
                outs[i] = NoViewReplayTensor(outs[i])
            return tuple(outs)

        return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g))

    def test_output_aliases_input_view_meta_replay(self):
        @self._compile_and_erase_bases(0)
        def f(a):
            return a.view(-1)

        inp = torch.ones(2, 2, requires_grad=True)
        out = f(inp)

        self.assertIsNotNone(out.grad_fn)
        self.assertExpectedInline(
            str(out.grad_fn.__class__), """<class 'ViewBackward0'>"""
        )

    def test_output_aliases_intermediate_view_meta_replay(self):
        @self._compile_and_erase_bases(0, 1)
        def f(a):
            b = a.clone()
            return b.view(-1), b.view(-1)

        inp = torch.ones(2, 2, requires_grad=True)
        out1, out2 = f(inp)

        self.assertIsNotNone(out1.grad_fn)
        self.assertExpectedInline(
            str(out1.grad_fn.__class__), """<class 'ViewBackward0'>"""
        )

        self.assertIsNotNone(out2.grad_fn)
        self.assertExpectedInline(
            str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
        )

    def test_output_aliases_output_view_meta_replay(self):
        @self._compile_and_erase_bases(1)
        def f(a):
            b = a.add(10)
            return b, b.view(-1)

        inp = torch.ones(2, 2, requires_grad=True)
        out1, out2 = f(inp)

        self.assertEqual(out1.untyped_storage(), out2.untyped_storage())
        self.assertIsNotNone(out2.grad_fn)
        self.assertExpectedInline(
            str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
        )

    @skipIfTorchDynamo()
    @patch("torch._dynamo.config.assume_static_by_default", False)
    def test_dynamic_output_aliases_input_view_meta_replay(self):
        # - torch.compile: using it so we can have a SymInt in the FX graph.
        # - Compiling with inductor, so that tensor._base isn't tracked.
        #
        # This should force the use of as_strided in the view reconstruction path.
        # The first 2 view-replay paths won't be taken because:
        #   - target_functional_tensor will be symbolic (_functionalize_is_symbolic call)
        #   - tensor._base will be None
        @torch.compile(backend="inductor")
        def f(a, sz):
            return a.view(sz), a.view(-1)

        inp = torch.ones(2, 2, requires_grad=True)
        out1, out2 = f(inp, (4,))

        self.assertIsNotNone(out1.grad_fn)
        self.assertExpectedInline(
            str(out1.grad_fn.__class__), """<class 'AsStridedBackward0'>"""
        )

        self.assertIsNotNone(out2.grad_fn)
        self.assertExpectedInline(
            str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
        )


def extract_graph(fx_g, _, graph_cell):
    graph_cell[0] = fx_g
    return fx_g


def get_ins_outs(fx_g):
    ins = []
    outs = []
    for n in fx_g.graph.nodes:
        if n.op == "placeholder":
            ins.append(n)
        elif n.op == "output":
            outs = tuple(n.args[0])
    return ins, outs


def get_num_ins_outs(fx_g):
    return tuple(len(i) for i in get_ins_outs(fx_g))


def get_fw_bw_graph(
    f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False
):
    fw_graph_cell = [None]
    bw_graph_cell = [None]
    aot_function(
        f,
        fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
        bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
        partition_fn=partitioner,
        decompositions=default_decompositions,
        dynamic=dynamic,
    )(*inps).sum().backward()
    return (fw_graph_cell[0], bw_graph_cell[0])


class TestMod(torch.nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.p = torch.nn.Parameter(torch.ones(2, requires_grad=True))
        self.fn = fn

    def forward(self, *args):
        return self.fn(self.p, *args)


class TestAOTExport(AOTTestCase):
    def test_aot_export_ban_dropout_mut_pre_dispatch(self):
        def fn(p, x):
            y = torch.ops.aten.dropout.default(x, 0.1, train=False)
            y.add_(1)
            return (y,)

        mod = TestMod(fn)
        inp = torch.randn(2, 2)

        with self.assertRaisesRegex(
            RuntimeError, "cannot mutate tensors with frozen storage"
        ):
            aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)

        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=False)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    clone = torch.ops.aten.clone.default(arg1_1);  arg1_1 = None
    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
    return (add,)""",
        )

        fw_graph_cell = [None]
        bw_graph_cell = [None]

        compiled_outs = aot_function(
            fn,
            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
            partition_fn=default_partition,
            decompositions=default_decompositions,
            dynamic=True,
        )(*inp)
        fw_graph = fw_graph_cell[0]
        bw_graph = bw_graph_cell[0]

        self.assertExpectedInline(
            str(fw_graph.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    clone = torch.ops.aten.clone.default(arg1_1);  arg1_1 = None
    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
    return (add,)""",
        )

    def test_aot_export_predispatch_func_simple(self):
        def fn(p, x):
            y = x + 2
            with torch.no_grad():
                y.add_(2)
            return (x * 2 + y,)

        mod = TestMod(fn)
        inp = torch.randn(2, 2)

        with torch.no_grad():
            gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    add = torch.ops.aten.add.Tensor(arg1_1, 2)
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    add_1 = torch.ops.aten.add.Tensor(add, 2);  add = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
    mul = torch.ops.aten.mul.Tensor(arg1_1, 2);  arg1_1 = None
    add_2 = torch.ops.aten.add.Tensor(mul, add_1);  mul = add_1 = None
    return (add_2,)""",
        )

    def test_aot_export_predispatch_func_composite_implicit(self):
        def fn(p, x):
            with torch.enable_grad():
                y = x @ x
            y.add_(2)
            return (x.sum() + y.sum(),)

        mod = TestMod(fn)
        inp = torch.randn(2, 2)

        with torch.no_grad():
            gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    _set_grad_enabled = torch._C._set_grad_enabled(True);  _set_grad_enabled = None
    matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1)
    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
    add = torch.ops.aten.add.Tensor(matmul, 2);  matmul = None
    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
    sum_2 = torch.ops.aten.sum.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
    return (add_1,)""",
        )

    def test_aot_export_predispatch_composite_implicit_inplace(self):
        def fn(x, p):
            return (torch.ops.aten.absolute_.default(x.clone()),)

        mod = TestMod(fn)
        inp = torch.randn(2, 2)

        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
    abs_1 = torch.ops.aten.abs.default(clone);  clone = None
    return (abs_1,)""",
        )

    def test_aot_export_predispatch_composite_implicit_linear(self):
        class MM(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(2, 2)

            def forward(self, x):
                return (self.linear(x),)

        mod = MM()
        inp = torch.randn(2, 2)

        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1, arg2_1):
    linear = torch.ops.aten.linear.default(arg2_1, arg0_1, arg1_1);  arg2_1 = arg0_1 = arg1_1 = None
    return (linear,)""",
        )

    @unittest.expectedFailure
    def test_aot_export_predispatch_outdtype(self):
        class M(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = weight

            def forward(self, x):
                y = x + 2
                y.add_(5)
                return (
                    out_dtype(torch.ops.aten.mm.default, torch.int32, y, self.weight),
                )

        weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
        mod = M(weight)
        inp = torch.randint(-128, 127, (5, 5), dtype=torch.int8)

        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    _set_grad_enabled = torch._C._set_grad_enabled(True);  _set_grad_enabled = None
    mm = torch.ops.aten.mm.default(arg1_1, arg1_1)
    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
    add = torch.ops.aten.add.Tensor(mm, 2);  mm = None
    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
    sum_2 = torch.ops.aten.sum.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
    return (add_1,)""",
        )

    def test_aot_export_predispatch_func_view(self):
        def fn(p, x):
            y = x @ x
            y.add_(2)
            return (x.sum() + y.view(1, 4).sum(),)

        mod = TestMod(fn)
        inp = torch.randn(2, 2)

        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1)
    add = torch.ops.aten.add.Tensor(matmul, 2);  matmul = None
    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
    view_1 = torch.ops.aten.view.default(add, [1, 4]);  add = None
    sum_2 = torch.ops.aten.sum.default(view_1);  view_1 = None
    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
    return (add_1,)""",
        )

    def test_aot_export_predispatch_buffer_mutation_metadata(self):
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = torch.nn.Buffer(torch.zeros(2, 2))

            def forward(self, x):
                self.foo.add_(4)
                return (x.sum() + self.foo.sum(),)

        inp = torch.randn(2, 2)

        gm, graph_sig = aot_export_module(
            Foo(), [inp], trace_joint=False, pre_dispatch=True
        )
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 4);  arg0_1 = None
    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
    sum_2 = torch.ops.aten.sum.default(add)
    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
    return (add, add_1)""",
        )
        eager_mod = Foo()
        output_1, output_2 = gm(torch.zeros(2, 2), inp)
        eager_output = eager_mod(inp)
        self.assertTrue(torch.allclose(output_2, eager_output[0]))

        _, output_2 = gm(output_1, inp)
        eager_output = eager_mod(inp)
        self.assertTrue(torch.allclose(output_2, eager_output[0]))
        self.assertTrue("foo" in graph_sig.buffers)
        self.assertTrue(graph_sig.inputs_to_buffers["arg0_1"] == "foo")

    def test_aot_export_predispatch_with_autograd_op(self):
        def foo(p, x):
            with torch.enable_grad():
                y = x + 5
                y.add_(5)
                y.add_(7)
                return (x.cos() + y.sin(),)

        inp = torch.randn(2, 2)
        mod = TestMod(foo)

        with torch.no_grad():
            gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    _set_grad_enabled = torch._C._set_grad_enabled(True);  _set_grad_enabled = None
    add = torch.ops.aten.add.Tensor(arg1_1, 5)
    add_1 = torch.ops.aten.add.Tensor(add, 5);  add = None
    add_2 = torch.ops.aten.add.Tensor(add_1, 7);  add_1 = None
    cos = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
    sin = torch.ops.aten.sin.default(add_2);  add_2 = None
    add_3 = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
    return (add_3,)""",
        )

    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
    @unittest.skipIf(
        not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported"
    )
    def test_aot_export_predispatch_with_cond_nested(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                def true_fn(x):
                    y = x.sin()
                    y.add_(5)

                    def true_true_fn(x):
                        y = x.sin()
                        y.add_(7)
                        return y.sin()

                    def true_false_fn(x):
                        return x.cos()

                    return torch.cond(
                        y.cos().sum() > 5, true_true_fn, true_false_fn, [y.cos()]
                    )

                def false_fn(x):
                    z = x.cos()
                    z.add_(6)
                    return z.sin()

                a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
                return (a + 3, a + 4)

        inp = torch.randn(2, 2)
        gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1):
    sum_1 = torch.ops.aten.sum.default(arg0_1)
    gt = torch.ops.aten.gt.Scalar(sum_1, 4);  sum_1 = None
    true_graph_0 = self.true_graph_0
    false_graph_0 = self.false_graph_0
    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
    getitem = cond[0];  cond = None
    add = torch.ops.aten.add.Tensor(getitem, 3)
    add_1 = torch.ops.aten.add.Tensor(getitem, 4);  getitem = None
    return (add, add_1)""",  # noqa: B950
        )

        self.assertExpectedInline(
            str(gm.true_graph_0.code).strip(),
            """\
def forward(self, arg0_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    add = torch.ops.aten.add.Tensor(sin, 5);  sin = None
    cos = torch.ops.aten.cos.default(add)
    sum_1 = torch.ops.aten.sum.default(cos);  cos = None
    gt = torch.ops.aten.gt.Scalar(sum_1, 5);  sum_1 = None
    cos_1 = torch.ops.aten.cos.default(add);  add = None
    true_graph_0 = self.true_graph_0
    false_graph_0 = self.false_graph_0
    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [cos_1]);  gt = true_graph_0 = false_graph_0 = cos_1 = None
    getitem = cond[0];  cond = None
    return (getitem,)""",  # noqa: B950
        )

        self.assertExpectedInline(
            str(gm.true_graph_0.true_graph_0.code).strip(),
            """\
def forward(self, arg0_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    add = torch.ops.aten.add.Tensor(sin, 7);  sin = None
    sin_1 = torch.ops.aten.sin.default(add);  add = None
    return (sin_1,)""",
        )

    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
    @unittest.skipIf(
        not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported"
    )
    def test_aot_export_predispatch_map_1(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x, y):
                def true_fn(x, r):
                    y = x.sin()
                    y.add_(5)
                    return y.cos() + r.sum()

                def false_fn(x, r):
                    z = x.cos()

                    def f(x, y):
                        a = x.cos()
                        a.add_(5)
                        return a + y

                    return (
                        z
                        + control_flow.map(f, z, r).sum()
                        + control_flow.map(f, z, r).sum()
                    )

                a = torch.cond(x.sum() > 4, true_fn, false_fn, [x, y])
                return (a + 3, a + 4)

        inps = [torch.randn(2, 2), torch.ones(2)]
        gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    sum_1 = torch.ops.aten.sum.default(arg0_1)
    gt = torch.ops.aten.gt.Scalar(sum_1, 4);  sum_1 = None
    true_graph_0 = self.true_graph_0
    false_graph_0 = self.false_graph_0
    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
    getitem = cond[0];  cond = None
    add = torch.ops.aten.add.Tensor(getitem, 3)
    add_1 = torch.ops.aten.add.Tensor(getitem, 4);  getitem = None
    return (add, add_1)""",  # noqa: B950
        )
        self.assertExpectedInline(
            str(gm.true_graph_0.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    add = torch.ops.aten.add.Tensor(sin, 5);  sin = None
    cos = torch.ops.aten.cos.default(add);  add = None
    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
    add_1 = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
    return (add_1,)""",
        )
        self.assertExpectedInline(
            str(gm.false_graph_0.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
    select = torch.ops.aten.select.int(cos, 0, 0);  select = None
    body_graph_0 = self.body_graph_0
    map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]);  body_graph_0 = None
    getitem = map_impl[0];  map_impl = None
    sum_1 = torch.ops.aten.sum.default(getitem);  getitem = None
    add = torch.ops.aten.add.Tensor(cos, sum_1);  sum_1 = None
    select_1 = torch.ops.aten.select.int(cos, 0, 0);  select_1 = None
    body_graph_1 = self.body_graph_1
    map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]);  body_graph_1 = cos = arg1_1 = None
    getitem_1 = map_impl_1[0];  map_impl_1 = None
    sum_2 = torch.ops.aten.sum.default(getitem_1);  getitem_1 = None
    add_1 = torch.ops.aten.add.Tensor(add, sum_2);  add = sum_2 = None
    return (add_1,)""",
        )
        self.assertExpectedInline(
            str(gm.false_graph_0.body_graph_0.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
    add = torch.ops.aten.add.Tensor(cos, 5);  cos = None
    add_1 = torch.ops.aten.add.Tensor(add, arg1_1);  add = arg1_1 = None
    return (add_1,)""",
        )

    def test_aot_export_predispatch_map_2(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x, y):
                z = x.cos()

                def f(x, y):
                    a = x.cos()
                    a.add_(5)
                    return a + y

                return (z + control_flow.map(f, z, y).sum(),)

        inps = [torch.randn(2, 2), torch.ones(2)]
        gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
    body_graph_0 = self.body_graph_0
    map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]);  body_graph_0 = arg1_1 = None
    getitem = map_impl[0];  map_impl = None
    sum_1 = torch.ops.aten.sum.default(getitem);  getitem = None
    add = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
    return (add,)""",
        )  # noqa: B950
        self.assertExpectedInline(
            str(gm.body_graph_0.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
    add = torch.ops.aten.add.Tensor(cos, 5);  cos = None
    add_1 = torch.ops.aten.add.Tensor(add, arg1_1);  add = arg1_1 = None
    return [add_1]""",
        )

    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
    @unittest.skipIf(
        not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported"
    )
    def test_aot_export_predispatch_with_cond(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                def true_fn(x):
                    y = x.sin()
                    z = torch.ops.aten.linear.default(y, torch.randn(2, 2))
                    z.add_(5)
                    return z.cos()

                def false_fn(x):
                    z = x.cos()
                    z.add_(6)
                    return z.sin()

                a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
                return (a + 3, a + 4)

        inp = torch.randn(2, 2)
        gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1):
    sum_1 = torch.ops.aten.sum.default(arg0_1)
    gt = torch.ops.aten.gt.Scalar(sum_1, 4);  sum_1 = None
    true_graph_0 = self.true_graph_0
    false_graph_0 = self.false_graph_0
    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
    getitem = cond[0];  cond = None
    add = torch.ops.aten.add.Tensor(getitem, 3)
    add_1 = torch.ops.aten.add.Tensor(getitem, 4);  getitem = None
    return (add, add_1)""",  # noqa: B950
        )
        self.assertExpectedInline(
            str(gm.true_graph_0.code).strip(),
            """\
def forward(self, arg0_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    randn = torch.ops.aten.randn.default([2, 2], device = device(type='cpu'), pin_memory = False)
    linear = torch.ops.aten.linear.default(sin, randn);  sin = randn = None
    add = torch.ops.aten.add.Tensor(linear, 5);  linear = None
    cos = torch.ops.aten.cos.default(add);  add = None
    return (cos,)""",
        )

    def test_aot_export_predispatch_conv_and_bn(self):
        class ConvBatchnorm(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 3, 1, 1)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                return (x,)

        mod = ConvBatchnorm()
        mod.train()
        inp = torch.randn(1, 1, 3, 3)

        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
    conv2d = torch.ops.aten.conv2d.default(arg7_1, arg0_1, arg1_1);  arg7_1 = arg0_1 = arg1_1 = None
    add = torch.ops.aten.add.Tensor(arg6_1, 1);  arg6_1 = None
    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05);  conv2d = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None
    getitem = _native_batch_norm_legit_functional[0]
    getitem_3 = _native_batch_norm_legit_functional[3]
    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
    return (getitem_3, getitem_4, add, getitem)""",  # noqa: B950
        )

    def test_aot_export_predispatch_reshape(self):
        class Reshape(torch.nn.Module):
            def forward(self, x):
                y = x.reshape(4, 4)
                return (y.sum(),)

        mod = Reshape()
        inp = torch.randn(2, 8)

        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1):
    view = torch.ops.aten.view.default(arg0_1, [4, 4]);  arg0_1 = None
    sum_1 = torch.ops.aten.sum.default(view);  view = None
    return (sum_1,)""",
        )  # noqa: B950

    def test_aot_export_predispatch_contiguous(self):
        class Cont(torch.nn.Module):
            def forward(self, x):
                y = torch.ops.aten.contiguous.default(x)
                return (y.sum(),)

        mod = Cont()
        inp = torch.randn(2, 8)

        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1):
    sum_1 = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
    return (sum_1,)""",
        )  # noqa: B950

    def test_aot_export_module_joint(self):
        class ConvBatchnormRelu(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 3, 1, 1)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                user_out = torch.nn.functional.relu(x)
                loss = user_out.sum()
                return loss, user_out.detach()

        mod = ConvBatchnormRelu()
        mod.train()
        inp = torch.randn(1, 1, 3, 3)
        o_ref = mod(inp)
        fx_g, signature = aot_export_module(
            mod, [inp], trace_joint=True, output_loss_index=0
        )
        # Some important characteristics of the exported graph below:
        # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
        # 9 outputs: 3 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)
        for node in fx_g.graph.nodes:
            node.meta.pop("stack_trace", None)
        self.assertExpectedInline(
            fx_g.print_readable(print_output=False),
            """\
class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"):
        # No stacktrace found for following nodes
        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  arg1_1 = None
        add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1);  arg6_1 = None
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05);  arg3_1 = arg4_1 = arg5_1 = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_1: "f32[3]" = _native_batch_norm_legit_functional[1]
        getitem_2: "f32[3]" = _native_batch_norm_legit_functional[2]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem);  getitem = None
        detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu);  detach = None
        detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu)
        detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1);  detach_1 = None
        detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2);  detach_2 = None
        detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3);  detach_3 = None
        sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
        detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu);  relu = None
        detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5);  detach_5 = None
        detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6);  detach_6 = None
        detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7);  detach_7 = None
        detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8);  detach_8 = None
        detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9);  detach_9 = None
        ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
        expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]);  ones_like = None
        detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4);  detach_4 = None
        detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11);  detach_11 = None
        detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12);  detach_12 = None
        detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13);  detach_13 = None
        threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0);  expand = detach_14 = None
        native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]);  threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
        getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0]
        getitem_6: "f32[3]" = native_batch_norm_backward[1]
        getitem_7: "f32[3]" = native_batch_norm_backward[2];  native_batch_norm_backward = None
        convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]);  getitem_5 = arg7_1 = arg0_1 = None
        getitem_8 = convolution_backward[0];  getitem_8 = None
        getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1]
        getitem_10: "f32[3]" = convolution_backward[2];  convolution_backward = None
        return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7)
        """,  # noqa: B950
        )

        self.assertExpectedInline(
            str(signature.parameters),
            """['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias']""",
        )
        self.assertExpectedInline(
            str(signature.buffers),
            """['bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']""",
        )
        self.assertExpectedInline(str(signature.user_inputs), """['arg7_1']""")
        self.assertExpectedInline(
            str(signature.inputs_to_parameters),
            """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""",
        )  # noqa: B950
        self.assertExpectedInline(
            str(signature.inputs_to_buffers),
            """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""",
        )  # noqa: B950
        self.assertExpectedInline(
            str(signature.buffers_to_mutate),
            """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""",
        )  # noqa: B950
        self.assertExpectedInline(
            str(signature.backward_signature.gradients_to_parameters),
            """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""",
        )  # noqa: B950
        self.assertExpectedInline(
            str(signature.backward_signature.gradients_to_user_inputs), """{}"""
        )
        self.assertExpectedInline(
            str(signature.backward_signature.loss_output), """getitem_3"""
        )

        # Also check the inference graph
        # Main important thing here is that there are 5 total outputs: 3 total mutated buffers (from batchnorm), 2 user outputs.
        fx_g_inference, signature_inference = aot_export_module(
            mod, [inp], trace_joint=False
        )
        for node in fx_g_inference.graph.nodes:
            node.meta.pop("stack_trace", None)
        self.assertExpectedInline(
            fx_g_inference.print_readable(print_output=False),
            """\
class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"):
        # No stacktrace found for following nodes
        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  arg7_1 = arg0_1 = arg1_1 = None
        add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1);  arg6_1 = None
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05);  convolution = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem);  getitem = None
        sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
        detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu);  relu = None
        detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach);  detach = None
        detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1);  detach_1 = None
        return (getitem_3, getitem_4, add, sum_1, detach_2)
        """,  # noqa: B950
        )
        # Some important characteristics of the exported graph below:
        # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
        # 9 outputs: 2 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)

    def test_aot_export_simplified_basic(self):
        def f(x, y):
            return x * y, y * y.detach()

        x = torch.randn(2, requires_grad=True)
        y = torch.randn(2, requires_grad=True)

        f_graph_fw = aot_export_joint_simple(f, [x, y], trace_joint=False)
        out_ref = f(x, y)
        # No calling convention changes necessary to invoke the traced graph
        out_test = f_graph_fw(x, y)
        self.assertEqual(out_ref, out_test)

        # Now test the backward
        x = torch.randn(2, requires_grad=True)
        y = torch.randn(2, requires_grad=True)
        x2 = x.clone().detach().requires_grad_(True)
        y2 = y.clone().detach().requires_grad_(True)
        x3 = x.clone().detach().requires_grad_(True)
        y3 = y.clone().detach().requires_grad_(True)
        f_graph_joint = aot_export_joint_simple(f, [x, y], trace_joint=True)
        num_fw_outputs = 2
        fw_g, bw_g = default_partition(
            f_graph_joint, [x, y], num_fwd_outputs=num_fw_outputs
        )
        out_ref2 = f(x2, y2)
        fw_outs = fw_g(x3, y3)
        out_test2, activations = fw_outs[:num_fw_outputs], fw_outs[num_fw_outputs:]
        self.assertEqual(out_ref2, out_test2)

        # Test running the traced backward graph with a mocked-up grad_output
        grad_outs = [torch.ones_like(x) for x in out_ref2]
        grads_ref = torch.autograd.grad(out_ref2, [x2, y2], grad_outputs=grad_outs)
        grads_test = bw_g(*activations, *grad_outs)
        for g_ref, g_test in zip(grads_ref, grads_test):
            self.assertEqual(g_ref, g_test)

    def test_aot_export_metadata_mutation_banned(self):
        def fn(p, x):
            x.t_()
            return (x * 2,)

        mod = TestMod(fn)
        inp = torch.randn(2, 4)
        with self.assertRaisesRegex(
            RuntimeError, "Found an input that received a metadata mutation"
        ):
            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
            aot_export_module(mod, [inp], trace_joint=False)

    def test_aot_export_forward_mutation_no_buffer_mut(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))

            def forward(self, x):
                x.add_(4)
                return (x.cos().sum() + self.buffer1.sum(),)

        mod = M()
        inp = torch.ones(6, 4)
        gm, sig = aot_export_module(mod, [inp], trace_joint=False)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    add = torch.ops.aten.add.Tensor(arg1_1, 4);  arg1_1 = None
    cos = torch.ops.aten.cos.default(add)
    sum_1 = torch.ops.aten.sum.default(cos);  cos = None
    sum_2 = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
    return (add, add_1)""",
        )  # noqa: B950
        self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg1_1"})

    def test_aot_export_forward_mutation_multiple_mut(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))

            def forward(self, x, y):
                y.add_(4)
                self.buffer1.add_(5)
                return (
                    x.cos().sum() + y.sin().sum(),
                    self.buffer1.sum(),
                )

        mod = M()
        inp = [torch.ones(6, 4), torch.zeros(6, 4)]
        gm, sig = aot_export_module(mod, inp, trace_joint=False)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1, arg2_1):
    add = torch.ops.aten.add.Tensor(arg2_1, 4);  arg2_1 = None
    add_1 = torch.ops.aten.add.Tensor(arg0_1, 5);  arg0_1 = None
    cos = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
    sum_1 = torch.ops.aten.sum.default(cos);  cos = None
    sin = torch.ops.aten.sin.default(add)
    sum_2 = torch.ops.aten.sum.default(sin);  sin = None
    add_2 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
    sum_3 = torch.ops.aten.sum.default(add_1)
    return (add_1, add, add_2, sum_3)""",
        )  # noqa: B950
        self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg2_1"})
        self.assertEqual(sig.buffers_to_mutate, {"add_1": "buffer1"})

    def test_aot_export_input_mutation_on_input_requiring_grad_banned(self):
        class M(torch.nn.Module):
            def forward(self, x):
                x.add_(4)
                return (x,)

        mod = M()
        inp = torch.randn(2, requires_grad=True)
        with self.assertRaisesRegex(
            RuntimeError,
            "Found a graph input that requires gradients, and received a mutation",
        ):
            aot_export_module(mod, [inp], trace_joint=False)

    def test_aot_export_input_mutation_on_parameter_banned(self):
        def fn(p, x):
            p.mul_(2)
            return (p + x,)

        mod = TestMod(fn)
        inp = torch.randn(2)
        with self.assertRaisesRegex(
            RuntimeError,
            "Found a graph input that requires gradients, and received a mutation",
        ):
            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
            aot_export_module(mod, [inp], trace_joint=False)

    def test_aot_export_synthetic_bases_banned(self):
        def fn(p, x, y):
            x.mul_(2)
            return (x + y,)

        mod = TestMod(fn)
        inp = torch.randn(2)
        inp2 = inp.view(-1)
        with self.assertRaisesRegex(
            RuntimeError, "Encountered aliased inputs that are mutated"
        ):
            aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=False)
            aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=True)
            aot_export_module(mod, [inp, inp2], trace_joint=False)

    def test_aot_export_input_dupes_banned(self):
        def fn(p, x, y):
            x.mul_(2)
            return (x + y,)

        mod = TestMod(fn)
        inp = torch.randn(2)
        with self.assertRaisesRegex(
            RuntimeError, "Encountered duplicated inputs that are mutated in the graph"
        ):
            aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=False)
            aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=True)
            aot_export_module(mod, [inp, inp], trace_joint=False)

    def test_aot_export_multiple_outputs_require_grad_banned(self):
        def fn(p, x):
            out = p * x
            return out, out.sum()

        mod = TestMod(fn)
        inp = torch.randn(2)
        with self.assertRaisesRegex(
            RuntimeError,
            "Found an output of the forward that requires gradients, that was not",
        ):
            aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1)

    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
    @unittest.skipIf(
        not torch._dynamo.is_dynamo_supported(), "Cond needs dynamo to run"
    )
    def test_aot_export_with_torch_cond(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                def true_fn(x):
                    y = x + 4
                    y.add_(5)
                    return x.cos()

                def false_fn(x):
                    y = x + 5
                    y.add_(6)
                    return x.sin()

                a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
                return (a + 3, a + 4)

        inp = torch.randn(3, 4)
        gm, _ = aot_export_module(M(), (inp,), trace_joint=False)
        self.assertExpectedInline(
            gm.code.strip(),
            """\
def forward(self, arg0_1):
    sum_1 = torch.ops.aten.sum.default(arg0_1)
    gt = torch.ops.aten.gt.Scalar(sum_1, 4);  sum_1 = None
    true_graph_0 = self.true_graph_0
    false_graph_0 = self.false_graph_0
    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
    getitem = cond[0];  cond = None
    add = torch.ops.aten.add.Tensor(getitem, 3)
    add_1 = torch.ops.aten.add.Tensor(getitem, 4);  getitem = None
    return (add, add_1)""",  # noqa: B950
        )

        self.assertExpectedInline(
            gm.true_graph_0.code.strip(),
            """\
def forward(self, arg0_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 4)
    add_1 = torch.ops.aten.add.Tensor(add, 5);  add = add_1 = None
    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
    return (cos,)""",
        )

        self.assertExpectedInline(
            gm.false_graph_0.code.strip(),
            """\
def forward(self, arg0_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 5)
    add_1 = torch.ops.aten.add.Tensor(add, 6);  add = add_1 = None
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    return (sin,)""",
        )

    def test_aot_export_simplified_pytrees_banned(self):
        def fn(inps):
            return (inps[0] + inps[1],)

        inp1 = torch.randn(2)
        inp2 = torch.randn(2)
        inps = [inp1, inp2]
        with self.assertRaisesRegex(
            RuntimeError,
            "aot_export_joint_simple requires individual inputs not to be pytrees",
        ):
            aot_export_joint_simple(fn, [inps], trace_joint=False)
            aot_export_joint_simple(fn, [inps], trace_joint=True)

    def test_aot_export_functionalized_rng_banned(self):
        def fn(p, x):
            return (p + x,)

        mod = TestMod(fn)
        inp = torch.randn(2)
        with patch(
            "functorch.compile.config.functionalize_rng_ops", True
        ), self.assertRaisesRegex(
            RuntimeError,
            "Functionalized RNG is not currently supported in the aot_export",
        ):
            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
            aot_export_module(mod, [inp], trace_joint=False)

    def test_aot_export_unbacked_arg(self):
        class M(torch.nn.Module):
            def forward(self):
                full = torch.full((), 11)
                i0 = full.item()
                return (torch.full((i0,), 0),)

        gm, _ = aot_export_module(
            mod=M(), args=(), trace_joint=False, dynamic_shapes=True
        )
        self.assertExpectedInline(
            gm.code.strip(),
            """\
def forward(self):
    full = torch.ops.aten.full.default([], 11, device = device(type='cpu'), pin_memory = False)
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(full);  full = None
    full_1 = torch.ops.aten.full.default([_local_scalar_dense], 0, device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
    return (full_1,)""",  # noqa: B950
        )


class TestPartitioning(AOTTestCase):
    @unittest.skipIf(not USE_NETWORKX, "networkx not available")
    def test_recompute_partitioning(self):
        def fn(a, b):
            return torch.sin(torch.sin(a)) + b

        # Reference calculation
        ref_a = torch.rand(10, 10, requires_grad=True)
        ref_b = torch.rand(10, 10, requires_grad=True)
        ref = fn(ref_a, ref_b)
        ref.sum().backward()

        # Compiled function calculation
        res_a = ref_a.clone().detach().requires_grad_(True)
        res_b = ref_b.clone().detach().requires_grad_(True)

        def compile_fn(x, _):
            return x

        compiled_fn = compiled_function(
            fn, compile_fn, compile_fn, min_cut_rematerialization_partition
        )
        res = compiled_fn(res_a, res_b)
        res.sum().backward()
        assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
        assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3)
        assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3)

    def test_meta_tensor_inplace_op(self):
        # Following module results in inplace ops while tracing. The test checks
        # that the meta tensor information is stored for inplace ops.
        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = torch.nn.Parameter(
                    torch.randn(3072, 768, requires_grad=True)
                )
                self.bias = torch.nn.Parameter(torch.randn(3072, requires_grad=True))

            def forward(self, add_4):
                linear_4 = torch.nn.functional.linear(
                    add_4, self.weight, bias=self.bias
                )
                gelu = torch.nn.functional.gelu(linear_4)
                return gelu

        def check_meta_tensor(fx_g, _):
            for node in fx_g.graph.nodes:
                if node.op != "output":
                    assert "tensor_meta" in node.meta
            return fx_g

        inp0 = torch.randn(16, 128, 768, requires_grad=True)
        inputs = [
            inp0,
        ]
        mod = MockModule().to(device="cpu")
        aot_mod = aot_module(mod, fw_compiler=check_meta_tensor)
        aot_mod(*inputs)

    def test_default_partitioner_getitem(self):
        mod = nn.LayerNorm([10])

        def f(x, mod_weight, mod_bias):
            return torch.nn.functional.layer_norm(
                x, [10], mod_weight, mod_bias, eps=1e-6
            )

        fw_graph, bw_graph = get_fw_bw_graph(
            f,
            [torch.randn(3, 10, requires_grad=True), mod.weight, mod.bias],
            partitioner=default_partition,
        )
        self.assertEqual(get_num_ins_outs(fw_graph), (3, 6))
        self.assertEqual(get_num_ins_outs(bw_graph), (6, 3))

    @unittest.skipIf(not USE_NETWORKX, "networkx not available")
    def test_min_cut_partitioner_save_shape(self):
        def f(x):
            s = x.sum(dim=1)
            return s

        inp = [torch.ones([10, 10], requires_grad=True)]
        fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True)
        _, fw_output = get_ins_outs(fw_graph)
        self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
        self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
        self.assertEqual(str(fw_output[0]), "sum_1")
        # make sure we don't do the suboptimal thing of saving the bigger primals input to sum,
        # rather than saving the sizes of the primals input for use in backward expand
        self.assertEqual(str(fw_output[1]), "sym_size_int")
        self.assertEqual(str(fw_output[2]), "sym_size_int_1")

        inp = [
            torch.randn(10, requires_grad=True),
            torch.randn((3, 10), requires_grad=True),
            torch.randn((2, 10), requires_grad=True),
        ]

        def f(a, b, c):
            # tried to test what happens if we save a size tuple in the graph;
            # turns out we never will due to how we trace, but this is probably
            # still a good test case for various size manipulations
            sb = torch.ops.aten.sym_size(b)
            sc = c.size()
            x = sb[0] + sc[0]
            a_sz = (x, a.size(0))
            return torch.cat([a.expand(a_sz), b, c])

        fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True)
        self.assertEqual(get_num_ins_outs(fw_graph), (3, 4))
        self.assertEqual(get_num_ins_outs(bw_graph), (4, 3))
        _, outs = get_ins_outs(fw_graph)
        self.assertTrue(all(is_sym_node(n) for n in outs[1:]))

    def test_default_partitioner_output_tensor_shape_tensor(self):
        inp = [
            torch.randn(10, requires_grad=True),
            torch.randn((3, 10), requires_grad=True),
            torch.randn((2, 10), requires_grad=True),
            torch.randn((10, 1), requires_grad=True),
        ]

        def f(a, b, c, d):
            # Try to force symints intermixed with outputs in the function's returns
            sb = b.size()
            sc = c.size()
            x = sb[0] + sc[0]
            a_sz = (x, a.size(0))
            cat = torch.cat([a.expand(a_sz), b, c])
            mm = torch.mm(cat, d)
            mm2 = torch.mm(
                mm, a.view(mm.size(1), a.size(0))
            )  # this saves 4 new ints for backward. why?
            # and what do i have to do to make it save a tensor for backward?
            return cat, sb, c, mm2

        fw_graph_cell = [None]
        bw_graph_cell = [None]
        compiled_outs = aot_function(
            f,
            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
            partition_fn=default_partition,
            decompositions=default_decompositions,
            dynamic=True,
        )(*inp)
        fw_graph = fw_graph_cell[0]
        (compiled_outs[0].sum() + compiled_outs[2].sum()).backward()
        bw_graph = bw_graph_cell[0]

        # in the fwd graph, 13 outs because:
        # - 5 original outputs (sb is a tuple, gets expanded to 2 symints)
        # - 8 saved outputs for backward: 5 tensors, 3 symints
        self.assertEqual(get_num_ins_outs(fw_graph), (4, 13))
        # in the bwd graph, 10 inputs (grad outs) because:
        # - The fwd graph had 13 outputs
        # - 1 was a view of an input, which gets regenerated outside of the graph
        #   and doesn't participate in the backward
        # - 2 user outs were symints (b.size()), which don't get tangents in the backward
        self.assertEqual(get_num_ins_outs(bw_graph), (10, 4))
        _, fw_graph_out_nodes = get_ins_outs(fw_graph)
        self.assertEqual(
            # fw outputs include b.size() which expands to 2 symints,
            #
            # TODO(whc)- are the saved-tensors/saved-symints correct here?
            # i just made the test pass based on what default partition did
            # Of the 5 original forward outputs, the 4th (c) is an input,
            # which won't show up in the compiled forward graph
            [False, True, True, False, False] + [False] * 4 + [True] * 4,
            [is_sym_node(n) for n in fw_graph_out_nodes],
        )

        real_outs = f(*inp)
        self.assertEqual(compiled_outs, real_outs)
        self.assertTrue(isinstance(real_outs[1], torch.Size))

        # TODO(whc) we should learn to return torch.Sizes
        self.assertFalse(isinstance(compiled_outs[1], torch.Size))

    @unittest.skipIf(not USE_NETWORKX, "networkx not available")
    def test_min_cut_partitioner_output_tensor_shape_tensor(self):
        inp = [
            torch.randn(10, requires_grad=True),
            torch.randn((3, 10), requires_grad=True),
            torch.randn((2, 10), requires_grad=True),
            torch.randn((10, 1), requires_grad=True),
        ]

        def f(a, b, c, d):
            # Try to force symints intermixed with outputs in the function's returns
            sb = b.size()
            sc = c.size()
            x = sb[0] + sc[0]
            a_sz = (x, a.size(0))
            cat = torch.cat([a.expand(a_sz), b, c])
            mm = torch.mm(cat, d)
            mm2 = torch.mm(
                mm, a.view(mm.size(1), a.size(0))
            )  # this saves 4 new ints for backward. why?
            # and what do i have to do to make it save a tensor for backward?
            return cat, sb, c, mm2

        fw_graph_cell = [None]
        bw_graph_cell = [None]
        compiled_outs = aot_function(
            f,
            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
            partition_fn=min_cut_rematerialization_partition,
            decompositions=default_decompositions,
            dynamic=True,
        )(*inp)
        fw_graph = fw_graph_cell[0]
        (compiled_outs[0].sum() + compiled_outs[2].sum()).backward()
        bw_graph = bw_graph_cell[0]

        self.assertEqual(get_num_ins_outs(fw_graph), (4, 12))
        self.assertEqual(get_num_ins_outs(bw_graph), (9, 4))
        _, fw_graph_out_nodes = get_ins_outs(fw_graph)
        self.assertEqual(
            # fw outputs include b.size() which expands to 2 symints,
            # then 4 tensors (transposes of matricies used for mm) are saved
            # finally 3 symints are saved
            [False, True, True, False, False] + [False] * 4 + [True] * 3,
            [is_sym_node(n) for n in fw_graph_out_nodes],
        )

        real_outs = f(*inp)
        self.assertEqual(compiled_outs, real_outs)
        self.assertTrue(isinstance(real_outs[1], torch.Size))

        # TODO(whc) we should learn to return torch.Sizes
        self.assertFalse(isinstance(compiled_outs[1], torch.Size))

    @unittest.skipIf(not USE_NETWORKX, "networkx not available")
    def test_min_cut_partitioner(self):
        def f(x):
            return x.cos().cos().cos()

        fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)])
        self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
        self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))

        def f(a, b, c, d):
            x = a + b + c + d
            return x.cos().cos()

        fw_graph, bw_graph = get_fw_bw_graph(
            f, [torch.randn(3, requires_grad=True) for _ in range(4)]
        )
        self.assertEqual(get_num_ins_outs(fw_graph), (4, 2))
        self.assertEqual(get_num_ins_outs(bw_graph), (2, 4))

    def test_contiguous(self):
        # The test simulates the condition where transpose followed by view
        # happens in the backward pass.
        # https://discuss.pytorch.org/t/error-on-transpose-and-view/434
        def f(x):
            return x.view(2, 3).t()

        inp = torch.randn(6, requires_grad=True)
        out = aot_function(f, nop)(inp)
        torch.autograd.grad(out, inp, torch.randn(3, 2))

    def test_preserve_random(self):
        def fn(x):
            return torch.nn.functional.dropout(x, 0.5) + x

        x = torch.randn(4)

        torch.manual_seed(0)
        ref = fn(x)

        torch.manual_seed(0)
        aot_fn = aot_function(fn, nop)
        res = aot_fn(x)

        assert torch.allclose(ref, res)

    # https://github.com/pytorch/pytorch/issues/110666
    def test_generate_gives_inference_graph(self):
        # We expect this to give an inference graph
        def generate(x):
            with torch.no_grad():
                return torch.mul(x, x)

        inference_graph_cell = [None]
        inference_compiler = make_boxed_compiler(
            partial(extract_graph, graph_cell=inference_graph_cell)
        )
        aot_fn = aot_function(generate, nop, inference_compiler=inference_compiler)
        # Even though x requires grad, we should still get an inference graph
        x = torch.randn(4, requires_grad=True)
        res = aot_fn(x)
        self.assertTrue(inference_graph_cell[0] is not None)

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
    def test_autocast(self):
        mod = torchvision.models.resnet18().cuda()
        mod.train()

        x = torch.randn(16, 3, 32, 32, device="cuda")
        aot_mod = memory_efficient_fusion(mod)

        # Ensure that AOT Autograd works with AMP
        with torch.cuda.amp.autocast(True):
            res = aot_mod(x)
        res.sum().backward()


class TestAOTDispatch(AOTTestCase):
    # Tests to add cases for (non-exhaustive list, mostly for my notes):
    # - subclass / mode introduced in the middle of the compiled fn
    # - various input mutation / intermediate base tests
    # - input mutation that changes a tensor into a subclass
    # - metadata mutation? (TBD)
    # - guard tests (fw guards *and* bw guards)
    # - subclass test involving _indices_of_inps_to_detach
    def test_aot_dispatch_simple(self):
        # a is a subclass, b is not
        def f(a, b):
            aa = torch.mul(a, 6)
            bb = torch.div(b, 2)
            return aa + bb

        a1_ref = torch.ones(3, 3, requires_grad=True)
        a2_ref = torch.ones(3, 3, requires_grad=True)
        a_ref = TwoTensor(a1_ref, a2_ref)
        b_ref = torch.ones(3, 3, requires_grad=True)

        a1_test = a1_ref.clone().detach().requires_grad_(True)
        a2_test = a2_ref.clone().detach().requires_grad_(True)
        a_test = TwoTensor(a1_test, a2_test)
        b_test = b_ref.clone().detach().requires_grad_(True)

        fw_graph_cell = [None]
        bw_graph_cell = [None]

        compiled_f = aot_function(
            f,
            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
            partition_fn=min_cut_rematerialization_partition,
        )
        out_ref = f(a_ref, b_ref)
        out_test = compiled_f(a_test, b_test)

        # Output is a TwoTensor (check both inner tensors)
        self.assertEqual(out_ref.a, out_test.a)
        self.assertEqual(out_ref.b, out_test.b)

        out_ref.sum().backward()
        out_test.sum().backward()
        # Both grad_inputs are TwoTensor
        self.assertEqual(a_ref.grad.a, a_test.grad.a)
        self.assertEqual(a_ref.grad.b, a_test.grad.b)
        self.assertEqual(b_ref.grad.a, b_test.grad.a)
        self.assertEqual(b_ref.grad.b, b_test.grad.b)

        # Important pieces of the graph:
        # - mul() and div() show up twice, because we called them on a TwoTensor
        # - add() shows up once, because we called it on a plain Tensor
        # - The user forward() fn returns 1 output (the result of add),
        #   while the graph itself returns two outputs (add, add_1)
        # - add, add_1 correspond to the two inner dense tensors that will be wrapped
        # - into a single TwoTensor output.
        self.assertExpectedInline(
            fw_graph_cell[0].code.strip(),
            """\
def forward(self, primals_1, primals_2, primals_3):
    mul = torch.ops.aten.mul.Tensor(primals_1, 6);  primals_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(primals_2, 6);  primals_2 = None
    div = torch.ops.aten.div.Tensor(primals_3, 2);  primals_3 = None
    add = torch.ops.aten.add.Tensor(mul, div);  mul = None
    add_1 = torch.ops.aten.add.Tensor(mul_1, div);  mul_1 = div = None
    return (add, add_1)""",
        )

        # Important pieces of the graph:
        # - 4 total dense outputs.
        #   This corresponds to the fact that each user fwd inpt (a, b)
        #   will get a gradient that is a TwoTensor subclass,
        #   so (mul_2, mul_3) will be wrapped into a.grad
        #   and (div_1, div_2) will be wrapped into b.grad
        # - 4 total dense outputs,
        self.assertExpectedInline(
            bw_graph_cell[0].code.strip(),
            """\
def forward(self, tangents_1, tangents_2):
    div_1 = torch.ops.aten.div.Tensor(tangents_1, 2)
    div_2 = torch.ops.aten.div.Tensor(tangents_2, 2)
    mul_2 = torch.ops.aten.mul.Tensor(tangents_1, 6);  tangents_1 = None
    mul_3 = torch.ops.aten.mul.Tensor(tangents_2, 6);  tangents_2 = None
    return (mul_2, mul_3, div_1, div_2)""",
        )

    def test_aot_dispatch_inference(self):
        # a is a subclass, b is not
        def f(a, b):
            aa = torch.mul(a, 6)
            bb = torch.div(b, 2)
            return aa + bb

        a1_ref = torch.ones(3, 3)
        a2_ref = torch.ones(3, 3)
        a_ref = TwoTensor(a1_ref, a2_ref)
        b_ref = torch.ones(3, 3)

        a1_test = a1_ref.clone()
        a2_test = a2_ref.clone()
        a_test = TwoTensor(a1_test, a2_test)
        b_test = b_ref.clone()

        compiled_f = aot_function(
            f,
            fw_compiler=nop,
            bw_compiler=nop,
            partition_fn=min_cut_rematerialization_partition,
        )
        out_ref = f(a_ref, b_ref)
        out_test = compiled_f(a_test, b_test)

        # Output is a TwoTensor (check both inner tensors)
        self.assertEqual(out_ref.a, out_test.a)
        self.assertEqual(out_ref.b, out_test.b)

    def test_aot_dispatch_incorrect_backward(self):
        # a is a subclass, b is not
        def f(a, b):
            aa = torch.mul(a, 2)
            bb = torch.add(b, 3)
            out_subclass = torch.div(aa, bb)
            out_reg = torch.add(b, b)
            # When creating the joint, we assume that the second grad_out
            # is not a subclass.
            # In the below test case though, we end up being wrong.
            # This would require re-tracing and recompiling the backward.
            return out_subclass, out_reg

        a1_ref = torch.ones(3, 3, requires_grad=True)
        a2_ref = torch.ones(3, 3, requires_grad=True)
        a_ref = TwoTensor(a1_ref, a2_ref)
        b_ref = torch.ones(3, 3, requires_grad=True)

        a1_test = a1_ref.clone().detach().requires_grad_(True)
        a2_test = a2_ref.clone().detach().requires_grad_(True)
        a_test = TwoTensor(a1_test, a2_test)
        b_test = b_ref.clone().detach().requires_grad_(True)

        compiled_f = aot_function(
            f,
            fw_compiler=nop,
            bw_compiler=nop,
            partition_fn=min_cut_rematerialization_partition,
        )
        out_ref = f(a_ref, b_ref)
        out_test = compiled_f(a_test, b_test)
        # First out is a TwoTensor, second is an ordinary tensor
        self.assertEqual(out_ref[0].a, out_test[0].a)
        self.assertEqual(out_ref[0].b, out_test[0].b)
        self.assertEqual(out_ref[1], out_test[1])

        # We compiled our graph assuming type(grad_out[1]) == torch.Tensor,
        # but we were wrong: in the below tests, it is a subclass.
        # This will eventually require a repartition + recompile
        with self.assertRaisesRegex(
            AssertionError,
            "incorrectly attempted to compile the backward with incorrect subclass metadata",
        ):
            (out_test[0] + out_test[1]).sum().backward()

    def test_aot_dispatch_output_alias(self):
        # a is a tensor, b is a TwoTensor
        def f(a, b):
            return b.view(b.shape), a * b

        b1_ref = torch.ones(3, 3, requires_grad=True)
        b2_ref = torch.ones(3, 3, requires_grad=True)
        b_ref = TwoTensor(b1_ref, b2_ref)
        a_ref = torch.ones(3, 3, requires_grad=True)

        b1_test = b1_ref.clone().detach().requires_grad_(True)
        b2_test = b2_ref.clone().detach().requires_grad_(True)
        b_test = TwoTensor(b1_test, b2_test)
        a_test = a_ref.clone().detach().requires_grad_(True)

        compiled_f = aot_function(
            f,
            fw_compiler=nop,
            bw_compiler=nop,
            partition_fn=min_cut_rematerialization_partition,
        )
        out_ref1, out_ref2 = f(a_ref, b_ref)
        out_test1, out_test2 = compiled_f(a_test, b_test)
        self.assertEqual(out_ref1, out_test1)
        self.assertEqual(out_ref2.a, out_test2.a)
        self.assertEqual(out_ref2.b, out_test2.b)

        (out_ref1 + out_ref2).sum().backward()
        (out_test1 + out_test2).sum().backward()
        # Both grad_inputs are TwoTensor
        self.assertEqual(a_ref.grad.a, a_test.grad.a)
        self.assertEqual(a_ref.grad.b, a_test.grad.b)
        self.assertEqual(b_ref.grad.a, b_test.grad.a)
        self.assertEqual(b_ref.grad.b, b_test.grad.b)

    def test_aot_dispatch_input_mutation(self):
        def f(a, b):
            a.mul_(2)
            b.mul_(3)
            return a + b

        b1_ref = torch.ones(3, 3, requires_grad=True)
        b2_ref = torch.ones(3, 3, requires_grad=True)
        b_ref_base = TwoTensor(b1_ref, b2_ref)
        a_ref_base = torch.ones(3, 3, requires_grad=True)
        b_ref = b_ref_base + 1
        a_ref = a_ref_base + 1

        b1_test = b1_ref.clone().detach().requires_grad_(True)
        b2_test = b2_ref.clone().detach().requires_grad_(True)
        b_test_base = TwoTensor(b1_test, b2_test)
        a_test_base = a_ref_base.clone().detach().requires_grad_(True)
        b_test = b_test_base + 1
        a_test = a_test_base + 1

        compiled_f = aot_function(
            f,
            fw_compiler=nop,
            bw_compiler=nop,
            partition_fn=min_cut_rematerialization_partition,
        )
        out_ref = f(a_ref, b_ref)
        out_test = compiled_f(a_test, b_test)
        self.assertEqual(out_ref.a, out_test.a)
        self.assertEqual(out_ref.b, out_test.b)

        # confirm input mutations worked
        self.assertEqual(a_test, a_ref)
        self.assertEqual(b_test.a, b_ref.a)
        self.assertEqual(b_test.b, b_ref.b)

        # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile teh backward.
        (b_ref * out_ref).sum().backward()
        (b_test * out_test).sum().backward()
        # Both grad_inputs are TwoTensor
        self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
        self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
        self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
        self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)

    # NB: Metadata mutation for subclasses is currently broken and disabled
    # See https://github.com/pytorch/pytorch/issues/114975
    @unittest.expectedFailure
    def test_aot_dispatch_input_metadata_mutation(self):
        def f(a, b):
            a.t_()
            b.unsqueeze_(0)
            return a + b

        b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
        b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
        b_ref_base = TwoTensor(b1_ref, b2_ref)
        a_ref_base = (
            torch.arange(9, dtype=torch.float32)
            .reshape(3, 3)
            .detach()
            .requires_grad_(True)
        )
        b_ref = b_ref_base + 1
        a_ref = a_ref_base + 1

        b1_test = b1_ref.clone().detach().requires_grad_(True)
        b2_test = b2_ref.clone().detach().requires_grad_(True)
        b_test_base = TwoTensor(b1_test, b2_test)
        a_test_base = a_ref_base.clone().detach().requires_grad_(True)
        b_test = b_test_base + 1
        a_test = a_test_base + 1

        compiled_f = aot_function(
            f,
            fw_compiler=nop,
            bw_compiler=nop,
            partition_fn=min_cut_rematerialization_partition,
        )
        out_ref = f(a_ref, b_ref)
        out_test = compiled_f(a_test, b_test)
        self.assertEqual(out_ref.a, out_test.a)
        self.assertEqual(out_ref.b, out_test.b)

        # confirm input mutations worked
        self.assertEqual(a_test, a_ref)
        self.assertEqual(b_test.a, b_ref.a)
        self.assertEqual(b_test.b, b_ref.b)

        # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward.
        (b_ref * out_ref).sum().backward()
        (b_test * out_test).sum().backward()
        # Both grad_inputs are TwoTensor
        self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
        self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
        self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
        self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)

    # NB: Metadata mutation for subclasses is currently broken and disabled
    # See https://github.com/pytorch/pytorch/issues/114975
    @unittest.expectedFailure
    def test_aot_dispatch_input_data_and_metadata_mutation(self):
        def f(a, b):
            a.t_()
            b.unsqueeze_(0)
            a.mul_(2)
            b.mul_(3)
            return a + b

        b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
        b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
        b_ref_base = TwoTensor(b1_ref, b2_ref)
        a_ref_base = (
            torch.arange(9, dtype=torch.float32)
            .reshape(3, 3)
            .detach()
            .requires_grad_(True)
        )
        b_ref = b_ref_base + 1
        a_ref = a_ref_base + 1

        b1_test = b1_ref.clone().detach().requires_grad_(True)
        b2_test = b2_ref.clone().detach().requires_grad_(True)
        b_test_base = TwoTensor(b1_test, b2_test)
        a_test_base = a_ref_base.clone().detach().requires_grad_(True)
        b_test = b_test_base + 1
        a_test = a_test_base + 1

        compiled_f = aot_function(
            f,
            fw_compiler=nop,
            bw_compiler=nop,
            partition_fn=min_cut_rematerialization_partition,
        )
        out_ref = f(a_ref, b_ref)
        out_test = compiled_f(a_test, b_test)
        self.assertEqual(out_ref.a, out_test.a)
        self.assertEqual(out_ref.b, out_test.b)

        # confirm input mutations worked
        self.assertEqual(a_test, a_ref)
        self.assertEqual(b_test.a, b_ref.a)
        self.assertEqual(b_test.b, b_ref.b)

        # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward.
        (b_ref * out_ref).sum().backward()
        (b_test * out_test).sum().backward()
        # Both grad_inputs are TwoTensor
        self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
        self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
        self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
        self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)

    def test_aot_dispatch_input_mutation_and_output_alias(self):
        def f(a, b):
            a.mul_(2)
            b.mul_(3)
            return b.view(b.shape), a + b

        b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
        b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
        b_ref_base = TwoTensor(b1_ref, b2_ref)
        a_ref_base = (
            torch.arange(9, dtype=torch.float32)
            .reshape(3, 3)
            .detach()
            .requires_grad_(True)
        )
        b_ref = b_ref_base + 1
        a_ref = a_ref_base + 1

        b1_test = b1_ref.clone().detach().requires_grad_(True)
        b2_test = b2_ref.clone().detach().requires_grad_(True)
        b_test_base = TwoTensor(b1_test, b2_test)
        a_test_base = a_ref_base.clone().detach().requires_grad_(True)
        b_test = b_test_base + 1
        a_test = a_test_base + 1

        compiled_f = aot_function(
            f,
            fw_compiler=nop,
            bw_compiler=nop,
            partition_fn=min_cut_rematerialization_partition,
        )
        out_ref1, out_ref2 = f(a_ref, b_ref)
        out_test1, out_test2 = compiled_f(a_test, b_test)
        self.assertEqual(out_ref1.a, out_test1.a)
        self.assertEqual(out_ref1.b, out_test1.b)
        self.assertEqual(out_ref2.a, out_test2.a)
        self.assertEqual(out_ref2.b, out_test2.b)

        # confirm input mutations worked
        self.assertEqual(a_test, a_ref)
        self.assertEqual(b_test.a, b_ref.a)
        self.assertEqual(b_test.b, b_ref.b)

        (out_ref1 * out_ref2).sum().backward()
        (out_test1 * out_test2).sum().backward()
        # Both grad_inputs are TwoTensors
        self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
        self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)

    def test_aot_dispatch_output_requires_grad_in_no_grad(self):
        def fn(x):
            out1 = x.sin()
            with torch.enable_grad():
                out2 = x.cos()
            return out1, out2

        inp_fns = [
            lambda: torch.ones(10, requires_grad=True),
            lambda: torch.ones(10, requires_grad=False),
        ]

        compiled_f = aot_function(fn, nop)
        for inp_fn in inp_fns:
            with torch.no_grad():
                ref_x = inp_fn()
                ref_out = fn(ref_x)
                x = inp_fn()
                out = compiled_f(x)
                for r, o in zip(ref_out, out):
                    self.assertEqual(r.requires_grad, o.requires_grad)
            if ref_x.requires_grad:
                with torch.enable_grad():
                    (ref_out[0] + ref_out[1]).sum().backward()
                    (out[0] + out[1]).sum().backward()
                    self.assertEqual(ref_x.grad, x.grad)
                    assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3)

    def test_aot_dispatch_output_requires_grad_in_no_grad_views(self):
        # view-type ops preserve requires_grad even in no_grad.
        def fn(x):
            return x.view(-1), x.sin()

        inference_graph_cell = [None]
        inference_compiler = make_boxed_compiler(
            partial(extract_graph, graph_cell=inference_graph_cell)
        )
        compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler)

        inp_x0 = torch.ones(2, 3, requires_grad=True)
        # Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad
        ref_x0 = inp_x0.clone()
        x0 = inp_x0.clone()
        with torch.no_grad():
            ref_out1, ref_out2 = fn(ref_x0)

            out1, out2 = compiled_fn(x0)
            # Assert that we executed inference graph
            self.assertTrue(inference_graph_cell[0] is not None)

            self.assertEqual(ref_out1.requires_grad, out1.requires_grad)
            self.assertEqual(ref_out2.requires_grad, out2.requires_grad)


class TestAOTModuleSimplified(AOTTestCase):
    def test_aot_module_simplified(self):
        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(20, 30)

            def forward(self, x, y):
                return (self.linear(x) + y,)

        mod = MockModule()
        mod.zero_grad()

        x = torch.randn(128, 20, requires_grad=True)
        y = torch.randn(128, 30, requires_grad=True)
        inputs = [x, y]
        cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs]

        ref = mod(*inputs)
        ref[0].sum().backward()

        compiled_f = aot_module_simplified(mod, cloned_inputs, nop)
        mod.zero_grad()
        res = compiled_f(*cloned_inputs)
        res[0].sum().backward()

        assert torch.allclose(ref[0], res[0])
        assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
        assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)

    def test_aot_module_simplified_dynamic(self):
        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(20, 30)

            def forward(self, x, y):
                return (self.linear(x) + y,)

        mod = MockModule()

        shape_env = ShapeEnv()
        fake_mode = FakeTensorMode(shape_env=shape_env)

        x = torch.randn(128, 20, requires_grad=True)
        y = torch.randn(128, 30, requires_grad=True)

        inputs = [x, y]
        fake_inputs = [fake_mode.from_tensor(x) for x in inputs]
        compiled_f = aot_module_simplified(mod, fake_inputs, nop)

        ref = mod(*inputs)
        ref[0].sum().backward()

        cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs]
        res = compiled_f(*cloned_inputs)
        res[0].sum().backward()

        self.assertExpectedInline(
            shape_env.format_guards(),
            """\
 - Eq(s1, 20)
 - Eq(s2, 30)""",
        )

        assert torch.allclose(ref[0], res[0])
        assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
        assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)

    # https://github.com/pytorch/pytorch/issues/105327
    def test_lift_fresh_copy_in_graph(self):
        class MyMod(torch.nn.Module):
            def forward(self, x):
                _tensor_constant0 = torch.tensor([1])
                lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(
                    _tensor_constant0
                )
                y = x.mul(lift_fresh_copy)
                return (y,)

        mod = MyMod()
        shape_env = ShapeEnv()
        fake_mode = FakeTensorMode(shape_env=shape_env)
        x = torch.ones(4, requires_grad=True)
        inputs = [x]
        fake_inputs = [fake_mode.from_tensor(x) for x in inputs]
        compiled_f = aot_module_simplified(mod, fake_inputs, nop)

        out_ref = mod(x)
        out_test = compiled_f(x)
        self.assertEqual(out_ref[0].detach(), out_test[0].detach())

    def test_inference_python_dispatcher(self):
        # Extracted from unet
        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.upsample = torch.nn.Upsample(
                    scale_factor=2, mode="bilinear", align_corners=True
                )

            def forward(self, x):
                return (self.upsample(x),)

        mod = MockModule()
        shape_env = ShapeEnv()
        fake_mode = FakeTensorMode(shape_env=shape_env)
        x = torch.randn(2, 512, 40, 59)  # NB: must not require grad
        inputs = [x]
        fake_inputs = [fake_mode.from_tensor(x) for x in inputs]
        compiled_f = aot_module_simplified(mod, fake_inputs, nop)

    def test_aot_module_simplified_preserves_stack_trace(self):
        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(20, 30)

            def forward(self, x, y):
                z = self.linear(x)
                z = z + y
                z = z.relu()
                return (z,)

        tracer = torch.fx.Tracer()
        tracer.record_stack_traces = True
        graph = tracer.trace(MockModule())
        mod = torch.fx.GraphModule(tracer.root, graph)

        for node in mod.graph.nodes:
            if node.op == "output":
                continue
            self.assertTrue(node.stack_trace is not None)
            assert "test_aotdispatch.py" in node.stack_trace

        def assert_compiler(gm: torch.fx.GraphModule, _):
            for node in gm.graph.nodes:
                if node.op == "output" or node.op == "placeholder":
                    continue
                self.assertTrue(node.stack_trace is not None)
                assert "test_aotdispatch.py" in node.stack_trace
            return gm.forward  # return a python callable

        x = torch.randn(128, 20, requires_grad=True)
        y = torch.randn(128, 30, requires_grad=True)
        inputs = [x, y]

        compiled_f = aot_module_simplified(
            mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler
        )
        res = compiled_f(*inputs)
        res[0].sum().backward()

    def test_aot_module_simplified_preserves_stack_trace_from_mutation(self):
        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                x_view = x[0]
                x_view.mul_(2)
                return (x + x,)

        tracer = torch.fx.Tracer()
        tracer.record_stack_traces = True
        graph = tracer.trace(MockModule())
        mod = torch.fx.GraphModule(tracer.root, graph)

        for node in mod.graph.nodes:
            if node.op == "output":
                continue
            self.assertTrue(node.stack_trace is not None)
            assert "test_aotdispatch.py" in node.stack_trace

        def assert_compiler(gm: torch.fx.GraphModule, _):
            assert torch.ops.aten.copy_.default in [x.target for x in gm.graph.nodes]
            for node in gm.graph.nodes:
                if node.target == torch.ops.aten.copy_.default:
                    assert "stack_trace" in node.meta
                    assert "x_view.mul_(2)" in node.meta["stack_trace"]
            return gm.forward  # return a python callable

        x = torch.randn(128, 20)
        inputs = [x]

        aot_module_simplified(
            mod,
            inputs,
            fw_compiler=assert_compiler,
            bw_compiler=assert_compiler,
            keep_inference_input_mutations=True,
        )

    def test_aot_module_simplified_fake_tensor_gm_raises(self):
        fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
        real_x = torch.randn(4, requires_grad=True)
        fake_x = fake_mode.from_tensor(real_x)
        real_z = torch.randn(4)
        fake_z = fake_mode.from_tensor(real_z)

        class MockModule(torch.nn.Module):
            def forward(self, x):
                # Accessing a free variable fake tensor will look like a
                # constant to make_fx, and result in the tensor being traced
                # into the graph, which is an error condition.  Make sure we
                # report adequately in this case.
                return (x + fake_z,)

        with self.assertRaisesRegex(AssertionError, "Unexpected fake"):
            aot_module_simplified(MockModule(), (fake_x,), nop)

    def test_aot_test_subclasses_with_tensor_factories(self):
        from torch.testing._internal.common_subclass import SubclassWithTensorFactory

        inp = SubclassWithTensorFactory(torch.zeros(3, 5))

        def fn(x):
            return 2 * x

        ref_out = fn(inp)
        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp)
        self.assertEqual(ref_out, out)


# entries in here don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)
aot_autograd_failures = {
    # data-dependent control flow
    xfail("cov"),
    xfail("nn.functional.gaussian_nll_loss"),
    xfail("tensor_split"),
    xfail("corrcoef"),
    xfail("quantile"),
    xfail("nanquantile"),
    xfail("narrow"),
    xfail("istft"),
    xfail("linalg.eig"),
    skip("as_strided_scatter"),
    skip("as_strided", "partial_views"),  # flaky
    # Given input size: (s0xs1x2). Calculated output size: ...
    skip("max_pool2d_with_indices_backward"),
    skip("nn.functional.nll_loss", ""),  # UBSAN failure!
    # Misc
    xfail("to_sparse"),
    xfail("corrcoef"),
    xfail("cov"),
    xfail("chalf"),  # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
    xfail("sparse.sampled_addmm"),
    xfail("sparse.mm", "reduce"),
    skip("nn.functional.binary_cross_entropy_with_logits"),  # seems to fail sometimes?
    skip("nn.functional.margin_ranking_loss"),  # seems flaky
    skip("linalg.lu_solve"),  # flaky
    decorate("matmul", decorator=unittest.skipIf(IS_ARM64, "flaky")),
    decorate("__rmatmul__", decorator=unittest.skipIf(IS_ARM64, "flaky")),
    # overrides atol=1e-4, rtol=1e-5 would do as well
    decorate(
        "svd_lowrank",
        decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)}),
    ),
    decorate(
        "linalg.householder_product",
        decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),
    ),
    decorate(
        "linalg.pinv",
        "singular",
        decorator=toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}),
    ),
    decorate(
        "nn.functional.interpolate",
        "bicubic",
        decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)}),
    ),
    # conv2d sometimes nondeterministic in this config?
    decorate("nn.functional.conv2d", decorator=unittest.skipIf(IS_ARM64, "flaky")),
}

symbolic_aot_autograd_failures = {
    xfail("combinations", ""),  # aten.masked_select.default
    xfail(
        "index_fill", ""
    ),  # Cannot call sizes() on tensor with symbolic sizes/strides
    xfail(
        "linalg.lstsq", ""
    ),  # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition
    xfail(
        "linalg.lstsq", "grad_oriented"
    ),  # aten.linalg_lstsq.default - couldn't find symbolic meta funct...
    xfail(
        "linalg.lu_solve", ""
    ),  # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco...
    skip(
        "nn.functional.batch_norm", ""
    ),  # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
    xfail(
        "nn.functional.binary_cross_entropy", ""
    ),  # aten.fill_.Scalar - couldn't find symbolic meta funct...
    xfail(
        "nn.functional.cross_entropy", ""
    ),  # Cannot call sizes() on tensor with symbolic sizes/strides
    xfail(
        "nn.functional.ctc_loss", ""
    ),  # aten._ctc_loss.Tensor - couldn't find symbolic meta function/deco...
    xfail(
        "nn.functional.fractional_max_pool3d", ""
    ),  # rand() received an invalid combination of arguments - g...
    xfail(
        "nn.functional.group_norm", ""
    ),  # Cannot call sizes() on tensor with symbolic sizes/strides
    xfail(
        "nn.functional.nll_loss", ""
    ),  # Cannot call sizes() on tensor with symbolic sizes/strides
    xfail(
        "_segment_reduce", "lengths"
    ),  # aten.segment_reduce.default - couldn't find symbolic meta functio...
    xfail(
        "_segment_reduce", "offsets"
    ),  # aten.segment_reduce.default - couldn't find symbolic meta functio...
    xfail("trace", ""),  # Cannot call sizes() on tensor with symbolic sizes/strides
    xfail(
        "_upsample_bilinear2d_aa"
    ),  # RuntimeError: isIntList() INTERNAL ASSERT FAILED  Expected IntList but got GenericList
    decorate(
        "linalg.householder_product",
        decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),
    ),
    # many complex operators incorrect striding, metadata
    xfail("fft.fft", ""),
    xfail("fft.hfft2", ""),
    xfail("fft.hfft", ""),
    xfail("fft.hfftn", ""),
    xfail("fft.ifft", ""),
    xfail("fft.ihfft2", ""),
    xfail("fft.ihfft", ""),
    xfail("fft.ihfftn", ""),
    xfail("fft.irfft2", ""),
    xfail("fft.irfft", ""),
    xfail("fft.irfftn", ""),
    xfail("fft.rfft2", ""),
    xfail("fft.rfft", ""),
    xfail("fft.rfftn", ""),
    xfail("stft", ""),  # Cannot call sizes() on tensor with symbolic sizes/strides
}


def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
    if not op.supports_autograd:
        self.skipTest("Op does not support autograd")

    # aot_autograd_check is able to check data specialization by
    # randomizing the inputs. Here's a list of ops that really do not
    # like random inputs for which we want to disable that.
    cant_check_data_specialization = set(
        {
            "nn.functional.max_unpool1d",
            "nn.functional.max_unpool2d",
            "nn.functional.max_unpool3d",
        }
    )
    try_check_data_specialization = op.name not in cant_check_data_specialization

    sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
    for sample_input in sample_inputs_itr:
        t_args = [sample_input.input] + list(sample_input.args)
        t_kwargs = sample_input.kwargs
        try:
            aot_autograd_check(
                op.op,
                t_args,
                t_kwargs,
                dynamic,
                self.assertRaisesRegex,
                self.assertEqual,
                check_gradients=True,
                try_check_data_specialization=try_check_data_specialization,
            )
        except DynamicOutputShapeException:
            self.skipTest("Dynamic output shape operation in trace")
        except GuardOnDataDependentSymNode:
            # Carveout for getitem; I don't want to xfail the entire test
            # because that will reject known to be good tests see
            # https://github.com/pytorch/pytorch/issues/94705
            if op.name == "__getitem__":
                self.skipTest("Dynamic output shape operation in trace")
            else:
                raise


def _test_aot_autograd_module_helper(
    self, device, dtype, training, module_info, *, dynamic=False
):
    module_cls = module_info.module_cls
    module_inputs = module_info.module_inputs_func(
        module_info, device=device, dtype=dtype, requires_grad=True, training=training
    )
    for module_input in module_inputs:
        if module_input.forward_input is None:
            continue

        args, kwargs = (
            module_input.constructor_input.args,
            module_input.constructor_input.kwargs,
        )
        m = module_cls(*args, **kwargs)
        m.to(device).to(dtype)
        m.train(training)

        # Lazy modules need to see an input first to initialize params.
        args, kwargs = (
            module_input.forward_input.args,
            module_input.forward_input.kwargs,
        )
        flat_args, args_spec = pytree.tree_flatten((args, kwargs))

        # PackedSequence is only used for RNNs. It might be possible to fake-ify if they're pytrees but
        # torchdynamo already doesn't support RNNs
        if any(tuple(isinstance(flat_arg, PackedSequence) for flat_arg in flat_args)):
            continue

        if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
            with torch.no_grad():
                m(*args, **kwargs)

        sentinel_val = -42
        is_tensor_spec = [
            sentinel_val if isinstance(arg, torch.Tensor) else arg for arg in flat_args
        ]
        args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)]

        def f(params_buffers_args):
            named_params, named_buffers, args = params_buffers_args
            cur_flat_args = list(is_tensor_spec)
            args = iter(args)
            for idx, v in enumerate(cur_flat_args):
                if v == sentinel_val:
                    cur_flat_args[idx] = next(args)
            c_args, c_kwargs = pytree.tree_unflatten(cur_flat_args, args_spec)
            params_and_buffers = {**named_params, **named_buffers}
            return torch.func.functional_call(m, params_and_buffers, c_args, c_kwargs)

        named_params = dict(m.named_parameters(remove_duplicate=False))
        named_buffers = dict(m.named_buffers(remove_duplicate=False))
        num_params_buffers = len(named_params) + len(named_buffers)
        compiled_f = aot_function(
            f, nop, num_params_buffers=num_params_buffers, dynamic=dynamic
        )
        params_buffers_args = [named_params, named_buffers, args]
        _test_aot_autograd_forwards_backwards_helper(
            f,
            compiled_f,
            params_buffers_args,
            self.assertRaisesRegex,
            self.assertEqual,
            True,
        )


class TestEagerFusionOpInfo(AOTTestCase):
    @ops(op_db + hop_db, allowed_dtypes=(torch.float,))
    @skipOps(
        "TestEagerFusionOpInfo", "test_aot_autograd_exhaustive", aot_autograd_failures
    )
    def test_aot_autograd_exhaustive(self, device, dtype, op):
        _test_aot_autograd_helper(self, device, dtype, op)

    @ops(op_db + hop_db, allowed_dtypes=(torch.float,))
    @patch("functorch.compile.config.debug_assert", True)
    @skipOps(
        "TestEagerFusionOpInfo",
        "test_aot_autograd_symbolic_exhaustive",
        aot_autograd_failures | symbolic_aot_autograd_failures,
    )
    def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
        _test_aot_autograd_helper(self, device, dtype, op, dynamic=True)


aot_autograd_module_failures = set(
    {
        torch.nn.CTCLoss,  # torch._subclasses.fake_tensor.DynamicOutputShapeException: aten._ctc_loss.default
        torch.nn.GaussianNLLLoss,  # RuntimeError: It appears that you're trying to get value out
        # of a tracing tensor with aten._local_scalar_dense.default -
        # erroring out! It's likely that this is caused by data-dependent
        # control flow or similar.
        torch.nn.MultiLabelMarginLoss,  # AssertionError: The values for attribute 'shape' do not match:
        # torch.Size([1]) != torch.Size([]). Outputs of the operator are different in
        # eager-mode PyTorch vs AOTAutograd. This means the operator will have incorrect
        # output underneath torch.compile. This could be because the operator's
        # implementation not traceable or that there is a bug in AOTAutograd.
        torch.nn.TransformerEncoder,  # DataDependentOutputException: aten.eq compares a mask input
        # to a causal mask tensor, to see if Boolean is_causal should be set
        # for TrnasformerEncoder layers, MHA and sdp custom kernels
        torch.nn.Transformer,  # DataDependentOutputException: aten.equal compares a mask input
        # to a causal mask tensor, to see if Boolean is_causal should be set
        # for TransformerEncoder layers, MHA and sdp custom kernels
        # (this bubbles up to Transformer)
    }
)

symbolic_aot_autograd_module_failures = {
    torch.nn.Transformer,  # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
    torch.nn.TransformerEncoder,  # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
    torch.nn.GaussianNLLLoss,  # NotImplementedError: local_scalar_dense/item NYI for torch.bool
    torch.nn.GroupNorm,  # in native_group_norm_backward cpg, _rem = divmod(C, group)
    # TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
    torch.nn.FractionalMaxPool3d,  # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
    torch.nn.BCELoss,  # new_size = _infer_size(target.size(), weight.size())
    # RuntimeError: expected int at position 0, but got: SymInt
}


class TestEagerFusionModuleInfo(AOTTestCase):
    @modules(module_db, allowed_dtypes=(torch.float,))
    @decorateForModules(unittest.expectedFailure, aot_autograd_module_failures)
    def test_aot_autograd_module_exhaustive(self, device, dtype, training, module_info):
        _test_aot_autograd_module_helper(self, device, dtype, training, module_info)

    @modules(module_db, allowed_dtypes=(torch.float,))
    @decorateForModules(
        unittest.expectedFailure,
        aot_autograd_module_failures | symbolic_aot_autograd_module_failures,
    )
    def test_aot_autograd_symbolic_module_exhaustive(
        self, device, dtype, training, module_info
    ):
        _test_aot_autograd_module_helper(
            self, device, dtype, training, module_info, dynamic=True
        )


instantiate_parametrized_tests(TestAOTAutograd)
only_for = "cpu"
instantiate_device_type_tests(
    TestPythonKey,
    globals(),
    only_for=only_for,
)
instantiate_device_type_tests(TestEagerFusionOpInfo, globals(), only_for=only_for)
instantiate_device_type_tests(TestEagerFusionModuleInfo, globals(), only_for=only_for)


@xfail_inherited_tests(
    [
        "test_set__and_data_mutation_bad",
        "test_subclass_metadata_mutation_req_grad_True",
        "test_subclass_metadata_mutation_req_grad_False",
    ]
)
@skipIfTorchDynamo("This test suite already uses dynamo")
class TestAOTAutogradWithDynamo(TestAOTAutograd):
    """
    These are the same as TestAOTAutograd tests, but we run dynamo first to get a graph module.
    """

    def assertExpectedInline(self, *args, **kwargs):
        # These will have different outputs because dynamo returns a different graph module
        # But we don't really care about that assertion when testing with dynamo,
        # only that the outputs match, etc.
        pass

    def make_compiler(self, graph_cell):
        return make_boxed_compiler(partial(extract_graph, graph_cell=graph_cell))

    # Compiler to passes to dynamo
    def run_autograd(
        self,
        f: Callable,
        fw_graph_cell: List[Optional[Callable]],
        decompositions: Optional[Dict],
        keep_input_mutations: bool,
        dynamic: bool,
    ):
        """
        Runs dynamo and aot_autograd with the specified settings
        """

        def dynamo_compiler(gm, inputs, **kwargs):
            result = aot_module_simplified(
                gm,
                inputs,
                fw_compiler=self.make_compiler(fw_graph_cell),
                bw_compiler=self.make_compiler([None]),
                decompositions=decompositions,
                keep_inference_input_mutations=keep_input_mutations,
                # Dynamic is calculated from whether the inputs have fake tensors
            )
            return result

        def torch_compile_wrapper(*args, **kwargs):
            torch._dynamo.reset()
            fn = torch.compile(f, backend=dynamo_compiler)
            try:
                result = fn(*args, **kwargs)
            except torch._dynamo.exc.BackendCompilerFailed as e:
                # So that assertRaises works properly
                raise e.inner_exception from e
            return result

        return torch_compile_wrapper


class MockFXGraphCache:
    """
    In memory version of FXGraphCache so we can isolate testing for FXGraphCache
    """

    def __init__(self) -> None:
        self.cache = {}

    def save(self, key, gm):
        self.cache[key] = gm

    def load(self, gm, inputs):
        key, _ = compiled_fx_graph_hash(gm, inputs, {}, {})
        if key in self.cache:
            gm = make_boxed_func(gm)
            gm._fx_graph_cache_key = key
            return gm
        else:
            self.save(key, gm)
            gm = make_boxed_func(gm)
            gm._fx_graph_cache_key = key
            return gm

    def _lookup_graph(self, key, inputs, local, remote_cache):
        gm = self.cache.get(key)
        if gm is not None:
            gm = make_boxed_func(gm)
        return gm

    def post_compile(self, gm, inputs, cudagraphs):
        pass


# The following tests fail in strict caching mode (i.e. they bypass or
# cache miss instead of cache hitting). They will be fixed in the PRs above this.
FAILING_CACHE_TESTS = (
    # BypassAOTAutogradCache: unsupported nodes
    "test_backward_mutation_data",  # Custom Autograd Function
    "test_backward_mutation_metadata",  # Custom Autograd Function
    "test_custom_autograd",  # Custom Autograd Function
    "test_input_output_aliase_custom_autograd_function",
)


@xfail_inherited_tests(FAILING_CACHE_TESTS)
class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
    """
    In memory version of FXGraphCache so we can isolate testing for FXGraphCache
    """

    def make_compiler(self, fw_graph_cell):
        mock_inductor_cache = self.inductor_cache

        def compiler(gm, inputs):
            nonlocal mock_inductor_cache, fw_graph_cell
            result = mock_inductor_cache.load(gm, inputs)
            fw_graph_cell[0] = gm
            return result

        return compiler

    def run_autograd(
        self,
        f: Callable,
        fw_graph_cell: List[Optional[Callable]],
        decompositions: Optional[Dict],
        keep_input_mutations: bool,
        dynamic: bool,
    ):
        return super().run_autograd(
            f,
            fw_graph_cell,
            decompositions,
            keep_input_mutations,
            dynamic,
        )

    @torch._functorch.config.patch(
        {
            "enable_autograd_cache": True,
            "strict_autograd_cache": True,
            "view_replay_for_aliased_outputs": False,
        }
    )
    @torch._inductor.config.patch("fx_graph_cache", True)
    def verify_aot_autograd(
        self,
        f,
        inp_: Union[Callable, List[Any]],
        *,
        test_mutation: bool = False,
        keep_inp_mutations: bool = False,
        decompositions: Optional[Dict] = None,
        dynamic: bool = False,
        # Only active when inp_ is Callable.
        # TODO: probably consolidate all tests to make inp a Callable.
        make_inputs_subclasses: bool = False,
    ):
        self.inductor_cache = MockFXGraphCache()
        AOTAutogradCache.clear()
        with patch(
            "torch._inductor.codecache.FxGraphCache._lookup_graph",
            new=self.inductor_cache._lookup_graph,
        ), patch(
            "torch._inductor.codecache.FxGraphCache.post_compile",
            new=self.inductor_cache.post_compile,
        ):
            return super().verify_aot_autograd(
                f,
                inp_,
                test_mutation=test_mutation,
                keep_inp_mutations=keep_inp_mutations,
                decompositions=decompositions,
                dynamic=dynamic,
                make_inputs_subclasses=make_inputs_subclasses,
            )

    def test_input_mutation_false_aliasing(self):
        # This test is disabled because it fails in strict cache mode
        # But also can't be xfailed because it causes undefined behavior for
        # ASAN
        self.skipTest("Skipping because it fails in strict cache mode")


if __name__ == "__main__":
    run_tests()
