# Owner(s): ["module: dynamo"]
"""
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_export_persist_assert)
"""
import copy
import functools
import inspect
import io
import operator
import unittest
from enum import Enum
from typing import Dict, List, Sequence
from unittest.mock import patch

import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from functorch.experimental.control_flow import cond
from torch._dynamo import config
from torch._dynamo.exc import UserError
from torch._dynamo.testing import normalize_gm
from torch._higher_order_ops.out_dtype import out_dtype
from torch._subclasses import fake_tensor
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
    ConstraintViolationError,
    DimDynamic,
    ShapeEnv,
    StatelessSymbolicContext,
)
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import TEST_CUDA


class ExportTests(torch._dynamo.test_case.TestCase):
    # TODO(voz): Refactor to a shared test function.
    # The tests in this file are a little redundant,
    # They all take a func, run it with eager, then export it, then compare
    def test_export(self):
        def pre_attention_state_ops(input, mems, state):
            lc_key = state[0]
            lc_val = state[1]
            bar = []
            for i in range(0, 4):
                bar2 = []
                for j in range(0, 3):
                    bar2.append(
                        lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
                    )
                bar.append(bar2)

            return bar

        def func():
            mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
            state = [
                torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
                torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
            ]
            i = torch.tensor(
                [
                    [0.0313, -0.1487, -0.3846, -0.5321],
                    [-1.7073, 1.3331, -0.0890, -1.4935],
                    [-0.8314, -0.1862, -0.5935, 1.5232],
                ]
            )
            return pre_attention_state_ops(i, mems, state)

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func()

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)()
        out_graph = exported[0]

        dynamo_result = out_graph()
        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_no_tensor_computation_fail(self):
        with self.assertRaisesRegex(
            AssertionError,
            "Failed to produce a graph",
        ):
            inp = [torch.randn(3)]
            inp2 = 2
            inps = [inp, inp2]

            def func(x, y):
                return x

            exported = torch._dynamo.export(func, same_signature=False)(*inps)

    def test_no_tensor_computation(self):
        inp = [torch.randn(3)]
        inp2 = 2
        inps = [inp, inp2]

        def func(x, y):
            return x

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
        self.assertExpectedInline(
            out_graph.code.strip(),
            """\
def forward(self, x, y):
    arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
    x = arg0
    return pytree.tree_unflatten([x], self._out_spec)""",
        )

    def test_no_tensor_computation_2(self):
        inp = torch.randn(3)
        inp2 = 2
        inps = [inp, inp2]

        def func(x, y):
            return y

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
        self.assertExpectedInline(
            out_graph.code.strip(),
            """\
def forward(self, x, y):
    arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
    x = arg0
    return pytree.tree_unflatten([2], self._out_spec)""",
        )

    def test_export_mismatched_out(self):
        def func(x):
            y = x + 1
            return ([x, x], (y, y))

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]]))
        out_graph = exported[0]

        dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_shape_control_flow_1(self):
        def func(x):
            if x.shape[0] > 10:
                return x.cos()
            return x.sin()

        opt_func = torch._dynamo.optimize("eager")(func)
        real_result = opt_func(torch.ones(6, 4))

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(torch.ones(6, 4))
        out_graph, out_guards = exported

        dynamo_result = out_graph(torch.ones(6, 4))

        from torch._guards import GuardSource

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
        hit = False
        for guard in out_guards:
            if guard.source == GuardSource.SHAPE_ENV:
                hit = True
                self.assertExpectedInline(
                    guard.code_list,
                    """["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""",  # noqa: B950
                )
                break

        self.assertTrue(hit)

    def test_export_control_flow_with_getattr(self):
        class Animal(Enum):
            COW = "moo"

        class MyModule(torch.nn.Module):
            def __init__(self, a):
                super().__init__()
                self.a = a

            def forward(self, x):
                if self.a == Animal.COW.value:
                    return x * x
                else:
                    raise ValueError("bad")

        module = MyModule("moo")
        input = (torch.ones(4, 3),)
        resA = module(*input)
        graph, _ = torch._dynamo.export(module)(*input)
        resB = graph(*input)
        self.assertTrue(torch._dynamo.utils.same(resA, resB))

    def test_export_graph_bypass(self):
        inp = [
            torch.tensor([0.1, 0.1]),
            torch.tensor([0.2, 0.2]),
            torch.tensor([0.3, 0.3]),
        ]

        def func(x):
            first = x[2]
            second = x[2]
            return first * second

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_list_unpack(self):
        inp = [
            torch.tensor([0.1, 0.1]),
            torch.tensor([0.2, 0.2]),
            torch.tensor([0.3, 0.3]),
        ]

        def func(x):
            first = x[2]
            second = x[2]
            return x[0], first * second, x[1], x[2]

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_with_shallow_list_copy_wo_side_effects(self):
        def f(x):
            y = x.copy()
            return y[0] + y[1]

        inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
        gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
            inp
        ).graph_module
        self.assertTrue(torch._dynamo.utils.same(gm(inp), f(inp)))

    def test_export_with_shallow_list_copy_with_side_effects(self):
        def f(x):
            y = x.copy()
            x[0] = x[1]
            y.append(torch.tensor([[100]]))
            return x[0] + x[1], y[0] + y[1], y[2]

        inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
        gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
            inp
        ).graph_module
        res = gm(inp)
        ref = f(inp)
        self.assertTrue(torch._dynamo.utils.same(res, ref))
        self.assertEqual(res[0], res[1])

    def test_export_mismatched_out_2(self):
        def func(x):
            y = x + 1
            return ([x, x], (y, y))

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]]))
        out_graph = exported[0]

        dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_graph_with_list(self):
        inp = [
            torch.tensor([0.1, 0.1]),
            torch.tensor([0.2, 0.2]),
            torch.tensor([0.3, 0.3]),
            torch.tensor([0.4, 0.4]),
        ]

        def func(x):
            first = x[2]
            second = x[2]
            return first * second, x

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_graph_with_complex_reorder(self):
        inp = [
            torch.tensor([0.1, 0.1]),
            torch.tensor([0.2, 0.2]),
            torch.tensor([0.3, 0.3]),
            torch.tensor([0.4, 0.4]),
        ]

        def func(x):
            first = x[0]
            second = x[1]
            third = x[2]
            return third, first, second, first * second, first * third

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes(self):
        inp = torch.tensor([0.1, 0.1])

        def func(x):
            y = x + 1
            return y, y

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes_2(self):
        inp = torch.tensor([0.1, 0.1])

        def func(x):
            y = x + 1
            return y, y

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes_and_bypass(self):
        inp = torch.tensor([0.1, 0.1])
        inp2 = torch.tensor([0.4, 0.4])
        inps = [inp, inp2]

        def func(x, z):
            y = x + 1
            return y, y, z

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes_and_bypass_with_non_tensor_arg(self):
        inp = torch.tensor([0.1, 0.1])
        inp2 = torch.tensor([0.1, 0.1])
        inp3 = 4
        inps = [inp, inp2, inp3]

        def func(x, z, k):
            y = x + k
            return y, y, z

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes_and_bypass_reorder_with_non_tensor_arg(self):
        inp = torch.tensor([0.1, 0.1])
        inp2 = torch.tensor([0.1, 0.1])
        inp3 = 4
        inps = [inp, inp2, inp3]

        def func(x, z, k):
            y = x + k
            return z, y, y

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    @config.patch(capture_scalar_outputs=True)
    def test_dupes_and_bypass_with_non_tensor_output(self):
        inp = torch.tensor([0.1, 0.1])
        inp2 = torch.tensor([0.1, 0.1])
        inp3 = 4
        inps = [inp, inp2, inp3]

        def func(x, z, k):
            y = x + k
            return y[0].item(), y, z

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_zeroes_in_and_out_different_shape_on_test(self):
        inp = torch.zeros(10)
        inp2 = torch.zeros(10)
        inp3 = torch.zeros(10)
        inps = [inp, inp2, inp3]

        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]

        def func(a, b, c):
            return [[a], [b, c], [a + b], [[c + c]]]

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps_rand)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps_rand)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    @config.patch(capture_scalar_outputs=True)
    def test_zeroes_in_new_shape_scalar_out(self):
        inp = torch.zeros(10)
        inp2 = torch.zeros(10)
        inp3 = torch.zeros(10)
        inps = [inp, inp2, inp3]

        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]

        def func(a, b, c):
            return a[0].item() + b[0].item() + c[0].item()

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps_rand)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps_rand)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    @config.patch(capture_scalar_outputs=True)
    def test_zeroes_in_new_shape_scalar_out_permute(self):
        inp = torch.zeros(10)
        inp2 = torch.zeros(10)
        inp3 = torch.zeros(10)
        inps = [inp, inp2, inp3]

        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]

        def func(a, b, c):
            return b[0].item() + c[0].item() + a[0].item() + a[0].item()

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps_rand)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps_rand)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    @config.patch(capture_scalar_outputs=True)
    def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self):
        inp = torch.zeros(10)
        inp2 = torch.zeros(10)
        inp3 = torch.zeros(10)
        inps = [inp, inp2, inp3]

        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]

        def func(a, b, c):
            return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps_rand)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps_rand)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_func_return(self):
        inp = torch.zeros(10)
        inp2 = torch.zeros(10)
        inp3 = torch.zeros(10)
        inps = [inp, inp2, inp3]

        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]

        def func(a, b, c):
            x = a + b + c

            def func2(y):
                return x * y

            return func2(x)

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps_rand)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps_rand)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dict_return(self):
        inp = torch.zeros(10)
        inp2 = torch.zeros(10)
        inp3 = torch.zeros(10)
        inps = [inp, inp2, inp3]

        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]

        def func(a, b, c):
            x = a + b + c
            return {"a": x}

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps_rand)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps_rand)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_with_aten_graph(self):
        def pre_attention_state_ops(input, mems, state):
            lc_key = state[0]
            lc_val = state[1]
            bar = []
            for i in range(0, 4):
                bar2 = []
                for j in range(0, 3):
                    bar2.append(
                        lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
                    )
                bar.append(bar2)

            return bar

        def func():
            mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
            state = [
                torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
                torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
            ]
            i = torch.tensor(
                [
                    [0.0313, -0.1487, -0.3846, -0.5321],
                    [-1.7073, 1.3331, -0.0890, -1.4935],
                    [-0.8314, -0.1862, -0.5935, 1.5232],
                ]
            )
            return pre_attention_state_ops(i, mems, state)

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func()

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)()
        out_graph = exported[0]

        dynamo_result = out_graph()
        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_no_tensor_computation_with_aten_graph(self):
        inp = [torch.randn(3)]
        inp2 = 2
        inps = [inp, inp2]

        def func(x, y):
            return x

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
        self.assertExpectedInline(
            out_graph.code.strip(),
            """\
def forward(self, x, y):
    arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
    arg0_1 = arg0
    return pytree.tree_unflatten([arg0_1], self._out_spec)""",
        )

    def test_no_tensor_computation_2_with_aten_graph(self):
        inp = torch.randn(3)
        inp2 = 2
        inps = [inp, inp2]

        def func(x, y):
            return y

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
        self.assertExpectedInline(
            out_graph.code.strip(),
            """\
def forward(self, x, y):
    arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
    arg0_1 = arg0
    return pytree.tree_unflatten([2], self._out_spec)""",
        )

    def test_export_mismatched_out_with_aten_graph(self):
        def func(x):
            y = x + 1
            return ([x, x], (y, y))

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(
            torch.tensor([[[1.3737, 0.1]]])
        )
        out_graph = exported[0]

        dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_graph_bypass_with_aten_graph(self):
        inp = [
            torch.tensor([0.1, 0.1]),
            torch.tensor([0.2, 0.2]),
            torch.tensor([0.3, 0.3]),
        ]

        def func(x):
            first = x[2]
            second = x[2]
            return first * second

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_list_unpack_with_aten_graph(self):
        inp = [
            torch.tensor([0.1, 0.1]),
            torch.tensor([0.2, 0.2]),
            torch.tensor([0.3, 0.3]),
        ]

        def func(x):
            first = x[2]
            second = x[2]
            return x[0], first * second, x[1], x[2]

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_mismatched_out_2_with_aten_graph(self):
        def func(x):
            y = x + 1
            return ([x, x], (y, y))

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(
            torch.tensor([[[1.3737, 0.1]]])
        )
        out_graph = exported[0]

        dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_graph_with_list_with_aten_graph(self):
        inp = [
            torch.tensor([0.1, 0.1]),
            torch.tensor([0.2, 0.2]),
            torch.tensor([0.3, 0.3]),
            torch.tensor([0.4, 0.4]),
        ]

        def func(x):
            first = x[2]
            second = x[2]
            return first * second, x

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_graph_with_complex_reorder_with_aten_graph(self):
        inp = [
            torch.tensor([0.1, 0.1]),
            torch.tensor([0.2, 0.2]),
            torch.tensor([0.3, 0.3]),
            torch.tensor([0.4, 0.4]),
        ]

        def func(x):
            first = x[0]
            second = x[1]
            third = x[2]
            return third, first, second, first * second, first * third

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes_with_aten_graph(self):
        inp = torch.tensor([0.1, 0.1])

        def func(x):
            y = x + 1
            return y, y

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes_2_with_aten_graph(self):
        inp = torch.tensor([0.1, 0.1])

        def func(x):
            y = x + 1
            return y, y

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(inp)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(inp)
        out_graph = exported[0]

        dynamo_result = out_graph(inp)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes_and_bypass_with_aten_graph(self):
        inp = torch.tensor([0.1, 0.1])
        inp2 = torch.tensor([0.4, 0.4])
        inps = [inp, inp2]

        def func(x, z):
            y = x + 1
            return y, y, z

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self):
        inp = torch.tensor([0.1, 0.1])
        inp2 = torch.tensor([0.1, 0.1])
        inp3 = 4
        inps = [inp, inp2, inp3]

        def func(x, z, k):
            y = x + k
            return y, y, z

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self):
        inp = torch.tensor([0.1, 0.1])
        inp2 = torch.tensor([0.1, 0.1])
        inp3 = 4
        inps = [inp, inp2, inp3]

        def func(x, z, k):
            y = x + k
            return z, y, y

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    @config.patch(capture_scalar_outputs=True)
    def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self):
        inp = torch.tensor([0.1, 0.1])
        inp2 = torch.tensor([0.1, 0.1])
        inp3 = 4
        inps = [inp, inp2, inp3]

        def func(x, z, k):
            y = x + k
            return y[0].item(), y, z

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self):
        inp = torch.zeros(10)
        inp2 = torch.zeros(10)
        inp3 = torch.zeros(10)
        inps = [inp, inp2, inp3]

        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]

        def func(a, b, c):
            return [[a], [b, c], [a + b], [[c + c]]]

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps_rand)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps_rand)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_func_return_with_aten_graph(self):
        inp = torch.zeros(10)
        inp2 = torch.zeros(10)
        inp3 = torch.zeros(10)
        inps = [inp, inp2, inp3]

        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]

        def func(a, b, c):
            x = a + b + c

            def func2(y):
                return x * y

            return func2(x)

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps_rand)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps_rand)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_dict_return_with_aten_graph(self):
        inp = torch.zeros(10)
        inp2 = torch.zeros(10)
        inp3 = torch.zeros(10)
        inps = [inp, inp2, inp3]

        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]

        def func(a, b, c):
            x = a + b + c
            return {"a": x}

        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps_rand)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps_rand)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_with_stack_trace(self):
        inp = torch.randn(4, 4)

        class MyBlock(torch.nn.Module):
            def forward(self, x):
                x = torch.nn.functional.linear(x, torch.randn(4, 4))
                return torch.cos(x).relu() + 1

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.block = MyBlock()

            def forward(self, x):
                out = self.block(x)
                return out

        exported = torch._dynamo.export(MyModule(), aten_graph=False)(inp)
        out_graph = exported[0]

        for node in out_graph.graph.nodes:
            if node.op not in {"placeholder", "output"}:
                self.assertTrue(node.stack_trace is not None)
                self.assertTrue(node.meta["nn_module_stack"] is not None)
                self.assertTrue(node.meta["source_fn_stack"] is not None)

        torch._dynamo.reset()

        exported = torch._dynamo.export(MyModule(), aten_graph=True)(inp)
        out_graph = exported[0]
        for node in out_graph.graph.nodes:
            if node.op == "call_function":
                self.assertTrue(node.stack_trace is not None)
                self.assertTrue(node.meta["nn_module_stack"] is not None)
                self.assertTrue(node.meta["source_fn_stack"] is not None)
                self.assertTrue(node.meta["val"] is not None)
                self.assertTrue(node.meta["original_aten"] is not None)

    def test_export_preserves_nn_module_stack_for_get_attr(self):
        inp = torch.randn(4, 4)

        class MyBlock(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = torch.nn.Parameter(torch.ones(1, 1))
                self.buffer = torch.nn.Buffer(torch.ones(1, 1))

            def forward(self, x):
                x = torch.nn.functional.linear(x, torch.randn(4, 4))
                return torch.cos(x).relu() + self.weight + self.buffer

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.block = MyBlock()

            def forward(self, x):
                out = self.block(x)
                return out

        m = MyModule()
        exported = torch._dynamo.export(m, aten_graph=False)(inp)
        out_graph = exported[0]

        attr_access_count = 0
        for node in out_graph.graph.nodes:
            if node.op == "get_attr":
                attr_access_count += 1
                self.assertTrue(node.meta["nn_module_stack"] is not None)
        self.assertEqual(attr_access_count, 2)

        torch._dynamo.reset()

        exported = torch._dynamo.export(m, aten_graph=True)(inp)
        out_graph = exported[0]

        attr_access_count = 0
        for node in out_graph.graph.nodes:
            if node.op == "get_attr":
                attr_access_count += 1
                self.assertTrue(node.meta["nn_module_stack"] is not None)
        self.assertEqual(attr_access_count, 2)

    def test_export_compare_optimize_with_make_fx(self):
        inp = torch.tensor([0.1, 0.1])
        linear = torch.nn.Linear(2, 2)

        def func(x):
            x = x + 1
            y = x.t()
            y = y.relu()
            y = linear(y)
            return y

        exported = torch._dynamo.export(func, aten_graph=True)(inp)
        out_graph = exported[0]
        export_result = out_graph(inp)

        torch._dynamo.reset()

        def compiler(gm, sample_inputs):
            def fw(*args):
                aten_gm = make_fx(gm)(*args)
                return aten_gm(*args)

            return fw

        opt_func = torch._dynamo.optimize(compiler, nopython=True, dynamic=True)(func)
        make_fx_result_through_backend = opt_func(inp)

        fx_g = make_fx(func)(inp)
        make_fx_result_through_direct = fx_g(inp)

        self.assertTrue(
            torch._dynamo.utils.same(make_fx_result_through_backend, export_result)
        )
        self.assertTrue(
            torch._dynamo.utils.same(make_fx_result_through_direct, export_result)
        )

    def test_export_with_constant_method_on_module(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 2))
                self.linear = torch.nn.Linear(2, 2)

            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                return torch.nonzero(x)

            def forward(self, x):
                y = torch.sin(x)
                x = self.linear(x)
                y = self.helper_fn(x)
                return y

        module = MyModule()
        real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
        module = MyModule()
        graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
        result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))
        result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_method_on_module_invoke_twice(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 2))
                self.linear = torch.nn.Linear(2, 2)

            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                return torch.nonzero(x)

            def forward(self, x):
                y = torch.sin(x)
                x = self.linear(x)
                y = self.helper_fn(x) + self.helper_fn(x)
                return y

        module = MyModule()
        real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
        module = MyModule()
        graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
        result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))
        result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_free_function(self):
        @torch._dynamo.assume_constant_result
        def helper_fn(x):
            return torch.nonzero(x)

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 2))
                self.linear = torch.nn.Linear(2, 2)

            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                return torch.nonzero(x)

            def forward(self, x):
                y = torch.sin(x)
                x = self.linear(x)
                y = helper_fn(x) + self.helper_fn(x)
                return y

        module = MyModule()
        real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
        module = MyModule()
        graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
        result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))
        result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_free_function_and_class_method(self):
        @torch._dynamo.assume_constant_result
        def helper_fn(x):
            return torch.nonzero(x)

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 2))
                self.linear = torch.nn.Linear(2, 2)

            def forward(self, x):
                y = torch.sin(x)
                x = self.linear(x)
                y = helper_fn(x)
                return y

        module = MyModule()
        real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
        module = MyModule()
        graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
        result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))
        result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_free_function_and_class_method_multiarg(self):
        @torch._dynamo.assume_constant_result
        def helper_fn(x):
            return torch.nonzero(x)

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 2))
                self.linear = torch.nn.Linear(2, 2)

            def forward(self, x, z):
                y = torch.sin(x)
                x = self.linear(x)
                y = helper_fn(x) + helper_fn(z)
                return y

        module = MyModule()
        real_result = module(
            torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
        )
        module = MyModule()
        graph, _ = torch._dynamo.export(module)(
            torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
        )
        result = graph(
            torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]])
        )
        self.assertTrue(torch._dynamo.utils.same(result, real_result))
        result = graph(
            torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]])
        )
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_free_function_and_class_method_multiarg_diff(self):
        @torch._dynamo.assume_constant_result
        def helper_fn(x):
            return torch.nonzero(x)

        class MyModule(torch.nn.Module):
            def forward(self, x, z):
                y = helper_fn(x) + helper_fn(z)
                return y

        module = MyModule()
        real_result = module(
            torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
        )
        module = MyModule()
        graph, _ = torch._dynamo.export(module)(
            torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]])
        )
        result = graph(
            torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]])
        )
        self.assertTrue(torch._dynamo.utils.same(result, real_result))
        result = graph(
            torch.tensor([[1, 0], [0.25, 0.25]]),
            torch.tensor([[0.33, 0.33], [0.25, 0.25]]),
        )
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_tuple_nonzero(self):
        class MyModule(torch.nn.Module):
            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                return (torch.nonzero(x), torch.nonzero(x))

            def forward(self, x):
                y = torch.tensor([0.5])
                elements = self.helper_fn(x)
                all_y = []
                for element in elements:
                    for item in element:
                        all_y.append(y * item)
                return all_y

        module = MyModule()
        real_result = module(torch.tensor([1.0, 1.0]))
        graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))

        # Tensor input can be almost anything here, and the result will capture what we
        # made constant at compile time.
        result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_list_nonzero(self):
        class MyModule(torch.nn.Module):
            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                return [torch.nonzero(x), torch.nonzero(x)]

            def forward(self, x):
                y = torch.tensor([0.5])
                elements = self.helper_fn(x)
                all_y = []
                for element in elements:
                    for item in element:
                        all_y.append(y * item)
                return all_y

        module = MyModule()
        real_result = module(torch.tensor([1.0, 1.0]))
        graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))

        # Tensor input can be almost anything here, and the result will capture what we
        # made constant at compile time.
        result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_list_nonzero_free_function(self):
        @torch._dynamo.assume_constant_result
        def helper_fn(x):
            return [torch.nonzero(x), torch.nonzero(x)]

        class MyModule(torch.nn.Module):
            def forward(self, x):
                y = torch.tensor([0.5])
                elements = helper_fn(x)
                all_y = []
                for element in elements:
                    for item in element:
                        all_y.append(y * item)
                return all_y

        module = MyModule()
        real_result = module(torch.tensor([1.0, 1.0]))
        graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))

        # Tensor input can be almost anything here, and the result will capture what we
        # made constant at compile time.
        result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_dict_values(self):
        class MyModule(torch.nn.Module):
            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                return {"x": x, "x^2": x * x}

            def forward(self, x):
                y = torch.tensor([0.5])
                elements = self.helper_fn(x)
                y = y * elements["x"]
                y = y * elements["x^2"]
                return y

        module = MyModule()
        real_result = module(torch.tensor([2.0, 2.0]))
        graph, guards = torch._dynamo.export(module)(torch.tensor([2.0, 2.0]))

        # Tensor input can be almost anything here, and the result will capture what we
        # made constant at compile time.
        result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_none_control_flow(self):
        class MyModule(torch.nn.Module):
            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                if x.item() < 0:
                    return None
                else:
                    return x

            def forward(self, x):
                y = torch.tensor([0.5])
                x = self.helper_fn(x)
                if x is None:
                    return y
                return y * x

        module = MyModule()
        real_result = module(torch.tensor([-1]))

        # X is negative, so .item() < 0, which means we return y
        self.assertEqual(real_result, torch.tensor([0.5]))

        graph, guards = torch._dynamo.export(module)(torch.tensor([-1]))
        result = graph(torch.tensor([2]))
        # X is positive, but we compiled helper_fn to return None, so it will still return y
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_not_none_control_flow(self):
        class MyModule(torch.nn.Module):
            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                if x.item() < 0:
                    return None
                else:
                    return x

            def forward(self, x):
                y = torch.tensor([0.5])
                x = self.helper_fn(x)
                if x is None:
                    return y
                return y * x

        module = MyModule()
        real_result = module(torch.tensor([2]))

        # X is positive, so .item() > 0, which means we return y * x
        self.assertEqual(real_result, torch.tensor([1.0]))

        graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
        result = graph(torch.tensor([-0.5]))
        # X is negative, but we compiled helper_fn to return x, so it will still return y * x
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_none_control_flow_free_func(self):
        @torch._dynamo.assume_constant_result
        def helper_fn(x):
            if x.item() < 0:
                return None
            else:
                return x

        class MyModule(torch.nn.Module):
            def forward(self, x):
                y = torch.tensor([0.5])
                x = helper_fn(x)
                if x is None:
                    return y
                return y * x

        module = MyModule()
        real_result = module(torch.tensor([-1]))

        # X is negative, so .item() < 0, which means we return y
        self.assertEqual(real_result, torch.tensor([0.5]))

        graph, guards = torch._dynamo.export(module)(torch.tensor([-1]))
        result = graph(torch.tensor([2]))
        # X is positive, but we compiled helper_fn to return None, so it will still return y
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_not_none_control_flow_pos(self):
        class MyModule(torch.nn.Module):
            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                if x.item() < 0:
                    return None
                else:
                    return x

            def forward(self, x):
                y = torch.tensor([0.5])
                x = self.helper_fn(x)
                if x is None:
                    return y
                return y * x

        module = MyModule()
        real_result = module(torch.tensor([2]))

        # X is positive, so .item() > 0, which means we return y * x
        self.assertEqual(real_result, torch.tensor([1.0]))

        graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
        result = graph(torch.tensor([-0.5]))
        # X is negative, but we compiled helper_fn to return x, so it will still return y * x
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_not_none_control_flow_free_func(self):
        @torch._dynamo.assume_constant_result
        def helper_fn(x):
            if x.item() < 0:
                return None
            else:
                return x

        class MyModule(torch.nn.Module):
            def forward(self, x):
                y = torch.tensor([0.5])
                x = helper_fn(x)
                if x is None:
                    return y
                return y * x

        module = MyModule()
        real_result = module(torch.tensor([2]))

        # X is positive, so .item() > 0, which means we return y * x
        self.assertEqual(real_result, torch.tensor([1.0]))

        graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
        result = graph(torch.tensor([-0.5]))
        # X is negative, but we compiled helper_fn to return x, so it will still return y * x
        self.assertTrue(torch._dynamo.utils.same(result, real_result))

    def test_export_with_constant_not_return_const(self):
        class MyModule(torch.nn.Module):
            @torch._dynamo.assume_constant_result
            def helper_fn(self, x):
                return self.val

            def forward(self, x):
                y = torch.tensor([0.5])
                x = self.helper_fn(x)
                if x == "A":
                    return y
                return -1

        module = MyModule()
        module.val = "A"
        resA = module(torch.tensor([2]))
        graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
        module.val = "B"
        resB = graph(torch.tensor([2]))
        self.assertTrue(torch._dynamo.utils.same(resA, resB))

    def test_export_with_builtin_op_on_assume_constant(self):
        @torch._dynamo.assume_constant_result
        def get_y(y) -> torch.Tensor:
            return y

        class Bob(torch.nn.Module):
            def __init__(self, p, val) -> None:
                super().__init__()
                self.p = p
                self.y = torch.nn.Parameter(torch.tensor(val))

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                # This only looks dynamic but it's actually a constant value
                if get_y(self.y) < self.p:
                    return torch.cat([x, x])
                else:
                    return x

        model = Bob(0.5, 0.3)
        inp = torch.ones(3, 4)
        graph, guards = torch._dynamo.export(model)(inp)
        self.assertEqual(model(inp), graph(inp))

    def test_export_with_constant_in_unspecialized_nn_module(self):
        class Module(torch.nn.Module):
            def __init__(self, y):
                super().__init__()
                self.y = y

            @torch._dynamo.assume_constant_result
            def check(self):
                return self.y[0].item() == 1

            def forward(self, x):
                # This line leads to module obj being tracked as UnspecializedNNModuleVariable in dynamo
                self.device = x.device

                if self.check():
                    return x + 1
                else:
                    return x + 2

        model = Module(torch.tensor([1]))
        inp = torch.ones(3, 4)
        graph, _ = torch._dynamo.export(model)(inp)
        self.assertEqual(model(inp), graph(inp))

    def test_export_decomp(self):
        def f(x):
            return x.t() + x.t()

        def nop(x):
            return x.cos()

        graph, _ = torch._dynamo.export(
            f,
            aten_graph=True,
            decomposition_table={torch.ops.aten.t.default: nop},
        )(torch.randn(5))
        self.assertEqual(
            len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
            0,
        )

        graph, _ = torch._dynamo.export(f, aten_graph=True, decomposition_table=None)(
            torch.randn(5)
        )
        self.assertEqual(
            len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
            2,
        )

    def test_export_decomp_asserts_bad_args(self):
        def f(x):
            return x.t() + x.t()

        def nop(x):
            return x.cos()

        with self.assertRaises(AssertionError):
            graph, _ = torch._dynamo.export(
                f,
                (torch.randn(5)),
                aten_graph=False,
                decomposition_table={torch.ops.aten.t.default: nop},
            )

    @config.patch(capture_scalar_outputs=True)
    def test_export_with_module_layer(self):
        from functorch.experimental.control_flow import cond

        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)

            def forward(self, pred, x):
                def true_fn(val):
                    return self.linear(val) * torch.tensor(2)

                def false_fn(val):
                    return self.linear(val) * torch.tensor(-1)

                return cond(pred, true_fn, false_fn, [x])

        mod = Module()
        x = torch.randn([3, 3])
        pred = torch.tensor(x[0][0].item() < 0)
        real_result = mod.forward(pred, x)

        torch._dynamo.reset()

        exported = torch._dynamo.export(mod.forward)(pred, x)
        out_graph = exported[0]

        dynamo_result = out_graph(pred, x)
        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

        # New X, just to show we did not specialize
        x = x * -1
        pred = torch.tensor(x[0][0].item() < 0)
        real_result_2 = mod.forward(pred, x)
        dynamo_result_2 = out_graph(pred, x)
        self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2))

    @config.patch(capture_scalar_outputs=True)
    def test_export_with_cond_branches_calling_methods(self):
        from functorch.experimental.control_flow import cond

        class Module(torch.nn.Module):
            # ok
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)

            def t(self, val):
                return val + 1

            def f(self, val):
                return val - 1

            def true_fn(self, val):
                return self.linear(val) + self.t(val)

            def false_fn(self, val):
                return self.linear(val) - self.f(val)

            def forward(self, pred, x):
                return cond(pred, self.true_fn, self.false_fn, [x])

        mod = Module()
        x = torch.randn([3, 3])
        pred = torch.tensor(x[0][0].item() < 0)
        real_result = mod.forward(pred, x)
        out_graph, _ = torch._dynamo.export(mod.forward)(pred, x)
        dynamo_result = out_graph(pred, x)
        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    @config.patch(capture_scalar_outputs=True)
    def test_export_with_cond_closure(self):
        from functorch.experimental.control_flow import cond

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, pred, x):
                def true_fn(x):
                    return x * 2

                def false_fn(x):
                    return x - 2

                return cond(pred, true_fn, false_fn, [x])

        class Bar(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, pred, x):
                def true_fn(x):
                    return x * 2

                def false_fn(x):
                    return x - 2

                return cond(pred, true_fn, false_fn, [x + 1])

        class FooBar(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)

            def forward(self, pred, x):
                y = x + x

                def true_fn(x, y):
                    return self.linear(x) * (x + y)

                def false_fn(x, y):
                    return x * (y - x)

                return cond(pred, true_fn, false_fn, [x, y])

        for Module in [Foo, Bar, FooBar]:
            mod = Module()
            x = torch.randn([3, 3], requires_grad=True)
            pred = torch.tensor(x[0][0].item() < 0)
            real_result = mod.forward(pred, x)
            out_graph, _ = torch._dynamo.export(mod.forward)(pred, x)
            dynamo_result = out_graph(pred, x)
            self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_with_cond_with_closed_function(self):
        def hello(x):
            return x + 1

        def hi(x):
            return x + 2

        def foo(pred, x):
            def true_fn(x):
                return hello(x)

            def false_fn(x):
                return hi(x)

            return cond(pred, true_fn, false_fn, [x])

        x = torch.randn(5)
        pred = x[0] > 0
        real_result = foo(pred, x)
        out_graph, _ = torch._dynamo.export(foo)(pred, x)
        dynamo_result = out_graph(pred, x)
        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_with_cond_dynamic_shape_pred(self):
        from functorch.experimental.control_flow import cond

        class Module(torch.nn.Module):
            def forward(self, x):
                def true_fn(x):
                    return x + x

                def false_fn(x):
                    return x[:2]

                return cond(x.shape[0] <= 2, true_fn, false_fn, [x])

        class Module2(torch.nn.Module):
            def forward(self, x):
                def true_fn(x):
                    return x + x

                def false_fn(x):
                    return x[:2]

                return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))

        mods = [Module(), Module2()]
        for mod in mods:
            x = torch.randn(2, 2)
            out_graph, guards = torch._dynamo.export(mod)(x)
            self.assertExpectedInline(
                out_graph.code.strip(),
                """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    l_x_ = arg0
    size = l_x_.size()
    getitem = size[0];  size = None
    le = getitem <= 2;  getitem = None
    cond_true_0 = self.cond_true_0
    cond_false_0 = self.cond_false_0
    cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]);  le = cond_true_0 = cond_false_0 = l_x_ = None
    getitem_2 = cond[0];  cond = None
    return pytree.tree_unflatten([getitem_2], self._out_spec)""",
            )
            self.assertExpectedInline(
                out_graph.cond_true_0.code.strip(),
                """\
def forward(self, l_x_):
    l_x__1 = l_x_
    add = l_x__1 + l_x__1;  l_x__1 = None
    return (add,)""",
            )
            self.assertExpectedInline(
                out_graph.cond_false_0.code.strip(),
                """\
def forward(self, l_x_):
    l_x__1 = l_x_
    getitem = l_x__1[slice(None, 2, None)];  l_x__1 = None
    return (getitem,)""",
            )
            with self.assertRaisesRegex(
                torch._dynamo.exc.UncapturedHigherOrderOpError,
                "Cond doesn't work unless it is captured completely with torch.compile",
            ):
                # True branch and false branch return tensors of different shape
                torch._dynamo.export(mod)(torch.randn(3, 2))

            # We specialize into one of the branches since predicate is a python boolean.
            test_x = torch.randn(3, 2)
            mod(test_x)

    def test_export_with_map_cond(self):
        from functorch.experimental.control_flow import cond, map

        class Module(torch.nn.Module):
            def inner(self, x, pred):
                def true_fn(x):
                    return x + x

                def false_fn(x):
                    return x * x

                return cond(pred, true_fn, false_fn, [x])

            def forward(self, pred, xs):
                def body(x, pred):
                    return self.inner(x, pred)

                return map(body, xs, pred)

        mod = Module()
        x = torch.randn(3, 2, 1)
        pred_x = torch.tensor(True)

        y = torch.randn(4, 3, 2)
        pred_y = torch.tensor(False)
        real_result = mod(pred_y, y)

        out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
        self.assertEqual(real_result, out_graph(pred_y, y))

    def test_export_with_map_zero_sized_tensor(self):
        from functorch.experimental.control_flow import map

        class Module(torch.nn.Module):
            def forward(self, xs):
                def body(x):
                    return x + 1

                return map(body, xs)

        mod = Module()
        xs = torch.randn(0, 2)
        with self.assertRaisesRegex(
            torch._dynamo.exc.Unsupported,
            "zero-sized tensor",
        ):
            out_graph, _ = torch._dynamo.export(mod)(xs)

    def test_export_meta_val(self):
        def f(x, y, z):
            return x * y + z

        gm, _ = torch._dynamo.export(
            f,
            aten_graph=True,
        )(
            torch.ones(3, 2),
            torch.zeros(3, 2),
            torch.ones(3, 2),
        )
        for node in gm.graph.nodes:
            if node.op == "placeholder":
                self.assertIn("val", node.meta)

    def test_input_container_type(self):
        def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
            return {"a": x.sum() + sum(y).sum()}

        inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])

        gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp)

        self.assertEqual(gm(*inp), f(*inp))

    @config.patch(assume_static_by_default=False)
    def test_export_symbolic_shape(self):
        def f(x: torch.Tensor) -> torch.Tensor:
            return torch.empty(x.shape[0] * 2)

        inp = (torch.randn(6, 5),)
        gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp)

        has_sym_size = False
        for node in gm.graph.nodes:
            if node.target is torch.ops.aten.sym_size.int:
                has_sym_size = True

        self.assertTrue(has_sym_size)

    @config.patch(assume_static_by_default=False)
    def test_dynamic_slicing(self):
        def f(x):
            return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]

        gm_aten_mode, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))

        inp = torch.randn(6, 7)
        self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape)

        count = 0
        # aten graph should flatten getitem calls to actual
        # slice kernel call.
        for node in gm_aten_mode.graph.nodes:
            if (
                node.op == "call_function"
                and node.target == torch.ops.aten.slice.Tensor
            ):
                count += 1

        self.assertEqual(count, 2)

        gm_torch_mode, _ = torch._dynamo.export(f, aten_graph=False)(torch.randn(4, 5))

        # In torch mode, the graph should contain 3 getitem methods
        # one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice
        # this is because Tensor class has its' own getitem method
        # which gets translated to aten.Slice later.
        count = 0
        for node in gm_torch_mode.graph.nodes:
            if node.op == "call_function" and node.target == operator.getitem:
                count += 1

        self.assertEqual(count, 1)
        self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape)

    def test_dynamic_slicing_invalid(self):
        def g(x, y):
            return x[y : x.shape[0]]

        with self.assertRaisesRegex(
            torch._dynamo.exc.Unsupported,
            "Dynamic slicing on data-dependent value is not supported",
        ):
            torch._dynamo.export(
                g,
                aten_graph=True,
            )(
                torch.randn(4, 5),
                torch.tensor(2),
            )

    @config.patch(capture_scalar_outputs=True)
    def test_dynamic_slicing_simple(self):
        def f(x):
            return x[slice(None, None, None)]

        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))

        inp = torch.randn(6, 7)
        self.assertEqual(gm(inp), f(inp))

    def test_pre_dispatch_simple(self):
        def f(x):
            y = torch.ones_like(x)
            return torch.matmul(x, y)

        gm, _ = torch._dynamo.export(
            f,
            aten_graph=True,
            pre_dispatch=True,
            tracing_mode="fake",
        )(
            torch.randn(5, 5),
        )

        inp = torch.randn(6, 6)
        self.assertEqual(gm(inp), f(inp))
        self.assertExpectedInline(
            gm.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    arg0_1 = arg0
    ones_like = torch.ops.aten.ones_like.default(arg0_1, pin_memory = False)
    matmul = torch.ops.aten.matmul.default(arg0_1, ones_like);  arg0_1 = ones_like = None
    return pytree.tree_unflatten([matmul], self._out_spec)""",
        )

    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
    def test_export_cond_in_aten_symbolic(self):
        class ConditionOp(torch.nn.Module):
            def true_fn(self, x, y):
                return x * y

            def false_fn(self, x, y):
                return x + y

            def forward(self, pred, x, y):
                return cond(pred, self.true_fn, self.false_fn, [x, y])

        model = ConditionOp()
        inp = (
            torch.tensor(False),
            torch.randn(4, 4),
            torch.randn(4, 4),
        )
        gm, _ = torch._dynamo.export(model, aten_graph=True)(*inp)

        gm.print_readable()

        self.assertEqual(gm(*inp), model(*inp))

    def test_export_with_kwargs(self):
        def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs):
            out = pos0
            for arg in tuple0:
                out *= arg
            for arg in myargs:
                out *= arg
            out *= mykw0
            out *= mykwargs["input0"] * mykwargs["input1"]
            return out

        mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
        tuple0 = (torch.randn(4), torch.randn(4))
        mykw0 = torch.randn(4)
        pos0 = torch.randn(4)
        myargs = [torch.randn(4), torch.randn(4)]

        expected_argument_names = [
            "pos0",
            "tuple0",
            "myargs_0",
            "myargs_1",
            "mykw0",
            "input0",
            "input1",
        ]
        self._test_export_preserving_original_signature(
            fn_with_kwargs,
            expected_argument_names,
            pos0,
            tuple0,
            *myargs,
            mykw0=mykw0,
            **mykwargs,
        )

    def test_export_with_kwargs_and_empty_args(self):
        def fn_with_kwargs(mykw0=None, **mykwargs):
            out = mykw0
            out *= mykwargs["input0"] * mykwargs["input1"]
            return out

        mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
        mykw0 = torch.randn(4)

        expected_argument_names = ["mykw0"] + list(mykwargs.keys())
        self._test_export_preserving_original_signature(
            fn_with_kwargs, expected_argument_names, mykw0, **mykwargs
        )

    def test_export_with_args_and_empty_kwargs(self):
        def fn_with_kwargs(pos0, tuple0, *myargs):
            out = pos0
            for arg in tuple0:
                out *= arg
            for arg in myargs:
                out *= arg
            return out

        tuple0 = (torch.randn(4), torch.randn(4))
        pos0 = torch.randn(4)
        myargs = [torch.randn(4), torch.randn(4)]

        expected_argument_names = ["pos0", "tuple0", "myargs_0", "myargs_1"]
        self._test_export_preserving_original_signature(
            fn_with_kwargs, expected_argument_names, pos0, tuple0, *myargs
        )

    @common_utils.parametrize(
        "default_value",
        [
            common_utils.subtest(None, name="None"),
            common_utils.subtest(42.0, name="float"),
            common_utils.subtest(
                # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
                torch.randn(4),
                name="tensor",
                decorators=[unittest.expectedFailure],
            ),
            common_utils.subtest(
                # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
                (torch.randn(4),),
                name="tuple",
                decorators=[unittest.expectedFailure],
            ),
        ],
    )
    def test_export_with_args_with_default(self, default_value):
        def fn(pos0, pos1_default=default_value):
            out = pos0
            if pos1_default is None:
                pos1_default = torch.randn(4)
            if isinstance(pos1_default, tuple):
                pos1_default = pos1_default[0]
            out *= pos1_default
            return out

        pos0 = torch.randn(4)
        expected_argument_names = ["pos0"]
        self._test_export_preserving_original_signature(
            fn, expected_argument_names, pos0
        )

    @common_utils.parametrize(
        "default_value",
        [
            common_utils.subtest(None, name="None"),
            common_utils.subtest(42.0, name="float"),
            common_utils.subtest(
                # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
                torch.randn(4),
                name="tensor",
                decorators=[unittest.expectedFailure],
            ),
            common_utils.subtest(
                # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
                (torch.randn(4),),
                name="tuple",
                decorators=[unittest.expectedFailure],
            ),
        ],
    )
    def test_export_with_kwargs_with_default(self, default_value):
        def fn(pos0, *, kw0, kw1_default=default_value, **kwargs):
            out = pos0
            out += kw0
            if kw1_default is None:
                kw1_default = torch.randn(4)
            elif isinstance(kw1_default, tuple):
                kw1_default = kw1_default[0]
            out += kw1_default
            out += kwargs["kw2"]
            return out

        pos0 = torch.randn(4)
        kw0 = torch.randn(4)
        kw2 = torch.randn(4)

        args = (pos0,)
        kwargs = {"kw0": kw0, "kw2": kw2}
        expected_argument_names = ["pos0", "kw0", "kw2"]
        self._test_export_preserving_original_signature(
            fn, expected_argument_names, *args, **kwargs
        )

    def test_export_with_wrapped_fn(self):
        # To ensure dynamo.export is robust to wrapped functions
        # when it cannot use `inspect` to retrieve original signature
        # info.
        def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
            out = pos0
            out += pos1
            out += kw0
            out += kw1
            for arg in args:
                out += arg
            for kwarg in kwargs.values():
                out += kwarg
            return out

        def wrapped_fn(*args, **kwargs):
            return _fn(*args, **kwargs)

        pos0 = torch.randn(4)
        kw0 = torch.randn(4)
        args = (pos0, torch.randn(4), torch.randn(4))
        kwargs = {"kw0": kw0, "kw2": torch.randn(4)}
        expected_argument_names = [f"args_{i}" for i in range(len(args))] + list(
            kwargs.keys()
        )

        self._test_export_preserving_original_signature(
            wrapped_fn, expected_argument_names, *args, **kwargs
        )

    def test_export_with_functools_wrapped_method(self):
        def test_decorator(func):
            @functools.wraps(func)
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)

            return wrapper

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                return x

            @test_decorator
            def method_to_test(self, pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
                out = pos0
                out += pos1
                out += kw0
                out += kw1
                for arg in args:
                    out += arg
                for kwarg in kwargs.values():
                    out += kwarg
                return out

        pos0 = torch.randn(4)
        pos1 = torch.randn(4)
        unnamed_pos = torch.randn(4)
        kw0 = torch.randn(4)
        args = (pos0, pos1, unnamed_pos)
        kwargs = {"kw0": kw0, "kw2": torch.randn(4), "unnamed_kw": torch.randn(4)}
        expected_argument_names = [
            "pos0",
            "pos1",
            "args_0",  # 3rd unnamed positional argument
        ] + list(kwargs.keys())
        m = MyModule()

        self._test_export_preserving_original_signature(
            m.method_to_test, expected_argument_names, *args, **kwargs
        )

    def test_export_with_functools_wrapped_fn(self):
        def test_decorator(func):
            @functools.wraps(func)
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)

            return wrapper

        @test_decorator
        def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
            out = pos0
            out += pos1
            out += kw0
            out += kw1
            for arg in args:
                out += arg
            for kwarg in kwargs.values():
                out += kwarg
            return out

        def wrapped_fn(*args, **kwargs):
            return _fn(*args, **kwargs)

        pos0 = torch.randn(4)
        kw0 = torch.randn(4)
        args = (pos0, torch.randn(4), torch.randn(4))
        kwargs = {"kw0": kw0, "kw2": torch.randn(4)}
        expected_argument_names = [f"args_{i}" for i in range(len(args))] + list(
            kwargs.keys()
        )

        self._test_export_preserving_original_signature(
            wrapped_fn, expected_argument_names, *args, **kwargs
        )

    def _test_export_preserving_original_signature(
        self, fn, expected_argument_names: Sequence[str], *args, **kwargs
    ):
        torch._dynamo.reset()
        exported = torch._dynamo.export(
            fn,
            *args,
            **kwargs,
            aten_graph=False,
        )

        out_graph = exported[0]
        dynamo_result = out_graph(*args, **kwargs)
        real_result = fn(*args, **kwargs)
        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

        # Check that the exported graph preserves same argument names.
        self.assertEqual(
            inspect.getfullargspec(out_graph.forward).args[1:], expected_argument_names
        )

    def test_dataclass_input_output(self):
        from dataclasses import dataclass

        @dataclass
        class Tensors:
            x: torch.Tensor
            y: torch.Tensor

        def f(t):
            return t.x + t.y

        with self.assertRaisesRegex(
            UserError,
            "It looks like one of the inputs with type .*Tensors.* "
            "is not supported or pytree-flattenable",
        ):
            torch._dynamo.export(f, aten_graph=False)(
                Tensors(x=torch.randn(10), y=torch.randn(10))
            )

        def f(x, y):
            return Tensors(x=x.sin(), y=y.cos())

        with self.assertRaisesRegex(
            UserError,
            "It looks like one of the outputs with type .*Tensors.* "
            "is not supported or pytree-flattenable",
        ):
            torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10))

    def test_empty(self):
        def f(x):
            return x

        exported = torch._dynamo.export(f)(torch.randn(3, 3))
        out_graph = exported[0]
        inp = torch.randn(3, 3)
        self.assertTrue(torch._dynamo.utils.same(inp, out_graph(inp)))

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.ones(3, 3)

            def forward(self):
                return self.a

        exported = torch._dynamo.export(M())()
        out_graph = exported[0]
        self.assertTrue(torch._dynamo.utils.same(torch.ones(3, 3), out_graph()))

    @unittest.skipIf(not TEST_CUDA, "No CUDA available.")
    def test_export_with_parameters(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.features = torch.nn.Sequential(
                    torch.nn.Conv2d(
                        3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
                    ),
                    torch.nn.ReLU(inplace=True),
                )

            def forward(self, x):
                return self.features(x)

        model = MyModule().eval().cuda()
        random_inputs = (torch.rand([32, 3, 32, 32]).to("cuda"),)
        dim_x = torch.export.Dim("dim_x", min=1, max=32)
        exp_program = torch.export.export(
            model, random_inputs, dynamic_shapes={"x": {0: dim_x}}
        )
        output_buffer = io.BytesIO()
        # Tests if we can restore saved nn.Parameters when we load them again
        torch.export.save(exp_program, output_buffer)
        loaded_model = torch.export.load(output_buffer)
        self.assertTrue(
            isinstance(
                loaded_model.module().get_parameter("features.0.weight"),
                torch.nn.Parameter,
            )
        )

    def test_export_fast_binary_broadcast_check(self):
        # This test looks at the case where we erroneously create a guard
        # when checking the equality of the operands' shape and the output
        # shape during FakeTensor's binary op fast path.

        class MyModel(torch.nn.Module):
            def forward(self, a, b):
                # final shape is (dim0, 4, 8)
                # order matters since a & the output have the same shape
                return b + a

        a = torch.randn(100, 4, 8)
        b = torch.randn(4, 8)
        model = MyModel().eval().cuda()
        batchsize = torch.export.Dim("dim0", min=3, max=1024)
        dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]}

        torch.export.export(model, (a, b), dynamic_shapes=dynamic_shape_spec)

    def test_export_fast_binary_broadcast_check_unbacked(self):
        class MyModel(torch.nn.Module):
            def forward(self, numel, scalar):
                u0 = numel.item()
                torch._check_is_size(u0)
                x = torch.ones(u0 + 1)
                return scalar - x

        model = MyModel().eval().cuda()
        numel = torch.tensor(10)
        scalar = torch.randn(1)
        torch.export.export(model, (numel, scalar))

    def test_export_meta(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.p = torch.nn.Parameter(torch.ones(2, 3))

            def forward(self, x):
                return self.p + x

        with torch.device("meta"):
            m = MyModule()

        inp = torch.ones(2, 3, device="meta")
        exported = torch._dynamo.export(m)(inp)
        out_graph = exported[0]
        dynamo_result = out_graph(inp)
        self.assertEqual(dynamo_result, m(inp))

    def test_constraint_violation_error_messages(self):
        class Foo(torch.nn.Module):
            def forward(self, x):
                if x.shape[0] == x.shape[1] * 2:
                    return x + 1
                else:
                    return x + 2

        foo = Foo()

        t = torch.zeros([8, 4])
        dim0 = torch.export.Dim("dim0", min=3, max=10)
        dim1 = torch.export.Dim("dim1")
        dynamic_shapes = {"x": (dim0, dim1)}

        with self.assertRaisesRegex(
            torch._dynamo.exc.UserError,
            "Constraints violated .*!(.*\n)*.*"
            "by dim0 = 2\\*dim1(.*\n)*.*"
            "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*",
        ):
            torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes)

        class Bar(torch.nn.Module):
            def forward(self, x):
                if x.shape[0] == 5:
                    return x + 1
                else:
                    return x + 2

        bar = Bar()

        t = torch.zeros([5])
        dim0 = torch.export.Dim("dim0", min=3, max=8)
        dynamic_shapes = {"x": (dim0,)}
        with self.assertRaisesRegex(
            torch._dynamo.exc.UserError,
            "Not all values.*valid.*inferred to be a constant",
        ):
            torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes)

        class Qux(torch.nn.Module):
            def forward(self, x):
                if x.shape[0] > 5 and x.shape[0] < 10:
                    return x + 1
                else:
                    return x + 2

        qux = Qux()

        t = torch.zeros([7])
        dim0 = torch.export.Dim("dim0", min=3, max=8)
        dynamic_shapes = {"x": (dim0,)}
        with self.assertRaisesRegex(
            torch._dynamo.exc.UserError,
            "Not all values.*satisfy the generated guard",
        ):
            torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes)

    def test_untracked_inputs_in_constraints(self):
        from copy import copy

        class Foo(torch.nn.Module):
            def forward(self, x, y):
                return y + 1

        foo = Foo()

        x = torch.randn(2)
        y = torch.randn(5, 4)

        dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y")
        dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}

        example_inputs = (copy(x), y)
        ep = torch.export.export(foo, example_inputs, dynamic_shapes=dynamic_shapes)
        ep.module()(torch.randn(3), y)  # no specialization error

    def test_export_raise_guard_full_constraint(self):
        y = torch.randn([3, 3, 3])

        def my_dyn_fn(x):
            if x.shape[0] == 3:
                return x.sin()
            return x.cos()

        torch._dynamo.export(my_dyn_fn)(y)

        with self.assertRaises(ConstraintViolationError):
            torch._dynamo.export(
                my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
            )(y)

    def test_export_module_specify_constraints_signature(self):
        y = torch.randn([3, 3, 3])

        class Mod(torch.nn.Module):
            def forward(self, x):
                if x.shape[0] == 3:
                    return x.sin()
                return x.cos()

        mod = Mod()
        torch._dynamo.export(mod)(y)

        with self.assertRaisesRegex(ConstraintViolationError, "dimx = 3"):
            torch._dynamo.export(mod, dynamic_shapes=({0: torch.export.Dim("dimx")},))(
                y
            )

    def test_export_raise_guard_partial_constraint(self):
        y = torch.randn([3, 3, 3])

        def my_dyn_fn(x):
            if x.shape[0] > 3:
                return x.sin()
            return x.cos()

        torch._dynamo.export(my_dyn_fn)(y)

        with self.assertRaises(ConstraintViolationError):
            torch._dynamo.export(
                my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
            )(y)

    def test_export_raise_on_relationship(self):
        y = torch.randn([3, 3, 3])

        def my_dyn_fn(a, b, c):
            if a.shape[0] == b.shape[1] == c.shape[2]:
                return a.sin()

            return a.cos()

        torch._dynamo.export(my_dyn_fn)(y, y, y)
        dim = torch.export.Dim("dim")
        dynamic_shapes = ({0: dim}, {0: dim}, {0: dim})
        with self.assertRaises(ConstraintViolationError):
            torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
        dynamic_shapes = ({0: dim}, {1: dim}, {2: dim})
        torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)

    def test_export_no_raise(self):
        y = torch.randn([3, 3, 3])

        def my_dyn_fn(a, b, c):
            if a.shape[1] == 3:
                return a.cos()
            return a * b * c

        torch._dynamo.export(my_dyn_fn)(y, y, y)
        dim = torch.export.Dim("dim")
        dynamic_shapes = ({0: dim}, {0: dim}, {0: dim})
        torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)

    def test_export_multi_dynamic_dim_unsafe_relationship(self):
        x = torch.randn([3, 3, 3])
        y = torch.randn([2, 2, 2])
        z = torch.randn([3, 3, 3])

        def my_dyn_fn(a, b, c):
            if a.shape[0] == c.shape[0]:
                return a.cos()
            return a * c, b

        torch._dynamo.export(my_dyn_fn)(x, y, z)
        dimx, dimy, dimz = torch.export.dims("dimx", "dimy", "dimz")
        dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz})
        with self.assertRaises(ConstraintViolationError):
            torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
        dimz = dimx
        dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz})
        torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)

    def test_remove_redundant_dynamic_dim_in_error_message(self):
        class Foo(torch.nn.Module):
            def forward(self, x, y):
                if x.shape[0] == y["k"].shape[0]:
                    return x + 1
                else:
                    return x - 1

        foo = Foo()

        a = torch.randn(3)
        b = torch.randn(3)
        dim0_a, dim0_b = torch.export.dims("dim0_a", "dim0_b")
        with self.assertRaisesRegex(torch._dynamo.exc.UserError, "dim0_b = dim0_a"):
            torch.export.export(
                foo,
                (a, {"k": b}),
                dynamic_shapes={"x": {0: dim0_a}, "y": {"k": {0: dim0_b}}},
            )

    def test_enforce_equalities(self):
        class Bar(torch.nn.Module):
            def forward(self, x, y):
                return torch.matmul(x, y)

        bar = Bar()

        batch, size = torch.export.dims("batch", "size")
        dynamic_shapes = {"x": (batch, size, size), "y": (batch, size, size)}

        x = torch.randn(10, 3, 3)
        y = torch.randn(10, 3, 4)
        with self.assertRaisesRegex(
            torch._dynamo.exc.UserError,
            ".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3",
        ):
            torch.export.export(
                bar,
                (x, y),
                dynamic_shapes=dynamic_shapes,
            )
        y = torch.randn(10, 3, 3)
        ebar = torch.export.export(
            bar,
            (x, y),
            dynamic_shapes=dynamic_shapes,
        )
        self.assertEqual(
            [
                str(node.meta["val"].shape)
                for node in ebar.graph_module.graph.nodes
                if node.op == "placeholder"
            ],
            ["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
        )

    @torch._dynamo.config.patch(
        capture_dynamic_output_shape_ops=True,
        specialize_int=True,
        capture_scalar_outputs=True,
    )
    def test_export_preserve_constraints_as_metadata_tensor(self):
        def f(x):
            b = x.nonzero()
            torch._check(b.shape[0] >= 2)
            torch._check(b.shape[0] <= 5)
            return b

        y = torch.tensor([8, 8, 6])
        gm, _ = torch._dynamo.export(
            f,
            aten_graph=True,
            tracing_mode="symbolic",
        )(y)

    @config.patch(
        capture_dynamic_output_shape_ops=True,
        specialize_int=True,
        capture_scalar_outputs=True,
    )
    def test_exported_graph_serialization(self):
        def f(x, y):
            b = x.item()
            torch._check_is_size(b)
            return torch.empty((b, y.shape[0]))

        x = torch.tensor([3])
        y = torch.randn([8, 8, 6])
        example_inputs = [x, y]
        dynamic_shapes = (None, {0: torch.export.Dim("dimy", min=6, max=10)})
        gm, _ = torch._dynamo.export(
            f,
            dynamic_shapes=dynamic_shapes,
            aten_graph=True,
            tracing_mode="symbolic",
        )(*example_inputs)

        # Ensure the exported graph module with metadata is serializable,
        # metadata won't be saved in the serialized module
        buffer = io.BytesIO()
        torch.save(gm, buffer)

    def test_export_dynamic_dim_not_1(self):
        x = torch.randn([1, 1, 1])

        def my_dyn_fn(a):
            if a.shape[0] != 1:
                return a.cos()
            return a * a

        torch._dynamo.export(my_dyn_fn)(x)
        with self.assertRaises(ConstraintViolationError):
            torch._dynamo.export(
                my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
            )(x)

    def test_symbool(self):
        def f(x):
            a = torch.scalar_tensor(x.shape[0] > 4)
            return x.sin().sum() + a.sum()

        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
        self.assertEqual(gm(torch.ones(3, 4)), f(torch.ones(3, 4)))

    def test_export_multi_dynamic_dim_constraint(self):
        x = torch.randn([3, 3, 3])
        y = torch.randn([2, 2, 2])
        z = torch.randn([3, 3, 3])

        def my_dyn_fn(a, b, c):
            if a.shape[0] == c.shape[0]:
                return a.cos()
            return a * c, b

        torch._dynamo.export(my_dyn_fn)(x, y, z)
        dimx_0, dimx_1, dimx_2 = torch.export.dims("dimx_0", "dimx_1", "dimx_2")
        dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, None)
        with self.assertRaises(ConstraintViolationError):
            torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
        dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, {0: dimx_0})
        torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)

    def test_export_dynamic_dim_range_constraint(self):
        x = torch.ones(6, 4, 4)
        dynamic_shapes = ({0: torch.export.Dim("dimx", min=5, max=6)},)

        def foo(x):
            if x.shape[0] > 3:  # ok
                return x.sin()
            return x.cos()

        torch._dynamo.export(
            foo,
            dynamic_shapes=dynamic_shapes,
            aten_graph=True,
        )(x)

        def bar(x):
            if x.shape[0] > 5:  # error
                return x.sin()
            return x.cos()

        with self.assertRaises(ConstraintViolationError):
            torch._dynamo.export(
                bar,
                dynamic_shapes=dynamic_shapes,
                aten_graph=True,
            )(x)

    def test_trivial_constraint(self):
        class Foo(torch.nn.Module):
            def forward(self, x):
                # complex divisibility condition
                if (2 * x.shape[0] + 3) % (x.shape[0] - 3) == 0:
                    return x + 1
                else:
                    return x - 1

        foo = Foo()

        class Bar(torch.nn.Module):
            def forward(self, x):
                # trivially true
                if (2 * x.shape[0] + 2) % (x.shape[0] + 1) == 0:
                    return x + 1
                else:
                    return x - 1

        bar = Bar()

        class Qux(torch.nn.Module):
            def forward(self, x):
                # simple divisibility condition (not trivially true)
                if (3 * x.shape[0]) % 2 == 0:
                    return x + 1
                else:
                    return x - 1

        qux = Qux()

        x = torch.randn(12)
        dim0 = torch.export.Dim("dim0", max=100)
        dynamic_shapes = {"x": (dim0,)}
        with self.assertRaisesRegex(
            torch._dynamo.exc.UserError,
            r"Constraints violated \(dim0\)",
        ):
            torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes)

        torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes)

        with self.assertRaisesRegex(
            torch._dynamo.exc.UserError,
            r"Constraints violated \(dim0\)",
        ):
            torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes)

    def test_list_contains(self):
        def func(x):
            assert x.size(-1) in [4, 5, 6], "bad"
            return x + x

        inps = (torch.randn(1, 5),)
        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_list_not_contains(self):
        def func(x):
            assert x.size(0) not in [4, 5, 6], "bad1"
            assert "monkey" not in ["cow", "pig"], "bad2"
            return x + x

        inps = (torch.randn(1, 5),)
        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
        real_result = opt_func(*inps)

        torch._dynamo.reset()

        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
        out_graph = exported[0]

        dynamo_result = out_graph(*inps)

        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

    def test_export_identity(self):
        inp = torch.tensor([0.1, 0.1])

        def func(x):
            return x

        torch._dynamo.reset()
        exported, _ = torch._dynamo.export(func)(inp)
        dynamo_result = exported(inp)
        self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result))

    def test_export_specialized_int(self):
        class Foo(torch.nn.Module):
            def __init__(
                self,
                input_dim,
            ):
                super().__init__()
                self.torch_module = torch.nn.LayerNorm(
                    input_dim, eps=1e-5, elementwise_affine=True
                )
                self.int_val = 100

            def forward(self, input):
                return input.cos() * self.int_val * self.torch_module.eps

        mod = Foo(128)
        inp = torch.randn(3, 128)

        # In export, int & float in forward should always be specialized
        gm, _ = torch._dynamo.export(mod, aten_graph=True)(inp)
        count = 0
        for node in gm.graph.nodes:
            if node.op == "placeholder":
                count += 1
        self.assertEqual(count, 1)

    def test_export_with_nonzero_static(self):
        class BasicModule(torch.nn.Module):
            def __init__(self, static_size):
                super().__init__()
                self.static_size = static_size

            def forward(self, x):
                return torch.nonzero_static(x, size=self.static_size)

        input_tensors = torch.tensor([6, 8]), torch.zeros(2, 3)
        static_sizes = 3, 4
        for input_tensor, static_size in zip(input_tensors, static_sizes):
            m = BasicModule(static_size)
            gm, _ = torch._dynamo.export(m, aten_graph=True)(input_tensor)
            res = gm(input_tensor)
            self.assertEqual(res.size(0), static_size)
            self.assertTrue(
                torch._dynamo.utils.same(
                    res, torch.nonzero_static(input_tensor, size=static_size)
                )
            )

    def test_export_pass_arg_by_name(self):
        class BasicModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.my_lin = torch.nn.Linear(3, 4, bias=True)

            def forward(self, x):
                return self.my_lin(x)

        mod, input_tensor = BasicModule(), torch.randn(2, 3)
        gm, guard = torch._dynamo.export(mod, aten_graph=True)(input_tensor)
        ref = mod(x=input_tensor)
        res = gm(x=input_tensor)
        self.assertTrue(torch._dynamo.utils.same(ref, res))

    def test_export_pass_arg_by_name_star_args(self):
        class BasicModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.my_lin = torch.nn.Linear(3, 4, bias=True)

            def forward(self, *args):
                return self.my_lin(args[0]) * self.my_lin(args[1])

        mod, input_tensor, input_tensor2 = (
            BasicModule(),
            torch.randn(2, 3),
            torch.randn(2, 3),
        )
        gm, guard = torch._dynamo.export(mod, aten_graph=True)(
            input_tensor, input_tensor2
        )
        ref = mod(input_tensor, input_tensor2)
        res = gm(input_tensor, input_tensor2)
        self.assertTrue(torch._dynamo.utils.same(ref, res))

    def test_export_mark_dynamic_conflict_dynamic_dim(self):
        y = torch.randn([3, 3, 3])

        def my_dyn_fn(x):
            if x.shape[0] > 3:
                return x.sin()
            return x.cos()

        torch._dynamo.mark_dynamic(y, 0)
        with self.assertRaisesRegex(
            RuntimeError,
            "Constraints violated",
        ):
            torch._dynamo.export(
                my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},)
            )(y)

    def test_export_dynamic_dim_cleanup(self):
        y = torch.randn([3, 3, 3])

        def my_dyn_fn(x):
            return x.cos()

        torch._dynamo.export(my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},))(
            y
        )

    @config.patch(capture_dynamic_output_shape_ops=True)
    def test_export_dynamic_control_flow_error(self):
        def f(x):
            if x.nonzero() > 3:
                return x.cos()
            return x.sin()

        with self.assertRaisesRegex(
            torch._dynamo.exc.UserError,
            "Dynamic control flow is not supported at the moment",
        ):
            gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(5, 6))

    @config.patch(assume_static_by_default=False)
    def test_export_persist_assert(self):
        def f(x):
            assert x[0].sum() > 4, "Shape must be more than 4"
            return x.cos() + x.sin()

        gm, guard = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
            torch.ones(5, 4, 6)
        )

        def has_aten_op(gm, op):
            for node in gm.graph.nodes:
                if node.target == op:
                    return True
            return False

        self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))

        gm.graph.eliminate_dead_code()
        gm.recompile()
        self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))

        with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
            gm(torch.zeros(3, 4, 5))

    @common_utils.parametrize(
        "type_fn",
        [
            common_utils.subtest(type, name="builtin"),
            common_utils.subtest(lambda obj: obj.__class__, name="attr"),
        ],
    )
    def test_access_class_method_from_user_class(self, type_fn):
        class A:
            @classmethod
            def func(cls):
                return torch.Tensor([4, 5])

        def f(x):
            a = A()
            return x.sum() + type_fn(a).func().sum()

        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
        self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))

    def test_not_functionalize(self):
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer1 = torch.nn.Buffer(torch.ones(6, 2))

            def forward(self, x):
                x.add_(2)
                return x.sum() + self.buffer1.sum()

        example_inputs = (torch.ones(1, 2, 3),)
        gm, _ = torch._dynamo.export(
            Foo(),
            aten_graph=True,
            tracing_mode="symbolic",
        )(*example_inputs)
        count = 0
        for node in gm.graph.nodes:
            if node.target == torch.ops.aten.add_.Tensor:
                count += 1
        self.assertEqual(count, 1)
        test_inp = (torch.ones(1, 2, 3),)
        test_inp_v2 = (torch.ones(1, 2, 3),)
        self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2))

    def test_round_dynamic_shapes(self):
        def f(x):
            return x[: round(x.shape[0] / 2)]

        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))

        self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))

    def test_cond_supported_pred_types(self):
        def true_fn(x):
            return x.cos()

        def false_fn(x):
            return x.sin()

        def f_pred_traced_as_symnode_var(x):
            return cond(x.shape[0] > 2, true_fn, false_fn, [x])

        def f_pred_traced_as_tensor_var(x):
            return cond(x.all(), true_fn, false_fn, [x])

        def f_pred_complex_expression_traced_as_symnode_var(x):
            return cond(
                x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10,
                true_fn,
                false_fn,
                [x],
            )

        example_inputs = (torch.rand(5, 8),)
        for f in [
            f_pred_traced_as_symnode_var,
            f_pred_traced_as_tensor_var,
            f_pred_complex_expression_traced_as_symnode_var,
        ]:
            gm, _ = torch._dynamo.export(f, aten_graph=True)(*example_inputs)
            self.assertEqual(gm(*example_inputs), f(*example_inputs))

    @unittest.expectedFailure  # TODO: Not sure why dynamo creates a new inputs for self.a
    def test_sum_param(self):
        # Setting a new attribute inside forward()
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.randn(3, 2)

            def forward(self, x):
                self.b = 2
                return x.sum() + self.a.sum() + self.b

        torch._dynamo.export(Foo())(torch.randn(3, 2))

    def test_mixed_real_and_fake_inputs(self):
        class _TestPattern(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)
                self.bn = torch.nn.BatchNorm2d(1)

            def forward(self, input):
                running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
                scale_factor = self.bn.weight / running_std
                weight_shape = [1] * len(self.conv.weight.shape)
                weight_shape[0] = -1
                bias_shape = [1] * len(self.conv.weight.shape)
                bias_shape[1] = -1
                scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
                zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
                conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
                conv_orig = conv / scale_factor.reshape(bias_shape)
                conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
                conv = self.bn(conv_orig)
                return conv

        example_inputs = (torch.randn(1, 1, 3, 3),)
        torch._dynamo.export(
            _TestPattern(),
            aten_graph=True,
        )(*example_inputs)

    @config.patch(
        capture_dynamic_output_shape_ops=True,
        capture_scalar_outputs=True,
        assume_static_by_default=False,
    )
    def test_sym_contains(self):
        def f(x, y):
            return x.size(0) in y

        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(2), torch.ones(3))

        true_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(5))
        false_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(2))
        self.assertEqual(gm(*true_inp), f(*true_inp))
        self.assertEqual(gm(*false_inp), f(*false_inp))

    def test_cond_raise_user_error_on_missing_args(self):
        def true_fn(x):
            return x.cos()

        def false_fn(x):
            return x.sin()

        def f(x):
            return cond(x.shape[0] > 10, true_fn, false_fn)

        example_inputs = (torch.rand(5),)
        with self.assertRaisesRegex(
            TypeError,
            r"cond\(\) missing 1 required positional argument: 'operands'",
        ):
            f(*example_inputs)

    def test_cond_raise_user_error_on_unsupported_pred(self):
        def f_unsupported_pred(x):
            pred = torch.nn.Module()
            return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x])

        example_inputs = (torch.rand(5),)
        with self.assertRaisesRegex(
            RuntimeError,
            "Expected pred to be bool or tensor, but got Module()",
        ):
            f_unsupported_pred(*example_inputs)

    def test_cond_raise_user_error_on_non_list_operands(self):
        def f_non_list_operands(x):
            return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x)

        example_inputs = (torch.rand(5),)
        with self.assertRaisesRegex(
            RuntimeError,
            r"Expect operands to be a tuple of possibly nested dict/list/tuple",
        ):
            f_non_list_operands(*example_inputs)

    def test_cond_raise_user_error_on_non_tensor_operands(self):
        def f_non_tensor_operands(x):
            a: float = 3.14
            return cond(
                torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a]
            )

        example_inputs = (torch.rand(5),)
        with self.assertRaisesRegex(
            RuntimeError,
            r"Expect operands to be a tuple of possibly nested dict/list/tuple",
        ):
            f_non_tensor_operands(*example_inputs)

    def test_cond_raise_user_error_on_branch_args_mismatch(self):
        def true_fn(x, y):
            return x.sin()

        def false_fn(x):
            return x.cos()

        def f_branch_args_mismatch(x, y):
            return cond(torch.tensor([[[[True]]]]), true_fn, false_fn, [x, y])

        example_inputs = (torch.rand(5), torch.rand(2))
        with self.assertRaisesRegex(
            torch._dynamo.exc.UncapturedHigherOrderOpError,
            "Cond doesn't work unless it is captured completely with torch.compil",
        ):
            torch._dynamo.export(
                f_branch_args_mismatch,
                aten_graph=True,
            )(
                *example_inputs,
            )

    @config.patch(suppress_errors=True)
    def test_uncaptured_higher_order_op_error_not_suppresed(self):
        def true_fn(x, y):
            return x.sin()

        def false_fn(x):
            return x.cos()

        def f_branch_args_mismatch(x, y):
            return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y])

        example_inputs = (torch.rand(5), torch.rand(2))
        with self.assertRaisesRegex(
            torch._dynamo.exc.UncapturedHigherOrderOpError,
            "Cond doesn't work unless it is captured completely with torch.compile",
        ):
            torch._dynamo.export(
                f_branch_args_mismatch,
                aten_graph=True,
            )(
                *example_inputs,
            )

    def test_cond_raise_user_error_on_branch_return_non_tensor(self):
        def f_branch_return_non_tensor(x):
            return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x])

        example_inputs = (torch.rand(5),)
        with self.assertRaisesRegex(
            torch._dynamo.exc.UncapturedHigherOrderOpError,
            "Cond doesn't work unless it is captured completely with torch.compile",
        ):
            torch._dynamo.export(
                f_branch_return_non_tensor,
                aten_graph=True,
            )(*example_inputs)

    def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
        def f_branch_return_multiple_tensors(pred, x, y):
            return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])

        example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
        gm, _ = torch._dynamo.export(
            f_branch_return_multiple_tensors,
            aten_graph=True,
        )(*example_inputs)
        self.assertEqual(
            gm(*example_inputs), f_branch_return_multiple_tensors(*example_inputs)
        )

    def test_multiple_outputs_op_with_evaluator(self):
        class TopKModel(torch.nn.Module):
            def forward(self, x):
                values, _ = torch.topk(x, 3)
                return torch.sum(values)

        x = torch.arange(1.0, 6.0, requires_grad=True)
        torch._dynamo.export(TopKModel())(x)

    def test_cond_raise_user_error_on_mismatch_return_length(self):
        def true_fn(x):
            return x

        def false_fn(x):
            return (x, x)

        def f_mismatch_return_length(x):
            return cond(torch.tensor(100), true_fn, false_fn, [x])

        example_inputs = (torch.rand(5),)
        with self.assertRaisesRegex(
            RuntimeError, "Unmatched number of outputs from cond"
        ):
            torch._dynamo.export(
                f_mismatch_return_length,
                aten_graph=True,
            )(*example_inputs)

    def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self):
        def true_fn(x):
            return torch.tensor([[3], [2]])

        def false_fn(x):
            return torch.tensor([3.14])

        def f_return_tensor_mismatch(x):
            return cond(x.shape[0] < 3, true_fn, false_fn, [x])

        example_inputs = (torch.rand(5),)
        with self.assertRaisesRegex(
            torch._dynamo.exc.UncapturedHigherOrderOpError,
            "Cond doesn't work unless it is captured completely with torch.compile",
        ):
            torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)(
                *example_inputs,
            )

    def test_byte_tensor_does_not_crash(self):
        # See https://github.com/pytorch/pytorch/issues/100455
        def func(text):
            tensor = torch.ByteTensor(list(bytes(text, "utf8")))
            return tensor + tensor

        text = "".join(chr(a % 90 + 40) for a in range(111))
        opt_func = torch._dynamo.optimize("eager", dynamic=True)(func)
        for i in [99, 100]:
            input = text[:i]
            opt_func(input)

    def test_export_defaults_ok(self):
        class DynamicSliceExportMod(torch.nn.Module):
            def forward(self, x):
                results = []
                for i in range(4):
                    results.append(x[: x.size(0) - i, i : x.size(2), i:3])
                return tuple(results)

        gm, _ = torch._dynamo.export(DynamicSliceExportMod(), aten_graph=True)(
            torch.randn(5, 5, 5),
        )

        self.assertExpectedInline(
            gm.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    arg0_1 = arg0
    sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
    slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
    sub = sym_size_int - 1
    slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub);  sub = None
    slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int);  slice_2 = None
    slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3);  slice_3 = None
    sub_1 = sym_size_int - 2
    slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1);  sub_1 = None
    slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int);  slice_5 = None
    slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3);  slice_6 = None
    sub_2 = sym_size_int - 3
    slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2);  arg0_1 = sub_2 = None
    slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int);  slice_8 = sym_size_int = None
    slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3);  slice_9 = None
    return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""",
        )

    def test_capture_symbolic_tracing_simple_within_fake_mode(self):
        from torch._dynamo.output_graph import config

        def f(x):
            y = torch.randn(3)
            return x + x * y

        with fake_tensor.FakeTensorMode(
            shape_env=ShapeEnv(
                allow_scalar_outputs=config.capture_scalar_outputs,
                allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
            ),
        ):
            x = torch.randn(3)

            for aten_graph in [True, False]:
                gm, _ = torch._dynamo.export(f, aten_graph=aten_graph)(x)
                self.assertTrue(
                    isinstance(gm, torch.fx.GraphModule),
                    msg="test_capture_symbolic_tracing_simple_within_fake_mode_aten_graph_"
                    + str(aten_graph),
                )

    def test_export_with_symbool_inputs(self):
        def f(pred: bool, x: torch.Tensor):
            if pred:
                return x.sin()
            else:
                return x.cos()

        x = torch.randn([3, 4])

        def test_symbool_guards(
            f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards
        ):
            shape_env = ShapeEnv()
            with fake_tensor.FakeTensorMode(
                shape_env=shape_env,
            ) as fake_mode:
                fake_x = fake_mode.from_tensor(
                    x,
                    symbolic_context=StatelessSymbolicContext(
                        dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())],
                    ),
                )
                for i, size in enumerate(size_tests):
                    pred = fake_x.size(0) == size
                    gm, guards = torch._dynamo.export(f)(pred, x)
                    actual = normalize_gm(gm.print_readable(print_output=False))
                    # TODO: This is naughty, EXPECTTEST_ACCEPT=1 doesn't work
                    self.assertExpectedInline(actual, exp_graph[i])
                    dynamo_shape_env_guards = [
                        guard
                        for guard in guards
                        if guard.guard_types is not None
                        and "SHAPE_ENV" in guard.guard_types
                    ]
                    self.assertEqual(len(dynamo_shape_env_guards), 1)
                    guard_code_on_predicate = [
                        code
                        for code in dynamo_shape_env_guards[0].code_list
                        if "L['pred']" in code
                    ]
                    self.assertEqual(guard_code_on_predicate, exp_guard_code[i])
                    outter_shape_env_guards = [
                        str(guard.expr) for guard in shape_env.guards
                    ]
                    self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i])

        true_graph = """\
class GraphModule(torch.nn.Module):
    def forward(self, pred, x):
        arg1: "f32[s1, s2]";

        arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
        l_x_ = arg1

        sin: "f32[s1, s2]" = l_x_.sin();  l_x_ = None
        return pytree.tree_unflatten([sin], self._out_spec)
"""
        false_graph = """\
class GraphModule(torch.nn.Module):
    def forward(self, pred, x):
        arg1: "f32[s1, s2]";

        arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
        l_x_ = arg1

        cos: "f32[s1, s2]" = l_x_.cos();  l_x_ = None
        return pytree.tree_unflatten([cos], self._out_spec)
"""
        true_guard_code = [
            "cast_symbool_to_symint_guardless(L['pred']) == 1",
        ]
        false_guard_code = [
            "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
        ]
        test_symbool_guards(
            f,
            [3, 3, 4, 5],
            [true_graph, true_graph, false_graph, false_graph],
            [true_guard_code, true_guard_code, false_guard_code, false_guard_code],
            # Outter shape env should have no guards in it because we never specialize on the outter symbool.
            [[], [], [], []],
        )

    def test_invalid_input_global(self) -> None:
        global bulbous_bouffant
        bulbous_bouffant = torch.randn(3)

        def f(y):
            return bulbous_bouffant + y

        self.assertExpectedInlineMunged(
            UserError,
            lambda: torch._dynamo.export(f)(torch.randn(3)),
            """\
G['bulbous_bouffant'], accessed at:
  File "test_export.py", line N, in f
    return bulbous_bouffant + y
""",
        )

    def test_invalid_input_global_multiple_access(self) -> None:
        global macademia
        macademia = torch.randn(3)

        def g(y):
            global macademia
            y = macademia + y
            return y

        def f(y):
            global macademia
            y = g(y)
            return macademia + y

        # NB: This doesn't actually work (it only reports the first usage),
        # but I'm leaving the test here in case we fix it later
        self.assertExpectedInlineMunged(
            UserError,
            lambda: torch._dynamo.export(f)(torch.randn(3)),
            """\
G['macademia'], accessed at:
  File "test_export.py", line N, in f
    y = g(y)
  File "test_export.py", line N, in g
    y = macademia + y
""",
        )

    def test_invalid_input_nonlocal(self) -> None:
        arglebargle = torch.randn(3)

        def f(y):
            return arglebargle + y

        self.assertExpectedInlineMunged(
            UserError,
            lambda: torch._dynamo.export(f)(torch.randn(3)),
            """L['arglebargle'], a closed over free variable""",
        )

    def test_invalid_input_unused_nonlocal_ok(self) -> None:
        arglebargle = torch.randn(3)

        def f(y):
            x = arglebargle
            return y

        torch._dynamo.export(f)(torch.randn(3))

    def test_symbolic_tracing_within_fake_mode_with_constraints(self):
        from torch._subclasses import fake_tensor

        fake_mode = fake_tensor.FakeTensorMode()

        class DynamicShapeSimpleModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, a, b, c) -> torch.Tensor:
                d = (torch.matmul(a, b) + c) / 2
                d_s0 = d.shape[0]
                d_s1 = d.shape[1]
                d_s3 = d_s0 * d_s1
                e = d.view(d_s3)
                return torch.cat([e, e])

        with fake_mode:
            model = DynamicShapeSimpleModel()
            inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
            dim = torch.export.Dim("dim")
            dynamic_shapes = ({0: dim}, None, {0: dim})
            for aten_graph in [True, False]:
                gm = torch._dynamo.export(
                    model,
                    dynamic_shapes=dynamic_shapes,
                    aten_graph=aten_graph,
                )(*inputs).graph_module

        # Since there are no parameters we can do this
        inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
        self.assertEqual(model(*inputs), gm(*inputs))

    def test_symbolic_tracing_within_fake_mode_with_constraints_with_parameters(self):
        from torch._subclasses import fake_tensor

        fake_mode = fake_tensor.FakeTensorMode()

        # TODO: Seems to choke if you don't make a fresh model and
        # just try to export Linear directly...
        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(2, 2)

            def forward(self, x):
                out = self.linear(x)
                return out

        with fake_mode:
            model = Model()
            inputs = (torch.randn(10, 2, 2),)
            dynamic_shapes = ({0: torch.export.Dim("dim")},)
            for aten_graph in [True, False]:
                gm = torch._dynamo.export(
                    model,
                    dynamic_shapes=dynamic_shapes,
                    aten_graph=aten_graph,
                )(*inputs).graph_module

    def test_capture_symbolic_tracing_within_fake_mode(self):
        from torch._dynamo.output_graph import config
        from torch._subclasses import fake_tensor
        from torch.fx.experimental.symbolic_shapes import ShapeEnv

        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(2, 2)
                self.linear2 = torch.nn.Linear(2, 2)

            def forward(self, x):
                out = self.linear(x)
                out = self.linear2(out)
                return out

        # User-instantiated FakeTensorMode
        fake_mode = fake_tensor.FakeTensorMode(
            allow_non_fake_inputs=False,
            allow_fallback_kernels=True,
            shape_env=ShapeEnv(
                allow_scalar_outputs=config.capture_scalar_outputs,
                allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
            ),
        )
        # Fakefy input+model before exporting it
        with fake_mode:
            x = torch.rand(5, 2, 2)
            model = Model()

            # Export the model with fake inputs and parameters
            for aten_graph in [True, False]:
                graph_module, _ = torch._dynamo.export(model, aten_graph=aten_graph)(x)
                self.assertTrue(
                    isinstance(graph_module, torch.fx.GraphModule),
                    msg="test_capture_symbolic_tracing_within_fake_mode_aten_graph_"
                    + str(aten_graph),
                )

    def test_cond_op_param_buffer_lifted(self):
        class A(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4))

            def forward(self):
                return self.buffer1.sum()

        class B(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer2 = torch.nn.Buffer(torch.ones(6, 4))

            def forward(self):
                return self.buffer2.sum()

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = A()
                self.b = B()

            def forward(self, x):
                def true_fn(x):
                    return x.cos() + self.a()

                def false_fn(x):
                    return x.sin() + self.b()

                return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)

        gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4))
        self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
        self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))

    def test_nested_cond_op_param_buffer_lifted(self):
        class A(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4))

            def forward(self):
                return self.buffer1.sum()

        class B(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer2 = torch.nn.Buffer(torch.ones(6, 4))

            def forward(self):
                return self.buffer2.sum()

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = A()
                self.b = B()

            def forward(self, x):
                def true_true_fn(x):
                    return x.cos() + self.a()

                def true_false_fn(x):
                    return x.cos() + self.a() + 1

                def true_fn(x):
                    return cond(x.shape[0] > 5, true_true_fn, true_false_fn, [x])

                def false_fn(x):
                    return x.sin() + self.b()

                return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)

        gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4))
        self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
        self.assertEqual(gm(torch.ones(5, 4)), M()(torch.ones(5, 4)))
        self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))

    def test_map_cond_param_buffer_lifted(self):
        from functorch.experimental.control_flow import cond, map

        class A(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4))

            def forward(self):
                return self.buffer1.sum()

        class B(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer2 = torch.nn.Buffer(torch.ones(6, 4))

            def forward(self):
                return self.buffer2.sum()

        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = A()
                self.b = B()

            def inner(self, x, pred):
                def true_fn(x):
                    return x + x + self.a()

                def false_fn(x):
                    return x * x + self.b()

                return cond(pred, true_fn, false_fn, [x])

            def forward(self, pred, xs):
                def body(x, pred):
                    return self.inner(x, pred) + self.b()

                return map(body, xs, pred)

        mod = Module()
        x = torch.randn(3, 2, 1)
        pred_x = torch.tensor(True)

        y = torch.randn(4, 3, 2)
        pred_y = torch.tensor(False)
        real_result = mod(pred_y, y)

        out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
        self.assertEqual(real_result, out_graph(pred_y, y))

    def test_cond_free_variables_overlapping(self):
        from functorch.experimental.control_flow import cond

        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, pred, x):
                a = torch.ones(6, 4)
                b = torch.ones(6, 4)
                c = torch.ones(6, 4)
                d = torch.ones(6, 4)

                def true_fn(x):
                    return x + x + a.cos() + b.cos() + d.cos()

                def false_fn(x):
                    return x * x + a.sin() + b.sin() + c.sin()

                return cond(pred, true_fn, false_fn, [x])

        mod = Module()
        x = torch.ones(6, 4)
        pred_x = torch.tensor(True)

        out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
        self.assertExpectedInline(
            out_graph.code.strip(),
            """\
def forward(self, pred, x):
    arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
    l_pred_ = arg0
    l_x_ = arg1
    a = torch.ones(6, 4)
    b = torch.ones(6, 4)
    c = torch.ones(6, 4)
    d = torch.ones(6, 4)
    cond_true_0 = self.cond_true_0
    cond_false_0 = self.cond_false_0
    cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [a, b, l_x_, d, c]);  l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None
    getitem = cond[0];  cond = None
    return pytree.tree_unflatten([getitem], self._out_spec)""",  # noqa: B950,E122
        )

        self.assertExpectedInline(
            out_graph.cond_true_0.code.strip(),
            """\
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
    a_1 = a
    b_1 = b
    l_x__1 = l_x_
    add = l_x__1 + l_x__1;  l_x__1 = None
    cos = a_1.cos();  a_1 = None
    add_1 = add + cos;  add = cos = None
    cos_1 = b_1.cos();  b_1 = None
    add_2 = add_1 + cos_1;  add_1 = cos_1 = None
    cos_2 = d_true_branch.cos();  d_true_branch = None
    add_3 = add_2 + cos_2;  add_2 = cos_2 = None
    return (add_3,)""",
        )

        self.assertExpectedInline(
            out_graph.cond_false_0.code.strip(),
            """\
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
    a_1 = a
    b_1 = b
    l_x__1 = l_x_
    mul = l_x__1 * l_x__1;  l_x__1 = None
    sin = a_1.sin();  a_1 = None
    add = mul + sin;  mul = sin = None
    sin_1 = b_1.sin();  b_1 = None
    add_1 = add + sin_1;  add = sin_1 = None
    sin_2 = c_false_branch.sin();  c_false_branch = None
    add_2 = add_1 + sin_2;  add_1 = sin_2 = None
    return (add_2,)""",
        )

    @unittest.skipIf(
        common_utils.TEST_WITH_ASAN,
        "Times out with ASAN, see https://github.com/pytorch/pytorch/issues/110416",
    )
    def test_retracibility(self):
        class MyLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = torch.randn(20, 98)
                self.bias = torch.randn(20)

            def forward(self, x):
                return torch.nn.functional.linear(x, self.weight, self.bias)

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(16, 33, 3)
                self.linear = MyLinear()

            def forward(self, x):
                a, b = x
                a_conv = self.conv(a)
                a_linear = self.linear(a_conv)
                b_conv = self.conv(b)
                b_linear = self.linear(b_conv)
                return (
                    a_linear.cos() + b_linear.sin(),
                    a_linear.sin() + b_linear.cos(),
                )

        inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))

        gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
        gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)

        inp_test = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))

        self.assertTrue(torch.allclose(gm(inp_test)[0], gm2(inp_test)[0]))
        self.assertTrue(torch.allclose(gm(inp_test)[1], gm2(inp_test)[1]))

    def test_retracibility_dict_container_inp_out(self):
        class MyLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = torch.randn(20, 98)
                self.bias = torch.randn(20)

            def forward(self, x):
                return torch.nn.functional.linear(x, self.weight, self.bias)

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(16, 33, 3)
                self.linear = MyLinear()

            def forward(self, x):
                a1, a2 = x["a"]
                b = x["b"]
                a1_conv = self.conv(a1)
                a1_linear = self.linear(a1_conv)
                a2_conv = self.conv(a2)
                a2_linear = self.linear(a2_conv)
                b_conv = self.conv(b)
                b_linear = self.linear(b_conv)
                return {
                    "a": [
                        a1_linear.cos() + b_linear.sin(),
                        a1_linear.cos() + b_linear.sin(),
                    ],
                    "b": a2_linear.sin() + b_linear.cos(),
                }

        inp_container = {
            "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
            "b": torch.randn(20, 16, 50, 100),
        }

        gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
        gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)

        inp_test = {
            "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
            "b": torch.randn(20, 16, 50, 100),
        }

        self.assertTrue(torch.allclose(gm(inp_test)["a"][0], gm2(inp_test)["a"][0]))
        self.assertTrue(torch.allclose(gm(inp_test)["a"][1], gm2(inp_test)["a"][1]))
        self.assertTrue(torch.allclose(gm(inp_test)["b"], gm2(inp_test)["b"]))

    def test_retracibility_nested_list_out(self):
        class MyLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = torch.randn(20, 98)
                self.bias = torch.randn(20)

            def forward(self, x):
                return torch.nn.functional.linear(x, self.weight, self.bias)

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(16, 33, 3)
                self.linear = MyLinear()

            def forward(self, x):
                a1, a2 = x["a"]
                b = x["b"]
                a1_conv = self.conv(a1)
                a1_linear = self.linear(a1_conv)
                a2_conv = self.conv(a2)
                a2_linear = self.linear(a2_conv)
                b_conv = self.conv(b)
                b_linear = self.linear(b_conv)
                return [
                    [
                        a1_linear.cos() + b_linear.sin(),
                        a1_linear.cos() + b_linear.sin(),
                    ],
                    [
                        a2_linear.sin() + b_linear.cos(),
                        a2_linear.sin() + b_linear.cos(),
                    ],
                ]

        inp_container = {
            "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
            "b": torch.randn(20, 16, 50, 100),
        }

        gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
        gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)

        inp_test = {
            "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
            "b": torch.randn(20, 16, 50, 100),
        }

        self.assertTrue(torch.allclose(gm(inp_test)[0][0], gm2(inp_test)[0][0]))
        self.assertTrue(torch.allclose(gm(inp_test)[0][1], gm2(inp_test)[0][1]))
        self.assertTrue(torch.allclose(gm(inp_test)[1][0], gm2(inp_test)[1][0]))
        self.assertTrue(torch.allclose(gm(inp_test)[1][1], gm2(inp_test)[1][1]))

    def test_fx_pytree(self):
        def foo(args):
            flat_args, spec = torch.utils._pytree.tree_flatten(args)
            flat_args_fx = torch.fx._pytree.tree_flatten_spec(args, spec)
            return flat_args_fx[0] + flat_args[0]

        inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))

        gm, _ = torch._dynamo.export(foo, inp_container, aten_graph=True)

        self.assertTrue(torch.allclose(foo(inp_container), gm(inp_container)))

    @config.patch(suppress_errors=True)
    @config.patch(verbose=True)
    def test_export_with_map_zero_sized_tensor_suppress_errors(self):
        from functorch.experimental.control_flow import map

        class Module(torch.nn.Module):
            def forward(self, xs):
                def body(x):
                    return x + 1

                return map(body, xs)

        mod = Module()
        xs = torch.randn(0, 2)
        with self.assertRaises(
            torch._dynamo.exc.Unsupported,
        ):
            out_graph, _ = torch._dynamo.export(mod, xs)

    def test_param_buffer_safe_from_mutation_simple(self):
        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer1 = torch.nn.Buffer(torch.zeros(5, 5))

            def forward(self, x):
                self.buffer1.add_(1)
                return x + self.buffer1

        gm, _ = torch._dynamo.export(Module(), torch.ones(5, 5), aten_graph=False)
        buffers = list(gm.named_buffers())
        self.assertEqual(len(buffers), 1)

        name, buffer = buffers[0]
        self.assertEqual(name, "L__self___buffer1")

        self.assertTrue(torch.allclose(buffer, torch.zeros(5)))

    def test_param_buffer_safe_from_mutation_recurse(self):
        class Child(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer2 = torch.nn.Buffer(torch.zeros(5))

            def forward(self, x):
                return x.sum() + self.buffer2.sum()

        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer1 = torch.nn.Buffer(torch.zeros(5))
                self.child = Child()

            def forward(self, x):
                self.buffer1.add_(1)
                self.child.buffer2.add_(2)
                return x.sum() + self.buffer1.sum() + self.child(x)

        gm, _ = torch._dynamo.export(Module(), torch.ones(5), aten_graph=False)
        for name, buffer in gm.named_buffers():
            self.assertTrue(torch.allclose(buffer, torch.zeros(5)))

    def test_predispatch_with_higher_order(self):
        def f(x):
            return cond(x.shape[0] > 4, lambda x: x + 5, lambda x: x - 3, [x])

        gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)(
            torch.randn(4, 4)
        )
        inp1 = torch.randn(4, 4)
        inp2 = torch.randn(6, 4)
        self.assertTrue(torch.allclose(f(inp1), gm(inp1)))
        self.assertTrue(torch.allclose(f(inp2), gm(inp2)))

    def test_predispatch_with_higher_order_nested(self):
        def f(x):
            def true_fn(x):
                return cond(x.shape[0] > 6, lambda x: x + 10, lambda x: x - 10, [x])

            return cond(x.shape[0] > 4, true_fn, lambda x: x - 3, [x])

        gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)(
            torch.randn(4, 4)
        )
        inp1 = torch.randn(4, 4)
        inp2 = torch.randn(6, 4)
        inp3 = torch.randn(8, 4)
        self.assertTrue(torch.allclose(f(inp1), gm(inp1)))
        self.assertTrue(torch.allclose(f(inp2), gm(inp2)))
        self.assertTrue(torch.allclose(f(inp3), gm(inp3)))

    def test_predispatch_with_for_out_dtype(self):
        class M(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = weight

            def forward(self, x):
                return out_dtype(torch.ops.aten.mm.default, torch.int32, x, self.weight)

        weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
        m = M(weight)
        x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
        gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True)

        self.assertTrue(torch.allclose(m(x), gm(x)))

    def test_predispatch_with_for_out_dtype_nested(self):
        class M(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = weight

            def true_fn(self, x):
                return out_dtype(
                    torch.ops.aten.mm.default, torch.int32, x, self.weight
                ).sum()

            def false_fn(self, x):
                return out_dtype(
                    torch.ops.aten.mul.Tensor, torch.int32, x, self.weight
                ).sum()

            def forward(self, x):
                return cond(x.sum() != 0, self.true_fn, self.false_fn, [x])

        weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
        m = M(weight)
        x = torch.ones((5, 5), dtype=torch.int8)
        gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True)

        self.assertTrue(torch.allclose(m(x), gm(x)))
        y = torch.zeros((5, 5), dtype=torch.int8)
        self.assertTrue(torch.allclose(m(y), gm(y)))

        self.assertExpectedInline(
            gm.true_graph_0.code.strip(),
            """\
def forward(self, arg0_1, arg1_1):
    out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg1_1, arg0_1);  arg1_1 = arg0_1 = None
    sum_1 = torch.ops.aten.sum.default(out_dtype);  out_dtype = None
    return (sum_1,)""",
        )

        self.assertExpectedInline(
            gm.false_graph_0.code.strip(),
            """\
def forward(self, arg0_1, arg1_1):
    out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg1_1, arg0_1);  arg1_1 = arg0_1 = None
    sum_1 = torch.ops.aten.sum.default(out_dtype);  out_dtype = None
    return (sum_1,)""",
        )

    def test_export_nn_module_stack_patched_module(self):
        def forward(self, x, y):
            return x * y

        class Toplevel(torch.nn.Module):
            def __init__(self, m):
                super().__init__()
                self.m = m

            def forward(self, x, y):
                return self.m(x, y)

        class M(torch.nn.Module):
            def forward(self, x, y):
                return x + y

        t = Toplevel(M())
        t.m.forward = forward.__get__(t.m, M)
        x, y = torch.rand(3), torch.rand(3)
        gm, _ = torch._dynamo.export(t, x, y)

        self.assertTrue(torch.allclose(forward(None, x, y), gm(x, y)))
        for node in gm.graph.nodes:
            if node.op == "call_function":
                self.assertIn("nn_module_stack", node.meta)

    def test_preserve_fx_node_metadata(self):
        class Module1(torch.nn.Module):
            def forward(self, x):
                return torch.sin(x)

        class Module2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mod1 = Module1()

            def forward(self, x):
                x = torch.cos(x)
                x = self.mod1(x)
                x = torch.relu(x)
                return x

        def fn(x):
            return torch.abs(x)

        mod = Module2()
        inp = torch.randn(3, 3)

        gm, _ = torch._dynamo.export(mod)(inp)

        # replace relu with fn
        gm_edit = copy.deepcopy(gm)
        for nd in gm_edit.graph.nodes:
            if nd.target == torch.relu:
                nd.target = fn
                nd.meta.clear()
                break
        gm_edit.recompile()

        gm2, _ = torch._dynamo.export(gm_edit)(inp)

        self.assertExpectedInline(
            gm.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    l_x_ = arg0
    x = torch.cos(l_x_);  l_x_ = None
    x_1 = torch.sin(x);  x = None
    x_2 = torch.relu(x_1);  x_1 = None
    return pytree.tree_unflatten([x_2], self._out_spec)""",
        )

        def _constais_op(gm, target):
            for nd in gm.graph.nodes:
                if nd.target == target:
                    return True
            return False

        self.assertTrue(_constais_op(gm_edit, torch.cos))
        self.assertTrue(_constais_op(gm_edit, torch.sin))
        self.assertTrue(not _constais_op(gm_edit, torch.relu))

        self.assertExpectedInline(
            gm2.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    l_x_ = arg0
    x = torch.cos(l_x_);  l_x_ = None
    x_1 = torch.sin(x);  x = None
    x_2 = torch.abs(x_1);  x_1 = None
    return pytree.tree_unflatten([x_2], self._out_spec)""",
        )

        # check for other metadata
        for op in (torch.sin, torch.cos):
            nd1 = next(filter(lambda nd: nd.target == op, gm.graph.nodes))
            nd2 = next(filter(lambda nd: nd.target == op, gm2.graph.nodes))
            self.assertTrue(
                ("nn_module_stack" in nd1.meta) == ("nn_module_stack" in nd2.meta)
            )
            if "nn_module_stack" in nd1.meta:
                self.assertEqual(
                    nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"]
                )
            self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"])

    def test_preserve_fx_node_metadata_recompile(self):
        def fn(x):
            return torch.sin(x)

        gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))
        do_export = torch._dynamo.export(gm)
        torch._dynamo.optimize("eager")(fn)(torch.randn(3, 3))
        gm1, _ = do_export(torch.randn(3, 3))
        gm2, _ = do_export(torch.randn(5, 3))

        self.assertExpectedInline(
            gm1.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    l_x_ = arg0
    sin = torch.sin(l_x_);  l_x_ = None
    return pytree.tree_unflatten([sin], self._out_spec)""",
        )
        self.assertExpectedInline(
            gm2.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    l_x_ = arg0
    sin = torch.sin(l_x_);  l_x_ = None
    return pytree.tree_unflatten([sin], self._out_spec)""",
        )

    def test_preserve_fx_node_metadata_inline(self):
        def f1(x):
            return torch.sin(x)

        gm, _ = torch._dynamo.export(f1)(torch.randn(3, 3))

        def f2(x):
            x = torch.cos(x)
            return gm(x)

        gm2, _ = torch._dynamo.export(f2)(torch.randn(3, 3))

        self.assertExpectedInline(
            gm2.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    l_x_ = arg0
    x = torch.cos(l_x_);  l_x_ = None
    sin = torch.sin(x);  x = None
    return pytree.tree_unflatten([sin], self._out_spec)""",
        )

    def test_preserve_fx_node_metadata_graph_break(self):
        def fn(x):
            x = torch.sin(x)
            x = torch.abs(x)
            return torch.cos(x)

        def bad_fn(x):
            torch._dynamo.graph_break()
            return x

        gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))

        # replace abs with graph break
        gm_edit = copy.deepcopy(gm)
        for nd in gm_edit.graph.nodes:
            if nd.target == torch.abs:
                nd.target = bad_fn
                nd.meta.clear()
                break
        gm_edit.recompile()

        expected = [
            """x = torch.sin(l_x_)""",
            """cos = torch.cos(l_stack0_)""",
        ]

        def test_backend(gm: torch.fx.GraphModule, example_inputs):
            self.assertTrue(expected)
            # Normalize output for dynamic and not
            for nd in gm.graph.nodes:
                if "example_value" in nd.meta:
                    del nd.meta["example_value"]
            self.assertIn(expected[0], gm.print_readable(print_output=False))
            expected.pop(0)
            return gm.forward

        torch._dynamo.reset()
        opt_gm_edit = torch.compile(gm_edit, backend=test_backend)
        opt_gm_edit(torch.randn(3, 3))

    def test_torch_inference_mode_ctx(self):
        @torch.inference_mode()
        def fn(x):
            return x + 1

        gm, _ = torch._dynamo.export(fn, torch.rand(2, 2))

        inp = torch.randn(2, 2)
        out = gm(inp)
        self.assertExpectedInline(
            gm.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    l_args_0_ = arg0
    _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
    add = l_args_0_ + 1;  l_args_0_ = None
    _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode);  _enter_inference_mode = _exit_inference_mode = None
    return pytree.tree_unflatten([add], self._out_spec)""",  # NOQA: B950
        )
        self.assertEqual(out.requires_grad, False)
        with self.assertRaisesRegex(
            RuntimeError,
            "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.",
        ):
            out.requires_grad = True

        @torch.inference_mode(False)
        def fn_no_inference(x):
            return x + 1

        gm_no_inference, _ = torch._dynamo.export(fn_no_inference, torch.rand(2, 2))
        self.assertExpectedInline(
            gm_no_inference.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    l_args_0_ = arg0
    _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False)
    add = l_args_0_ + 1;  l_args_0_ = None
    _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode);  _enter_inference_mode = _exit_inference_mode = None
    return pytree.tree_unflatten([add], self._out_spec)""",  # NOQA: B950
        )

        inp = torch.randn(2, 2)
        out = gm_no_inference(inp)
        self.assertEqual(out.requires_grad, False)
        out.requires_grad = True

        def fn(x):
            with torch.inference_mode():
                return x + 1

        gm, _ = torch._dynamo.export(fn)(torch.rand(2, 2))
        self.assertExpectedInline(
            gm.code.strip(),
            """\
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    l_x_ = arg0
    _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
    add = l_x_ + 1;  l_x_ = None
    _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode);  _enter_inference_mode = _exit_inference_mode = None
    return pytree.tree_unflatten([add], self._out_spec)""",  # NOQA: B950
        )
        inp = torch.randn(2, 2, requires_grad=True)
        out = gm(inp)
        self.assertEqual(out.requires_grad, False)

    def test_export_masking_with_no_grad(self):
        def fn(x, b, y):
            x = x.clone()
            x[b] = y
            return x

        def fn_no_grad(x, b, y):
            with torch.no_grad():
                return fn(x, b, y)

        def fn_inference_mode(x, b, y):
            with torch.inference_mode():
                return fn(x, b, y)

        x = torch.randn(4, requires_grad=True)
        b = torch.tensor([True, False, True, False])
        y = torch.randn(2, requires_grad=True)

        gm, _ = torch._dynamo.export(fn_no_grad)(x, b, y)
        self.assertExpectedInline(
            gm.code.strip(),
            """\
def forward(self, x, b, y):
    arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
    l_x_ = arg0
    l_b_ = arg1
    l_y_ = arg2
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    x = l_x_.clone();  l_x_ = None
    x[l_b_] = l_y_;  setitem = x;  l_b_ = l_y_ = setitem = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    return pytree.tree_unflatten([x], self._out_spec)""",
        )

        gm, _ = torch._dynamo.export(fn_inference_mode)(x, b, y)
        self.assertExpectedInline(
            gm.code.strip(),
            """\
def forward(self, x, b, y):
    arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
    l_x_ = arg0
    l_b_ = arg1
    l_y_ = arg2
    _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
    x = l_x_.clone();  l_x_ = None
    x[l_b_] = l_y_;  setitem = x;  l_b_ = l_y_ = setitem = None
    _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode);  _enter_inference_mode = _exit_inference_mode = None
    return pytree.tree_unflatten([x], self._out_spec)""",  # NOQA: B950
        )

        with self.assertRaisesRegex(
            torch._dynamo.exc.Unsupported, "boolean masking setitem backwards"
        ):
            gm, _ = torch._dynamo.export(fn)(x, b, y)

    def test_dynamo_list_index(self):
        def fn(x, in_list):
            return x + in_list.index(2)

        inputs = (torch.ones(2, 2), [1, 2])
        graph, _ = torch._dynamo.export(fn)(*inputs)
        out = graph(*inputs)
        self.assertEqual(out, torch.ones(2, 2) + 1)


common_utils.instantiate_parametrized_tests(ExportTests)

if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
