# Owner(s): ["module: functorch"]
# flake8: noqa: B950
import unittest
from collections import deque
from functools import partial
from typing import List, TYPE_CHECKING

import torch
import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
from functorch.compile import (
    aot_function,
    default_decompositions,
    min_cut_rematerialization_partition,
    nop,
)
from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.effects import with_effects
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import (
    _get_torch_cuda_version,
    SM70OrLater,
    SM80OrLater,
)
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing._internal.common_utils import (
    IS_WINDOWS,
    run_tests,
    skipIfTorchDynamo,
    TEST_CUDA,
    TEST_WITH_ROCM,
    TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations


if TYPE_CHECKING:
    from torch.utils.hooks import RemovableHandle

from torch.testing._internal.two_tensor import TwoTensor


def extract_graph(fx_g, _, graph_cell):
    graph_cell[0] = fx_g
    return fx_g


def get_fw_bw_graph(
    f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False
):
    fw_graph_cell = [None]
    bw_graph_cell = [None]
    requires_grad = False

    def fn_req_grad(t):
        nonlocal requires_grad
        requires_grad = requires_grad or t.requires_grad
        return t

    torch.utils._pytree.tree_map_only(torch.Tensor, fn_req_grad, inps)

    out = aot_function(
        f,
        fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
        bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell)
        if requires_grad
        else nop,
        partition_fn=partitioner,
        decompositions=default_decompositions,
        dynamic=dynamic,
    )(*inps)

    if requires_grad:
        out.sum().backward()

    return (fw_graph_cell[0], bw_graph_cell[0])


def make_inputs_non_leaves(inps):
    return torch.utils._pytree.tree_map_only(torch.Tensor, lambda t: t.add(1), inps)


@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
class TestWithEffects(TestCase):
    def setUp(self):
        init_torchbind_implementations()

    def test_print(self):
        class M(torch.nn.Module):
            def forward(self, x):
                torch.ops.aten._print("moo")
                res = x + x
                torch.ops.aten._print("moo")
                return (res,)

        inputs = (torch.randn(3),)

        # Without functionalization, print should just appear in the graph directly
        gm = make_fx(M())(*inputs)
        FileCheck().check_count("torch.ops.aten._print.default", 2, exactly=True).run(
            gm.code
        )

        # With functionalization, it should appear wrapped with with_effects()
        gm, gs = aot_export_module(M(), inputs, trace_joint=False)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo');  arg0_1 = None
    getitem = with_effects[0];  with_effects = None
    add = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
    getitem_2 = with_effects_1[0];  with_effects_1 = None
    return (getitem_2, add)""",
        )
        self.assertEqual(len(gs.input_tokens), 1)
        self.assertEqual(len(gs.output_tokens), 1)

        with torch._functorch.config.patch(unlift_effect_tokens=True):
            gm, gs = aot_export_module(M(), inputs, trace_joint=False)
            self.assertExpectedInline(
                str(gm.code).strip(),
                """\
def forward(self, arg1_1):
    _make_token_default = torch.ops.prims._make_token.default()
    with_effects = torch.ops.higher_order.with_effects(_make_token_default, torch.ops.aten._print.default, 'moo');  _make_token_default = None
    getitem = with_effects[0];  with_effects = None
    add = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
    getitem_2 = with_effects_1[0];  with_effects_1 = None
    _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]);  getitem_2 = _sink_tokens_default = None
    return [add]""",  # noqa: B950
            )

    def test_torchbind_custom_op(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)

            def forward(self, x):
                return (x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x),)

        with enable_torchbind_tracing():
            gm, gs = aot_export_module(M(), (torch.ones(2, 3),), trace_joint=False)

        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1):
    _torchbind_obj0 = self._torchbind_obj0
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo.default, _torchbind_obj0, arg1_1);  arg0_1 = _torchbind_obj0 = None
    getitem = with_effects[0]
    getitem_1 = with_effects[1];  with_effects = None
    add = torch.ops.aten.add.Tensor(arg1_1, getitem_1);  arg1_1 = getitem_1 = None
    return (getitem, add)""",  # noqa: B950
        )
        self.assertEqual(len(gs.input_tokens), 1)
        self.assertEqual(len(gs.output_tokens), 1)

    def test_print_with_buffer_mutations(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buf = torch.nn.Buffer(torch.ones(3))

            def forward(self, x):
                torch.ops.aten._print("moo")
                res = x + x
                self.buf.add_(res)
                res = self.buf + x
                torch.ops.aten._print("moo")
                return (res,)

        inputs = (torch.randn(3),)

        # With functionalization, it should appear wrapped with with_effects()
        gm, gs = aot_export_module(M(), inputs, trace_joint=False)
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, arg0_1, arg1_1, arg2_1):
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo');  arg0_1 = None
    getitem = with_effects[0];  with_effects = None
    add = torch.ops.aten.add.Tensor(arg2_1, arg2_1)
    add_1 = torch.ops.aten.add.Tensor(arg1_1, add);  arg1_1 = add = None
    add_2 = torch.ops.aten.add.Tensor(add_1, arg2_1);  arg2_1 = None
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
    getitem_2 = with_effects_1[0];  with_effects_1 = None
    return (getitem_2, add_1, add_2)""",
        )
        self.assertEqual(len(gs.input_tokens), 1)
        self.assertEqual(len(gs.output_tokens), 1)
        self.assertEqual(len(gs.buffers_to_mutate), 1)

    def test_print_with_input_mutations(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                torch.ops.aten._print("moo")
                res = x + x
                x.add_(res)
                res = x + x
                torch.ops.aten._print("moo")
                return (res,)

        inputs = (torch.randn(3),)

        # With functionalization, it should appear wrapped with with_effects()
        gm, gs = aot_export_module(M(), inputs, trace_joint=False)
        self.assertEqual(len(gs.input_tokens), 1)
        self.assertEqual(len(gs.output_tokens), 1)
        self.assertEqual(len(gs.user_inputs_to_mutate), 1)

    def test_alias_op(self):
        def f(token, x):
            token, out = with_effects(token, torch.ops.aten.absolute_.default, x)
            return token, out

        with self.assertRaisesRegex(
            AssertionError, r"Ops with aliasing is not supported"
        ):
            make_fx(f)(torch.tensor([]), torch.tensor(4))

    def test_compile_aot_eager(self):
        def f(x):
            torch.ops.aten._print("moo")
            res = x + x
            torch.ops.aten._print("moo")
            return res

        inputs = (torch.randn(2, 3),)

        res = torch.compile(f, backend="aot_eager")(*inputs)
        self.assertTrue(torch.allclose(res, f(*inputs)))

    @unittest.skipIf(IS_WINDOWS, "triton")
    @unittest.skipIf(not SM70OrLater, "triton")
    def test_compile_inductor(self):
        def f(x):
            torch.ops.aten._print("moo")
            res = x + x
            torch.ops.aten._print("moo")
            return res

        inputs = (torch.randn(2, 3),)

        res = torch.compile(f, backend="inductor")(*inputs)
        self.assertTrue(torch.allclose(res, f(*inputs)))

    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
    @skipIfNoDynamoSupport
    def test_compile_inductor_external_op_return_none(self):
        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
            torch.library.define(
                "mylib::inplace_add",
                "(Tensor input, Tensor(a!) output) -> ()",
                lib=lib,
            )

            def inplace_add(input: torch.Tensor, output: torch.Tensor) -> None:
                assert input.device == output.device
                output.add_(input)

            lib.impl("inplace_add", inplace_add, "CompositeExplicitAutograd")

            def f(x):
                out = torch.empty(3)
                out = torch.zeros_like(out)
                torch.ops.mylib.inplace_add(x, out)
                return out

            inputs = (torch.randn(3),)

            res = torch.compile(f, backend="inductor")(*inputs)
            self.assertTrue(torch.allclose(res, f(*inputs)))

    def test_compile_aot_eager_requires_grad(self):
        def f(x):
            torch.ops.aten._print("moo")
            res = x + x
            torch.ops.aten._print("moo")
            return res

        inputs = (torch.randn(2, 3, requires_grad=True),)

        res = torch.compile(f, backend="aot_eager")(*inputs)
        self.assertTrue(torch.allclose(res, f(*inputs)))

        res.sum().backward()

    @unittest.skipIf(IS_WINDOWS, "triton")
    @unittest.skipIf(TEST_WITH_ROCM, "triton")
    @unittest.skipIf(not SM80OrLater, "triton")
    @unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
    @unittest.skipIf(not TEST_CUDA, "triton")
    @skipIfNoDynamoSupport
    def test_register_effectful_custom_op(self):
        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
            torch._dynamo.config.capture_scalar_outputs = True
            torch._dynamo.config.capture_dynamic_output_shape_ops = True

            torch.library.define(
                "mylib::record_scalar_tensor",
                "(Tensor x, str prefix) -> ()",
                lib=lib,
            )

            # global variable to store the recorded tensor and prefix.
            recorded_dict = {}

            # Pytorch custorm op implementation
            @torch.library.impl(
                "mylib::record_scalar_tensor",
                "CompositeExplicitAutograd",
                lib=lib,
            )
            def record_scalar_tensor(x, prefix):
                recorded_dict[prefix] = x.clone()
                return

            # Meta function of the custom op
            @torch.library.impl_abstract(
                "mylib::record_scalar_tensor",
                lib=lib,
            )
            def record_scalar_tensor_meta(x, prefix):
                return

            from torch._higher_order_ops.effects import (
                _EffectType,
                _register_effectful_op,
            )

            _register_effectful_op(
                torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
            )

            my_config = {}
            my_config["MockModule"] = "mean"
            my_config["MockModule.linear"] = "mean"
            my_config["MockModule.relu"] = "mean"

            class MyLinear(torch.nn.Module):
                def __init__(self, in_features, out_features):
                    super().__init__()
                    self.weight = torch.nn.Parameter(
                        torch.randn(out_features, in_features), requires_grad=True
                    )
                    self.bias = torch.nn.Parameter(
                        torch.randn(out_features), requires_grad=True
                    )

                def forward(self, x):
                    return torch.nn.functional.linear(x, self.weight, self.bias)

            class MockModule(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.linear = MyLinear(10, 10)
                    self.register_buffer(
                        "buf0", torch.randn(10, 10, requires_grad=True)
                    )

                def forward(self, x):
                    return torch.nn.functional.relu(self.linear(x) + self.buf0)

            def forward_hook(
                module: torch.nn.Module,
                inputs: torch.Tensor,
                output: torch.Tensor,
                prefix: str,
                aggregate_method: str,
            ) -> torch.Tensor:
                if aggregate_method == "mean":
                    torch.ops.mylib.record_scalar_tensor(output.mean(), prefix)
                elif aggregate_method == "max":
                    torch.ops.mylib.record_scalar_tensor(output.max(), prefix)
                else:
                    # demo purpose, using "min"
                    torch.ops.mylib.record_scalar_tensor(output.sum(), prefix)
                return output

            def add_hooks(module, config):
                handles: List[RemovableHandle] = []
                q = deque([(module.__class__.__name__, module)])
                while q:
                    name, m = q.pop()
                    children = [(name + "." + n, y) for (n, y) in m.named_children()]
                    q.extend(children)
                    aggregate_method = config.get(name, "mean")
                    prefix = name + ":" + aggregate_method
                    handle = m.register_forward_hook(
                        partial(
                            forward_hook,
                            prefix=prefix,
                            aggregate_method=aggregate_method,
                        )
                    )
                    if handle:
                        handles.append(handle)
                return handles

            x = torch.randn(10, 10, device="cuda")
            mod = MockModule().to("cuda")

            add_hooks(mod, my_config)

            opt_mod = torch.compile(backend="inductor")(mod)
            y = opt_mod(x)

            self.assertTrue(torch.allclose(y, mod(x)))
            # Ensure it works well with backward
            y.sum().backward()
            # Ensure the grad is existing
            self.assertTrue(isinstance(opt_mod.linear.weight.grad, torch.Tensor))

            self.assertEqual(len(recorded_dict), 2)
            self.assertTrue("MockModule.linear:mean" in recorded_dict)
            self.assertTrue("MockModule:mean" in recorded_dict)

    @skipIfNoDynamoSupport
    def test_effectful_custom_op_with_subclasses(self):
        with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
            lib.define("zoo(Tensor x) -> Tensor")
            lib.define("zoo2(Tensor x) -> Tensor")

            d = {"fw": 0, "bw": 0}

            def reset_counter():
                d["fw"] = 0
                d["bw"] = 0

            def assert_counter(fw, bw):
                self.assertEqual(d["fw"], fw)
                self.assertEqual(d["bw"], bw)

            def foo_impl(a):
                d["fw"] = d["fw"] + 1
                return 2 * a.clone()

            def foo_meta(a):
                return a.clone()

            def foo2_impl(x):
                d["bw"] = d["bw"] + 1
                return x.clone()

            def foo2_meta(a):
                return a.clone()

            for backend in ["CPU", "CUDA"]:
                lib.impl("zoo", foo_impl, backend)
                lib.impl("zoo2", foo2_impl, backend)
            lib.impl("zoo", foo_meta, "Meta")
            lib.impl("zoo2", foo2_meta, "Meta")

            def foo_bwd(ctx, grad):
                torch.ops._mylib.zoo2(grad)
                return grad.clone()

            torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)

            from torch._higher_order_ops.effects import (
                _EffectType,
                _register_effectful_op,
            )

            _register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
            _register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)

            def fn(x, y):
                return torch.ops._mylib.zoo(x) + y

            def ins_sc():
                return (
                    TwoTensor(
                        torch.tensor([1.0, 2.0, 3.0]), torch.tensor([1.0, 2.0, 3.0])
                    ),
                    torch.tensor([4.0, 5.0, 6.0]),
                )

            def ins_dense():
                return torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])

            for i, (ins_fn, expected_fw_count) in enumerate(
                zip([ins_sc, ins_dense], [2, 1])
            ):
                reset_counter()
                ref_out = fn(*ins_fn())
                assert_counter(expected_fw_count, 0)

                compiled_fn = torch.compile(fn, backend="aot_eager")
                out = compiled_fn(*ins_fn())
                reset_counter()
                out = compiled_fn(*ins_fn())
                assert_counter(expected_fw_count, 0)

                self.assertEqual(ref_out, out)

            def ins_dense_req_grad():
                return (
                    torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
                    torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
                )

            def ins_sc_req_grad():
                return (
                    TwoTensor(
                        torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
                    ),
                    TwoTensor(
                        torch.tensor([7.0, 8.0, 9.0], requires_grad=True),
                        torch.tensor([10.0, 11.0, 12.0], requires_grad=True),
                    ),
                )

            for i, (
                ins_fn_req_grad,
                (
                    expected_fw_count,
                    expected_fw_count_after_bw,
                    expected_bw_count_after_bw,
                ),
            ) in enumerate(
                zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)])
            ):
                ref_ins = ins_fn_req_grad()
                reset_counter()
                ref_out = fn(*ref_ins)
                assert_counter(expected_fw_count, 0)
                ref_out.sum().backward()
                assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)

                compiled_fn = torch.compile(fn, fullgraph=True)

                ins = ins_fn_req_grad()
                out = compiled_fn(*ins)
                reset_counter()
                out = compiled_fn(*ins)
                assert_counter(expected_fw_count, 0)
                self.assertEqual(ref_out, out)
                out.sum().backward()
                assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
                self.assertEqual(ref_ins[1].grad, ins[1].grad)
                self.assertEqual(ref_ins[0].grad, ins[0].grad)

            fw_graph, bw_graph = get_fw_bw_graph(fn, ins_sc_req_grad())
            self.assertExpectedInline(
                fw_graph.code.strip(),
                """\
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5):
    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.zoo.default, primals_2);  primals_1 = primals_2 = None
    getitem = with_effects[0]
    getitem_1 = with_effects[1];  with_effects = None
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.zoo.default, primals_3);  getitem = primals_3 = None
    getitem_2 = with_effects_1[0]
    getitem_3 = with_effects_1[1];  with_effects_1 = None
    add = torch.ops.aten.add.Tensor(getitem_1, primals_4);  getitem_1 = primals_4 = None
    add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_5);  getitem_3 = primals_5 = None
    return (getitem_2, add, add_1)""",
            )
            self.assertExpectedInline(
                bw_graph.code.strip(),
                """\
def forward(self, tangents_1, tangents_2, tangents_token):
    with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.zoo2.default, tangents_1);  tangents_token = None
    getitem_4 = with_effects_2[0];  with_effects_2 = None
    with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.zoo2.default, tangents_2);  getitem_4 = None
    getitem_6 = with_effects_3[0];  with_effects_3 = None
    clone = torch.ops.aten.clone.default(tangents_1)
    clone_1 = torch.ops.aten.clone.default(tangents_2)
    return (clone, clone_1, tangents_1, tangents_2, getitem_6)""",
            )

    def test_effects_and_input_mutation_return(self):
        def fn(a, b):
            torch.ops.aten._print("effect")
            return torch.sin(a, out=b)

        inp = [torch.randn(3, 3), torch.ones(3, 3)]
        ref_out = fn(*inp)
        out = torch.compile(fn, fullgraph=True)(*inp)
        self.assertEqual(ref_out, out)

        fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, arg0_1, arg1_1, arg2_1):
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect');  arg0_1 = None
    getitem = with_effects[0];  with_effects = None
    sin = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
    return (getitem, sin, sin)""",
        )

    def test_effects_and_input_output_view_simple(self):
        def fn(a):
            return a.view(-1)

        inp = [torch.ones(2, 2, requires_grad=False).add(1)]
        ref_out = fn(*inp)
        out = torch.compile(fn, fullgraph=True)(*inp)
        self.assertEqual(ref_out, out)

        inp = [torch.ones(2, 2, requires_grad=True).add(1)]
        ref_out = fn(*inp)
        out = torch.compile(fn, fullgraph=True)(*inp)
        self.assertEqual(ref_out, out)

        fw_graph, bw_graph = get_fw_bw_graph(fn, inp)

        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, arg0_1):
    view = torch.ops.aten.view.default(arg0_1, [-1]);  arg0_1 = None
    return (view,)""",
        )

    def test_effects_and_aliased_outputs(self):
        def fn(a):
            b = a.mul(2)
            torch.ops.aten._print("effect")
            c = b.view(-1)
            return b, c

        f_compiled = aot_function(fn, nop)
        for req_grad in [True, False]:
            inp = torch.ones(3, requires_grad=req_grad)
            out_ref = fn(inp)
            out_test = f_compiled(inp)
            self.assertEqual(out_ref[0], out_test[0])
            self.assertEqual(out_ref[1], out_test[1])
            # Try mutating one of the outputs, which is aliased.
            out_ref[0].mul_(3)
            out_test[0].mul_(3)
            # Assert that the aliasing relationship was preserved
            self.assertEqual(out_ref[0], out_test[0])
            self.assertEqual(out_ref[1], out_test[1])

    def test_effects_and_input_mutation_is_output(self):
        def fn(a):
            a.mul_(2)
            torch.ops.aten._print("effect")
            return a

        inp = make_inputs_non_leaves([torch.ones(3, 3, requires_grad=True)])
        ref_out = fn(*inp)
        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
        self.assertEqual(ref_out, out)

        inp = [torch.ones(3, 3, requires_grad=False)]
        ref_out = fn(*inp)
        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
        self.assertEqual(ref_out, out)

        fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
        self.assertExpectedInline(
            fw_graph.code.strip(),
            """\
def forward(self, arg0_1, arg1_1):
    mul = torch.ops.aten.mul.Tensor(arg1_1, 2);  arg1_1 = None
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect');  arg0_1 = None
    getitem = with_effects[0];  with_effects = None
    return (getitem, mul, mul)""",
        )

    @skipIfTorchDynamo()
    def test_effectful_op_in_backward(self):
        with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
            lib.define("foo(Tensor x) -> Tensor")

            def foo_impl(a):
                return a.clone()

            def foo_bwd(ctx, grad):
                return torch.ops._mylib.foo(grad)

            for backend in ["CPU", "CUDA", "Meta"]:
                lib.impl("foo", foo_impl, backend)

            torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)

            from torch._higher_order_ops.effects import (
                _deregister_effectful_op,
                _EffectType,
                _register_effectful_op,
            )

            _register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
            try:

                def fn(x, y):
                    return torch.ops._mylib.foo(x) + y

                def ins_dense_req_grad():
                    return (
                        torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
                    )

                def ins_sc_req_grad():
                    return (
                        TwoTensor(
                            torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
                            torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
                        ),
                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
                    )

                for i, ins_fn in enumerate([ins_dense_req_grad, ins_sc_req_grad]):
                    ref_ins = ins_fn()

                    ref_out = fn(*ref_ins)
                    ref_out.sum().backward()

                    compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True)
                    ins = ins_fn()
                    out = compiled_fn(*ins)
                    self.assertEqual(ref_out, out)
                    out.sum().backward()
                    self.assertEqual(ref_ins[1].grad, ins[1].grad)
                    self.assertEqual(ref_ins[0].grad, ins[0].grad)

                    fw_graph, bw_graph = get_fw_bw_graph(fn, ins)
                    if i == 0:
                        self.assertExpectedInline(
                            fw_graph.code.strip(),
                            """\
def forward(self, primals_1, primals_2, primals_3):
    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2);  primals_1 = primals_2 = None
    getitem = with_effects[0]
    getitem_1 = with_effects[1];  with_effects = None
    add = torch.ops.aten.add.Tensor(getitem_1, primals_3);  getitem_1 = primals_3 = None
    return (getitem, add)""",
                        )
                        self.assertExpectedInline(
                            bw_graph.code.strip(),
                            """\
def forward(self, tangents_1, tangents_token):
    with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1);  tangents_token = None
    getitem_2 = with_effects_1[0]
    getitem_3 = with_effects_1[1];  with_effects_1 = None
    return (getitem_3, tangents_1, getitem_2)""",
                        )
                    elif i == 1:
                        self.assertExpectedInline(
                            fw_graph.code.strip(),
                            """\
def forward(self, primals_1, primals_2, primals_3, primals_4):
    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2);  primals_1 = primals_2 = None
    getitem = with_effects[0]
    getitem_1 = with_effects[1];  with_effects = None
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.foo.default, primals_3);  getitem = primals_3 = None
    getitem_2 = with_effects_1[0]
    getitem_3 = with_effects_1[1];  with_effects_1 = None
    add = torch.ops.aten.add.Tensor(getitem_1, primals_4);  getitem_1 = None
    add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_4);  getitem_3 = primals_4 = None
    return (getitem_2, add, add_1)""",
                        )
                        self.assertExpectedInline(
                            bw_graph.code.strip(),
                            """\
def forward(self, tangents_1, tangents_2, tangents_token):
    with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1);  tangents_token = None
    getitem_4 = with_effects_2[0]
    getitem_5 = with_effects_2[1];  with_effects_2 = None
    with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.foo.default, tangents_2);  getitem_4 = None
    getitem_6 = with_effects_3[0]
    getitem_7 = with_effects_3[1];  with_effects_3 = None
    return (getitem_5, getitem_7, tangents_1, tangents_2, getitem_6)""",
                        )
                    else:
                        raise NotImplementedError
            finally:
                _deregister_effectful_op(torch.ops._mylib.foo.default)

    @skipIfNoDynamoSupport
    def test_regular_effectful_op_only_in_backward(self):
        from torch._higher_order_ops.effects import (
            _deregister_effectful_op,
            _EffectType,
            _register_effectful_op,
        )

        _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
        try:

            def fn(x):
                return x.sin()

            def inps_fn():
                return (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)

            torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn())

            fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn())
            self.assertExpectedInline(
                fw_graph.code.strip(),
                """\
def forward(self, primals_1):
    sin = torch.ops.aten.sin.default(primals_1)
    return (sin, primals_1)""",
            )
            self.assertExpectedInline(
                bw_graph.code.strip(),
                """\
def forward(self, primals_1, tangents_1, tangents_token):
    with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1);  tangents_token = primals_1 = None
    getitem = with_effects[0]
    getitem_1 = with_effects[1];  with_effects = None
    mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1);  tangents_1 = getitem_1 = None
    return (mul, getitem)""",
            )

            def inps_fn_sc():
                return (
                    TwoTensor(
                        torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
                    ),
                )

            torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn_sc())
            fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn_sc())
            self.assertExpectedInline(
                fw_graph.code.strip(),
                """\
def forward(self, primals_1, primals_2):
    sin = torch.ops.aten.sin.default(primals_1)
    sin_1 = torch.ops.aten.sin.default(primals_2)
    return (sin, sin_1, primals_1, primals_2)""",
            )
            self.assertExpectedInline(
                bw_graph.code.strip(),
                """\
def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
    with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1);  tangents_token = primals_1 = None
    getitem = with_effects[0]
    getitem_1 = with_effects[1];  with_effects = None
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten.cos.default, primals_2);  getitem = primals_2 = None
    getitem_2 = with_effects_1[0]
    getitem_3 = with_effects_1[1];  with_effects_1 = None
    mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1);  tangents_1 = getitem_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(tangents_2, getitem_3);  tangents_2 = getitem_3 = None
    return (mul, mul_1, getitem_2)""",
            )
        finally:
            _deregister_effectful_op(torch.ops.aten.cos.default)

    @skipIfNoDynamoSupport
    def test_regular_effectful_op_in_forward_and_backward(self):
        from torch._higher_order_ops.effects import (
            _deregister_effectful_op,
            _EffectType,
            _register_effectful_op,
        )

        _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
        try:

            def fn(x):
                x = x.cos()
                return x.sin()

            inps = (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
            torch.compile(fn, backend="inductor", fullgraph=True)(*inps)

            fw_graph, bw_graph = get_fw_bw_graph(fn, inps)
            self.assertExpectedInline(
                fw_graph.code.strip(),
                """\
def forward(self, primals_1, primals_2):
    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.aten.cos.default, primals_2);  primals_1 = None
    getitem = with_effects[0]
    getitem_1 = with_effects[1];  with_effects = None
    sin = torch.ops.aten.sin.default(getitem_1)
    return (getitem, sin, primals_2, getitem_1)""",
            )
            self.assertExpectedInline(
                bw_graph.code.strip(),
                """\
def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
    with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, getitem_1);  tangents_token = getitem_1 = None
    getitem_2 = with_effects_1[0]
    getitem_3 = with_effects_1[1];  with_effects_1 = None
    mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_3);  tangents_1 = getitem_3 = None
    sin_1 = torch.ops.aten.sin.default(primals_2);  primals_2 = None
    neg = torch.ops.aten.neg.default(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(mul, neg);  mul = neg = None
    return (mul_1, getitem_2)""",
            )
        finally:
            _deregister_effectful_op(torch.ops.aten.cos.default)


if __name__ == "__main__":
    run_tests()
