# Owner(s): ["module: dynamo"]
import math
import random
import unittest

import numpy as np

import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.nn.functional as F
from torch._dynamo.comptime import comptime
from torch._dynamo.testing import CompileCounter, same
from torch.testing._internal.common_utils import skipIfWindows
from torch.testing._internal.logging_utils import logs_to_string


# The intention of this test file is you should put test cases specifically
# for assume_static_by_default=False, aka you want to YOLO make everything as
# dynamic as possible.  If you want to test the more normal situation where
# you assume static by default, put it in a regular test file and
# test_dynamic_shapes will cover both the YOLO and non-YOLO cases.


@torch._dynamo.config.patch(assume_static_by_default=False)
class UnspecTests(torch._dynamo.test_case.TestCase):
    def test_numpy_correctness(self):
        def fn(x, y, z):
            xy = [x + y, y, False]
            np_x = x.numpy()
            np_y = y.numpy()
            return {
                "x": x,
                "z": z,
                "a": np_y.sum(),
                "b": xy,
                "c": np_y[0][0] / 68,
                "d": np_x.sum(),
                "e": np_x + np_y,
            }, x + np_y.sum() + z

        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
        y = torch.ones([2, 2], dtype=torch.int64)
        z = np.int64(12)
        res1 = fn(x, y, z)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        res2 = opt_fn(x, y, z)
        self.assertEqual(res1, res2)

    def test_no_recompilations(self):
        # no recompilations if passing on different numpy int values
        def fn(x, y):
            return {"a": x + 1, "b": y / 2}

        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        for i in range(10):
            opt_fn(x, np.int64(i))
        self.assertEqual(cnts.frame_count, 1)
        self.assertEqual(cnts.op_count, 2)

    @unittest.expectedFailure  # array scalars decay to 0D arrays
    def test_builtin_max_min(self):
        # test unspecialized primitive max/min
        def fn(x, y, z):
            return z + 1, max(x, y), min(x - 4, y)

        x = np.int64(12)
        y = 10
        z = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
        res1 = fn(x, y, z)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        res2 = opt_fn(x, y, z)
        self.assertTrue(same(res1, res2, relax_numpy_equality=True))

    def test_feed_random_values_into_graph_only(self):
        def fn(shape):
            torch.manual_seed(123)
            x = torch.randn(shape, device="cpu") * random.randint(30, 100)
            return x

        shape = [2, 3]
        random.seed(1)
        res1 = fn(shape)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        random.seed(1)
        res2 = opt_fn(shape)

        self.assertTrue(same(res1, res2))

    def test_random_values_with_graph_break(self):
        def fn(x):
            r1 = random.random()
            y = x + random.uniform(10, 20)
            y.sum().item()
            r2 = random.randint(2, 18)  # no graph output in this frame
            y.sum().item()
            return y + r1, r2

        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
        random.seed(1)
        res1 = fn(x)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        random.seed(1)
        res2 = opt_fn(x)
        self.assertTrue(same(res1, res2))

    # Really annoying intersection of specialization and RandomValueSource
    # If we get a RandomValueSource with a single element tensor, we should return a ConstantVariable like other
    # unspects... but if we do, we break the bytecode assumptions and guards will not work as we will be referring
    # to a name from a source that is not there. If we call .item() and take the wrapped_value out, where we do
    # wrapped_value = wrapped_value.item() where we send unspec down to wrap_fx_proxy, this test passes and then
    # some models fail on missing codegen.tx.output.random_values_var. If we let the tensor value go into wrap as
    # it is, this test fails.
    # The real solution here is to rewrite RandomValueSource and all the codegen it does from the ground up.
    def test_multiple_consecutive_random_calls_before_graph(self):
        def fn(x):
            dim1 = random.randrange(start=0, stop=5)
            dim2 = random.randrange(start=0, stop=5)
            dim3 = random.randrange(start=0, stop=5)
            y = torch.rand(dim1, dim2, dim3)
            return x + 2, y

        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
        random.seed(1)
        res1 = fn(x)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        random.seed(1)
        res2 = opt_fn(x)
        self.assertTrue(same(res1, res2))

    def test_compiled_random_calls_are_random(self):
        # For compiled functions with random calls,
        # it should return different values for every iteration.
        # https://github.com/pytorch/pytorch/issues/95425
        @torch.compile(backend="eager", fullgraph=True)
        def fn(x):
            return (x + 1) * random.uniform(0, 1)

        res = []
        for _ in range(5):
            res.append(fn(torch.ones(2)))
        for i in range(1, 5):
            self.assertFalse(same(res[i - 1], res[i]))

    def test_random_call_with_while_loop(self):
        def fn(x):
            dim1 = random.randrange(start=0, stop=3)
            dim2 = dim1
            while dim1 == dim2:
                dim2 = random.randrange(start=0, stop=3)
            return x * 2

        x = torch.randn(4)
        random.seed(1)
        res1 = fn(x)
        opt_fn = torch._dynamo.optimize("eager")(fn)
        random.seed(1)
        res2 = opt_fn(x)
        self.assertTrue(same(res1, res2))

        random.seed(10)
        res1 = fn(x)
        random.seed(10)
        res2 = opt_fn(x)
        self.assertTrue(same(res1, res2))

    def test_random_object(self):
        # test argument passing, mutation, reconstruction, state correctness
        def fn(x, rand2):
            r1 = random.randint(1, 9)
            r2 = rand2.randint(1, 9)
            rand3 = random.Random(42)
            r3 = rand3.randint(1, 9)

            y = x + r1 + r2 + r3
            return y, rand2, rand3

        inp = torch.randn(3, 3)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        random.seed(0)
        y_1, rand2_1, rand3_1 = fn(inp, random.Random(12))
        state_1 = random.getstate()
        random.seed(0)
        y_2, rand2_2, rand3_2 = opt_fn(inp, random.Random(12))
        state_2 = random.getstate()
        self.assertEqual(y_1, y_2)
        self.assertEqual(state_1, state_2)
        self.assertEqual(rand2_1.getstate(), rand2_2.getstate())
        self.assertEqual(rand3_1.getstate(), rand3_2.getstate())

    def test_random_object_methods(self):
        def fn(x, rand1, rand2, rand3):
            rand1.seed(42)
            rand4 = random.Random(9002)
            rand2.setstate(rand4.getstate())
            r1 = rand1.random()
            r2 = rand2.randint(1, 10)
            r3 = rand3.randrange(10)
            r4 = rand4.uniform(0, 1)
            return x + r1 + r2 + r3 + r4

        inp = torch.randn(3, 3)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        rand1_1 = random.Random(1)
        rand2_1 = random.Random(2)
        rand3_1 = random.Random(3)
        rand1_2 = random.Random(1)
        rand2_2 = random.Random(2)
        rand3_2 = random.Random(3)
        y1 = fn(inp, rand1_1, rand2_1, rand3_1)
        y2 = opt_fn(inp, rand1_2, rand2_2, rand3_2)
        self.assertEqual(y1, y2)
        self.assertEqual(rand1_1.getstate(), rand1_2.getstate())
        self.assertEqual(rand2_1.getstate(), rand2_2.getstate())
        self.assertEqual(rand3_1.getstate(), rand3_2.getstate())

    def test_random_object_overriden_methods(self):
        # these will result in graph breaks, but we shouldn't crash
        def get_rng():
            rand1 = random.Random(1)
            rand2 = random.Random(2)

            orig_random = rand1.random

            def custom_random():
                return orig_random()

            orig_getstate = rand2.getstate

            def custom_getstate():
                return orig_getstate()

            rand1.random = custom_random
            rand2.getstate = custom_getstate
            return rand1, rand2

        def fn(x, rand1, rand2):
            r1 = rand1.random()
            rand3 = random.Random()
            rand3.setstate(rand2.getstate())
            r2 = rand3.random()
            return x + r1 + r2

        inp = torch.randn(3, 3)
        opt_fn = torch.compile(fn, backend="eager")
        y1 = fn(inp, *get_rng())
        y2 = opt_fn(inp, *get_rng())
        self.assertEqual(y1, y2)

    def test_builtin_getitem(self):
        # builtin getitem args[0] is python list and args[1] is unspec
        def fn(x, idx):
            return (torch.zeros(idx), x[idx], x[idx:])

        x = list(range(50))
        ref = fn(x, 48)  # 48 is unspecialized
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        res = opt_fn(x, 48)
        self.assertTrue(same(ref, res))

    def test_use_and_specialize(self):
        cnt = CompileCounter()

        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
        def fn(x, y):
            x = x + y
            if y == 2:
                return x - 1
            else:
                return x + 1

        self.assertTrue(same(fn(torch.tensor([5]), 2), 6))
        self.assertTrue(same(fn(torch.tensor([6]), 2), 7))
        self.assertTrue(same(fn(torch.tensor([5]), 3), 9))
        self.assertTrue(same(fn(torch.tensor([4]), 3), 8))
        self.assertEqual(cnt.frame_count, 2)

    def test_no_recompiles(self):
        cnt = CompileCounter()

        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
        def fn(x, y):
            return x + y

        self.assertTrue(same(fn(torch.tensor([5]), 100), 105))
        self.assertTrue(same(fn(torch.tensor([4]), 200), 204))
        self.assertTrue(same(fn(torch.tensor([3]), 300), 303))
        self.assertTrue(same(fn(torch.tensor([2]), 400), 402))
        self.assertEqual(cnt.frame_count, 1)
        self.assertEqual(cnt.op_count, 1)

    def test_no_recompiles_prod_backward(self):
        # https://github.com/pytorch/pytorch/issues/120608
        cnt = CompileCounter()

        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
        def fn(t):
            return torch.prod(t, 3, keepdim=True)

        input_shapes = [(8, 10, 3, 2), (8, 3, 5, 2), (8, 4, 8, 2)]
        for s in input_shapes:
            t1 = torch.randn(s, requires_grad=True)
            h_result = fn(t1)
            grad = torch.ones_like(h_result)
            h_result.backward(grad)

        self.assertEqual(cnt.frame_count, 1)
        self.assertEqual(cnt.op_count, 1)

    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
    def test_builtin_functions_on_cuda(self):
        def fn(x, scaler):
            m = torch.nn.ReLU()
            y = m(x) * scaler
            return y

        x = torch.randn([3, 6], device="cuda")
        scaler = 0.23  # 0.23 is unspecialized
        ref = fn(x, scaler)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        res = opt_fn(x, scaler)
        self.assertTrue(same(ref, res))
        self.assertEqual(ref.device, res.device)

    def test_unspec_float_precision(self):
        def fn(image, scale_factor):
            image = torch.nn.functional.interpolate(
                image[None],
                size=None,
                scale_factor=scale_factor,
                mode="bilinear",
                recompute_scale_factor=True,
                align_corners=False,
            )[0]

            return image.shape

        x = torch.rand([3, 427, 640])
        scale_factor = 1.873536229133606
        ref = fn(x, scale_factor)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        res = opt_fn(x, scale_factor)
        self.assertTrue(same(ref, res))

    @unittest.expectedFailure  # fails as long as numpy scalars are 0D arrays
    def test_specializing_numpy_float_in_control_flow(self):
        # np.float64 is unspecialized by default,
        # but it should be specialized when used in control flow.
        def fn(x, y):
            if y > 1.0:
                return x + 1
            else:
                return x - 1

        x = torch.rand(4)
        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
        for t in [np.float16, np.float32, np.float64]:
            y = t(1.23)
            ref = fn(x, y)
            res = opt_fn(x, y)
            self.assertTrue(same(ref, res))

    def test_mark_static_inside(self):
        def fn(x):
            torch._dynamo.mark_static(x, 0)
            comptime.assert_static(x.size(0))
            return x + 1

        opt_fn = torch.compile(fn, dynamic=True, fullgraph=True)
        opt_fn(torch.randn(12, 23))

    def test_shape_graph_break(self):
        from torch._dynamo.comptime import comptime

        def fn(x):
            x_shape = x.size()
            comptime.graph_break()
            return x + torch.randn(x_shape)

        x = torch.randn(20)
        opt_fn = torch._dynamo.optimize("eager")(fn)
        opt_fn(x)

    def test_isinstance_symint(self):
        def fn(x):
            assert isinstance(x.size(0), int)
            return x * 2

        x = torch.randn(20)
        opt_fn = torch._dynamo.optimize("eager")(fn)
        opt_fn(x)
        y = torch.randn(30)
        torch._dynamo.mark_dynamic(y, 0)
        opt_fn(y)

    def test_mark_01_dynamic(self):
        def fn(x):
            return x * 2

        x = torch.randn(1)
        torch._dynamo.mark_dynamic(x, 0)
        opt_fn = torch._dynamo.optimize("eager")(fn)
        # This will fail to compile a generic kernel, but we should not
        # complain about it (mark dynamic will try its best but 0/1
        # specialization is allowed)
        opt_fn(x)

    def test_conv1d_symint_padding(self):
        kernel = torch.randn(1, 1, 4)

        def func(x):
            padding = math.ceil((kernel.shape[-1] + x.shape[-1] % 2) / 2) - 1
            out = F.conv1d(x, kernel, padding=padding, stride=2)
            return out

        opt_func = torch.compile(func)

        x = torch.randn(1, 1, 175)
        opt_func(x)  # passes
        x = torch.randn(1, 1, 249)
        opt_func(x)  # crashes

    @torch._dynamo.config.patch("assume_static_by_default", True)
    def test_propagate_dynamic_dim(self):
        x = torch.randn(20)
        torch._dynamo.mark_dynamic(x, 0)

        @torch.compile()
        def fn(x):
            y = x * 2
            comptime.graph_break()
            z = y * 2
            return z

        z = fn(x)
        self.assertEqual(z._dynamo_weak_dynamic_indices, {0})

    def test_rshift_dynamic(self):
        def shift_right(tensor: torch.Tensor) -> torch.Tensor:
            return (tensor >> 2).to(torch.long)

        opt_fn = torch.compile(shift_right, fullgraph=True, dynamic=True)
        sample_input = torch.tensor([4, 4, 16, 32], dtype=torch.uint8)
        opt_fn(sample_input)

    @torch._dynamo.config.patch(capture_scalar_outputs=True)
    def test_symfloat_to_tensor(self):
        def f1(v):
            return torch.tensor([v.item()])

        def f2(v):
            return torch.tensor([[v.item()], [2.0]])

        def f3(v):
            return torch.tensor(v.item())

        def f4(v):
            return torch.tensor((v.item(),))

        optimize = torch.compile(backend="aot_eager", fullgraph=True)

        r = torch.randn(1)

        self.assertEqual(f1(r), optimize(f1)(r))
        self.assertEqual(f2(r), optimize(f2)(r))
        self.assertEqual(f3(r), optimize(f3)(r))
        self.assertEqual(f4(r), optimize(f4)(r))

    @skipIfWindows(
        msg="AssertionError: The values for attribute 'dtype' do not match: torch.int32 != torch.int64."
    )
    def test_to_tensor(self):
        def f1():
            a = np.random.uniform(low=-1, high=1, size=(20, 1))
            return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu")

        def f2():
            a = torch.tensor([[[123]]])
            return torch.tensor([a, a])

        def f3():
            a = torch.tensor(123)
            return torch.tensor([a, a])

        def f4():
            a = torch.tensor(123)
            b = torch.tensor([[[456]]])
            return torch.tensor([a, b])

        def f5():
            a = np.array([1, 2])
            return torch.tensor([a, a])

        optimize = torch.compile(backend="aot_eager", fullgraph=True)

        self.assertEqual(f1().shape, optimize(f1)().shape)
        self.assertEqual(f2(), optimize(f2)())
        self.assertEqual(f3(), optimize(f3)())
        self.assertEqual(f4(), optimize(f4)())
        self.assertEqual(f5(), optimize(f5)())

    def test_sym_int_conversion(self):
        def f(x):
            y = x.size(0)
            return x * int(y == 0)

        opt_fn = torch.compile(f, backend="eager", fullgraph=True)
        x = torch.randn(2, 3)
        opt_fn(x)

    def test_sum_dimlist_spec(self):
        def fn(inputs, dim):
            return torch.sum(inputs, dim)

        inputs = torch.randn(128, 5, 24, 24)
        dim = (-1, 1, 0, 2)
        compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True)
        self.assertEqual(compl_fn(inputs, dim), fn(inputs, dim))

    @torch._dynamo.config.patch(capture_scalar_outputs=True)
    def test_item_max(self):
        def fn(x):
            return torch.ones(max(x.item(), 1024))

        x = torch.tensor([1000])
        y = torch.tensor([2000])
        compl_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), compl_fn(x))
        self.assertEqual(fn(y), compl_fn(y))

    # https://github.com/pytorch/pytorch/issues/104812
    def test_argmin_coerces_symint_to_intlist_spec(self):
        def fn(x, dim):
            # the python arg parser coerces dim into a vector<int>
            return torch.amin(x, dim=dim, keepdim=True)

        x = torch.randn(4, 4, 4)
        dim = 2
        compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True)
        self.assertEqual(compl_fn(x, dim), fn(x, dim))

    def test_exponential(self):
        def fn(inputs, op_inputs_dict):
            res = inputs.exponential_(**op_inputs_dict)
            return res

        inputs = torch.randn(2, 3, 4)
        op_inputs_dict = {"lambd": 10, "generator": None}
        compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True)
        self.assertEqual(compl_fn(inputs, op_inputs_dict), fn(inputs, op_inputs_dict))

    def test_symbol_guard_limit_before_specialize(self):
        cnts = torch._dynamo.testing.CompileCounter()

        @torch._dynamo.optimize(cnts, dynamic=True)
        def fn(x):
            torch._check(x.size(0) != 3)
            torch._check(x.size(0) != 4)
            torch._check(x.size(0) != 5)
            torch._check(x.size(0) != 6)
            return x + 2

        # Control test
        fn(torch.randn(12))
        fn(torch.randn(13))
        fn(torch.randn(14))

        self.assertExpectedInline(cnts.frame_count, """1""")
        cnts.frame_count = 0

        torch._dynamo.reset()

        with torch.fx.experimental._config.patch(
            symbol_guard_limit_before_specialize=3
        ):
            fn(torch.randn(12))
            fn(torch.randn(13))
            fn(torch.randn(14))

            self.assertExpectedInline(cnts.frame_count, """3""")

    def test_defaults(self):
        def g(x, i=8):
            comptime.assert_static(i)
            return x * i

        def fn(x):
            return g(x)

        inputs = torch.randn(2, 3, 4)
        compl_fn = torch.compile(fn, dynamic=True, backend="eager")
        self.assertEqual(compl_fn(inputs), fn(inputs))

    @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True)
    def test_unspec_float_input(self):
        cnts = torch._dynamo.testing.CompileCounter()

        def f(x, y):
            if y == 5.0:
                return x + 2
            else:
                return x + y

        cf = torch.compile(backend=cnts, fullgraph=True)(f)

        x = torch.randn(3)
        self.assertEqual(f(x, 3.0), cf(x, 3.0))
        self.assertEqual(f(x, 4.0), cf(x, 4.0))
        self.assertExpectedInline(cnts.frame_count, """1""")  # no recompile
        self.assertEqual(f(x, 5.0), cf(x, 5.0))
        self.assertExpectedInline(cnts.frame_count, """2""")  # guard worked
        self.assertEqual(f(x, math.nan), cf(x, math.nan))
        self.assertExpectedInline(cnts.frame_count, """3""")  # nan always recompiles

    @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True)
    def test_unspec_float_output(self):
        cnts = torch._dynamo.testing.CompileCounter()

        def f(x, y):
            return x + 1, y * 2

        cf = torch.compile(backend=cnts, fullgraph=True)(f)
        x = torch.randn(3)

        self.assertEqual(f(x, 3.0), cf(x, 3.0))
        self.assertEqual(f(x, 4.0), cf(x, 4.0))
        self.assertEqual(f(x, 5.0), cf(x, 5.0))

    @torch._dynamo.config.patch(capture_scalar_outputs=True)
    def test_data_dependent_evaluate_expr_graph_break(self):
        cnts = torch._dynamo.testing.CompileCounter()

        # To ensure that the continuation frame is compiled,
        # have to write the test function in this funny way.
        # See https://github.com/pytorch/pytorch/issues/111918
        def test(y):
            if y > 2:
                return True
            else:
                return False

        @torch._dynamo.optimize(cnts)
        def fn(x):
            x = x + 1
            y = x.item()
            if test(y):
                return x * 2
            else:
                return x * 3

        x = torch.tensor([3.0])
        fn(x)

        self.assertExpectedInline(cnts.frame_count, """2""")
        self.assertExpectedInline(cnts.op_count, """4""")

    def test_prune_torch_check(self):
        log_stream, ctx = logs_to_string("torch._dynamo.output_graph", "graph_code")

        @torch.compile(fullgraph=True, dynamic=True, backend="eager")
        def f(x, y):
            torch._check(y + 5 == 85)
            torch._check(x.size(0) == 80)

        with ctx():
            f(torch.randn(80, 100), 80)

        out = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
        self.assertExpectedInline(
            out,
            """\
def forward(self):
        return ()""",
        )

    @torch._dynamo.config.patch(capture_scalar_outputs=True)
    def test_split_aot_autograd(self):
        @torch.compile(backend="aot_eager", fullgraph=True)
        def f(x, i):
            y, z = i.tolist()
            return torch.split(x, [y, z])

        print(f(torch.randn(10, requires_grad=True), torch.tensor([7, 3])))

    def test_bool_tensor_ctor(self):
        cnts = torch._dynamo.testing.CompileCounter()

        @torch.compile(backend=cnts, dynamic=True, fullgraph=True)
        def f(x):
            y = torch.empty((x.size(0) // 13) * 13)
            return torch.tensor(y.numel() == 0)

        self.assertTrue(f(torch.empty(8)).item())
        self.assertFalse(f(torch.empty(13)).item())

    @torch._dynamo.config.patch(error_on_recompile=True)
    def test_mark_unbacked(self):
        class TestModel(torch.nn.Module):
            def __init__(
                self,
            ):
                super().__init__()

            def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
                return x * 2

        main_model = TestModel()
        opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True)

        x1 = torch.rand(3, 5, 4, 8)
        x2 = torch.rand(1, 5, 4, 8)

        torch._dynamo.decorators.mark_unbacked(x1, 0)

        o1_ref = main_model(x1, 2)
        o1 = opt_model(x1, 2)
        self.assertEqual(o1_ref, o1)

        o1_2_ref = main_model(x2, 2)
        o1_2 = opt_model(x2, 2)
        self.assertEqual(o1_2_ref, o1_2)

    @torch._dynamo.config.patch(error_on_recompile=True)
    def test_mark_unbacked_hint_consistency(self):
        from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

        x = torch.randn(1)
        torch._dynamo.decorators.mark_unbacked(x, 0)

        @torch.compile()
        def f(x):
            if guard_size_oblivious(x.size(0) != 1):
                return x + 3
            else:
                return x + 4

        self.assertEqual(f(x), x + 3)

    @torch._dynamo.config.patch(error_on_recompile=True)
    def test_mark_unbacked_channels_last(self):
        class TestModel(torch.nn.Module):
            def __init__(
                self,
            ):
                super().__init__()

            def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
                return x * 2

        main_model = TestModel()
        opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True)

        x1 = torch.rand(3, 5, 4, 8).to(memory_format=torch.channels_last)
        x2 = torch.rand(1, 5, 4, 8).to(memory_format=torch.channels_last)

        torch._dynamo.decorators.mark_unbacked(x1, 0)

        o1_ref = main_model(x1, 2)
        o1 = opt_model(x1, 2)
        self.assertEqual(o1_ref, o1)

        o1_2_ref = main_model(x2, 2)
        o1_2 = opt_model(x2, 2)
        self.assertEqual(o1_2_ref, o1_2)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
