# Owner(s): ["module: inductor"]
import contextlib
import re
from unittest.mock import patch

import functorch
import torch
import torch._inductor.config as config
import torch.autograd
from torch._inductor import metrics
from torch._inductor.compile_fx import compile_fx, compile_fx_inner
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code

########################
# Explanation of Tests #
########################
# These tests are all testing *memory accesses* of TorchInductor.
# They are intended to be deterministic performance tests.
# The expect tests are all measuring the number of memory bytes read/written by
# the code that Inductor has generated
#
# If the test is failing because the number became smaller, feel free to lower it.
# On the other hand, if the test is failing because the number became larger,
# that means that your change is leading to *more* memory accesses on this test.
#
# That may still be aceeptable, but be aware that you are likely lowering
# performance for that setting.
#
# Defines all the kernels for tests
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda


if HAS_CUDA:
    import triton
    import triton.language as tl

    from torch.testing._internal.triton_utils import add_kernel

aten = torch.ops.aten


def compile_but_use_eager(gm, example_inputs):
    def inner_compile(gm, *args, **kwargs):
        compile_fx_inner(gm, *args, **kwargs)
        return gm

    return compile_fx(gm, example_inputs, inner_compile=inner_compile)


def count_numel(f, *args):
    """
    Assumes all inputs are fp32
    """
    metrics.reset()
    torch.compile(f, backend=compile_but_use_eager)(*args)
    print(metrics.nodes_num_elem)
    return str(metrics.num_bytes_accessed // 4)


def count_numel_train(f, *args):
    """
    Assumes all inputs are fp32
    """
    metrics.reset()

    f = torch.compile(f, backend=compile_but_use_eager)
    out = f(*args)
    res = 0
    for o in out:
        res += o.mean()
    res.backward()
    print(metrics.nodes_num_elem)
    return str(metrics.num_bytes_accessed // 4)


DEVICE = "cuda"


def T(*size, dtype=torch.float32, device=DEVICE, grad=False):
    return torch.randn(size, dtype=dtype, device=device, requires_grad=grad)


def TI(*size, mx=10, dtype=torch.int32, device=DEVICE):
    return torch.randint(0, mx, size, dtype=dtype, device=device)


class TestCase(InductorTestCase):
    device = DEVICE


class NumBytesMetricTests(TestCase):
    """
    Primarily used for sanity testing that the num_bytes_accessed metrics is correct.
    """

    def test_pointwise(self):
        def f(x):
            return x.cos()

        inp = (T(10),)
        self.assertExpectedInline(count_numel(f, *inp), """20""")

        def f(x, y):
            return x + y

        inp = (T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """30""")

        def f(x, y):
            return x + y

        inp = (T(10, 10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """210""")

        def f(x):
            return x + x

        inp = (T(10),)
        self.assertExpectedInline(count_numel(f, *inp), """20""")

        def f(x):
            return x + x.t()

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """200""")

        def f(a, b, c):
            return a.cos(), b.sin() + c.sin()

        inp = (T(10), T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """50""")

    def test_reduction(self):
        def f(x):
            return x.sum(dim=1)

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """110""")

        def f(x):
            return x.sum(dim=0)

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """110""")

    def test_extern(self):
        def f(x):
            return torch.mm(x, x)

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """200""")

        def f(a, b):
            return torch.mm(a, b)

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """300""")

        def f(x):
            x = x.cos()
            x = torch.mm(x, x)
            x = x.cos()
            return x

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """600""")

        def f(x):
            a = x.cos()
            b = x.sin()
            x = torch.mm(a, b)
            return x

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """600""")

    def test_cat(self):
        def f(a, b):
            return torch.cat([a.sin(), b.sin()])

        inp = (T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """40""")

        def f(a, b):
            return torch.cat([a, b])

        inp = (T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """40""")

        def f(a, b):
            return torch.cat([a.cos(), b])

        inp = (T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """40""")

        def f(a):
            return torch.cat([a.cos(), a.sin()])

        inp = (T(10),)
        self.assertExpectedInline(count_numel(f, *inp), """30""")

        def f(a, b):
            return torch.cat([torch.mm(a, a), b.sin()])

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """400""")

        def f(a, b, c):
            return torch.cat((a + 1, b + 2, c + 3)) + 10

        inp = (T(10, 10), T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """600""")

        def f(a, b, c, d, e):
            return torch.cat((a + 1, b + 2, c + 3, d + 4, e + 5)) + 10

        inp = [T(10, 10) for _ in range(5)]
        self.assertExpectedInline(count_numel(f, *inp), """1000""")

        def f(a, b):
            return torch.cat([a.sum(dim=0), b.sum(dim=0)]) + 10

        inp = [T(10, 10, 10), T(10, 10, 10)]
        self.assertExpectedInline(count_numel(f, *inp), """2600""")

    def test_cat_pointwise(self):
        def f(a, b):
            return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)])

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """400""")

        def f(a, b):
            return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)]).cos()

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """680""")

        # Should turn into pointwise even if only some of inputs are pointwise.
        def f(a, b):
            out = torch.cat([a.cos(), torch.mm(b, b)])
            return out.cos()

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """600""")

        # Should not turn into pointwise if all inputs are not pointwise
        def f(a, b):
            out = torch.cat([torch.mm(a, a), torch.mm(b, b)])
            return out.cos()

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """800""")

        def f(a, b):
            out = torch.cat([a, b])
            return out.cos()

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """400""")

        def f(a, b):
            b = b.cos()
            return torch.cat([a, b])

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """400""")

        def f(a, b):
            a = a @ a
            return torch.constant_pad_nd(torch.cat([a, b]), [2, 2], 0.5)

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """680""")

    @patch.object(config, "split_cat_fx_passes", False)
    @patch.object(
        config,
        "pre_grad_fusion_options",
        {
            "batch_linear": {},
            "batch_linear_lhs": {},
            "batch_layernorm": {},
            "batch_tanh": {},
            "batch_relu": {},
            "batch_sigmoid": {},
        },
    )
    @patch.object(config, "post_grad_fusion_options", {})
    def test_cat_pointwise_many_complex_inputs(self):
        def f(*inputs):
            input = [torch.nn.functional.gelu(val) for val in inputs]
            return torch.cat(input) + 10

        inp = (T(10, 10) for _ in range(16))
        self.assertExpectedInline(count_numel(f, *inp), """6400""")

    @patch.object(config, "split_cat_fx_passes", False)
    @patch.object(
        config,
        "pre_grad_fusion_options",
        {
            "batch_linear": {},
            "batch_linear_lhs": {},
            "batch_layernorm": {},
            "batch_tanh": {},
            "batch_relu": {},
            "batch_sigmoid": {},
        },
    )
    @patch.object(config, "post_grad_fusion_options", {})
    def test_cat_pointwise_many_simple_inputs(self):
        def f(*inputs):
            input = [torch.nn.functional.relu(val) for val in inputs]
            return torch.cat(input) + 10

        inp = (T(10, 10) for _ in range(16))
        self.assertExpectedInline(count_numel(f, *inp), """9600""")

    @patch.object(config, "max_pointwise_cat_inputs", 0)
    def test_cat_pointwise_config_option(self):
        def f(a, b):
            return torch.cat([a + 1, b + 2]) + 3

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """400""")

    def test_index(self):
        def f(a, b):
            return a[b]

        inp = (T(10), TI(10, mx=10))
        self.assertExpectedInline(count_numel(f, *inp), """30""")


class FusionTests(TestCase):
    """
    Tests that things can be fused into a single kernel
    """

    def test_horizontal_reduction_pointwise(self):
        def f(a):
            b = a.sum(dim=1)
            c = a.cos()
            return b, c

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """210""")

    def test_horizontal_reduction_reduction(self):
        def f(a):
            b = a.sum(dim=1)
            c = a.amax(dim=1)
            return b, c

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """120""")

    def test_horizontal_reduction_pointwise2(self):
        def f(a, b):
            c = a.sum(dim=1)
            b = b.cos()
            return b + c

        inp = (T(10, 10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """120""")

    def test_horizontal_reduction_outer_pointwise(self):
        def f(a, b):
            c = a.sum(dim=0)
            b = b.cos()
            return b + c

        inp = (T(10, 10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """120""")

    def test_horizontal_sum_pw_broadcast(self):
        def f(a, b):
            a = a.sum(dim=1, keepdim=True)
            b = b.cos()
            return a * b

        inp = (T(10, 10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """210""")

    def test_vertical_sum_pw(self):
        def f(a):
            a = a.cos()
            a = a.sum(dim=1)
            return a.cos()

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """110""")

    def test_norm_chain(self):
        def f(a):
            b = a.sum(dim=1, keepdim=True)
            a = a * b
            b = a.sum(dim=1, keepdim=True)
            a = a * b
            b = a.sum(dim=1, keepdim=True)
            a = a * b
            return a

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """200""")

    def test_softmax_inner(self):
        def f(a):
            return torch.softmax(a, dim=1)

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """200""")

    def test_layer_norm(self):
        # TODO: Suboptimal! We shouldn't need to save normalization stats.
        mod = torch.nn.LayerNorm(10, device=self.device)

        def f(x):
            return mod(x)

        inp = (T(10, 10),)
        with torch.no_grad():
            self.assertExpectedInline(count_numel(f, *inp), """220""")

    def test_double_softmax(self):
        def f(x):
            x = torch.softmax(x, dim=1)
            x = torch.softmax(x, dim=1)
            return x

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """200""")

    def test_softmax_backward(self):
        def f(grad_out, out):
            return aten._softmax_backward_data(grad_out, out, 1, torch.float32)

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """300""")

    def test_neighbor(self):
        def f(a, b):
            return ((a - b) ** 2).sum(dim=-1).amax(dim=1)

        inp = (T(10, 1, 4), T(1, 10, 4))
        self.assertExpectedInline(count_numel(f, *inp), """90""")

    def test_factory_reduction(self):
        def f():
            a = torch.ones(10, device=self.device)
            b = torch.ones(10, 10, device=self.device)
            return (a + b).sum(dim=-1)

        inp = ()
        self.assertExpectedInline(count_numel(f, *inp), """10""")

    def test_index_pointwise(self):
        def f(a, b):
            return a[b].cos()

        inp = (T(10, 10), TI(20, mx=10))
        self.assertExpectedInline(count_numel(f, *inp), """320""")

    def test_index_reduction(self):
        def f(a, b):
            return a[b].cos().sum(dim=1)

        inp = (T(10, 10), TI(20, mx=10))
        self.assertExpectedInline(count_numel(f, *inp), """140""")

    def test_mutation_fusion(self):
        def f(a, b, c):
            a0 = a.add(c)
            b0 = b.add(a0)
            b.copy_(b0)
            a.copy_(a0)

        inp = (T(10, 10), T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """500""")

    def test_reduction_pointwise_multi_level_reduction(self):
        hidden_size = 4096
        layer_norm = torch.nn.LayerNorm(hidden_size).cuda().float()

        @torch.inference_mode()
        def f(x, scale, amax_keep_dim):
            x = layer_norm(x.to(dtype=torch.float))
            amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim)
            x_scaled = x * scale
            y = torch.nn.functional.sigmoid(x_scaled)
            return (y, amax)

        inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float))

        # 2 kernels:
        # kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), first-level amax (split-reduction))
        # kernel 2: (input = first-level amax, output = final amax)
        # scale (1) + X (4*2048*hidden_size) * 2 + LN scale (hidden_size) + LN bias (hidden_size) + amax (4 * 2048 * 2 + 1)
        expected_numel = (
            1 + hidden_size * 2 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
        )
        self.assertExpectedInline(count_numel(f, *inp, True), str(expected_numel))
        self.assertExpectedInline(count_numel(f, *inp, False), str(expected_numel))

    def test_pointwise_multi_level_reduction(self):
        # TODO: this can be optimized by having the first pointwise kernel leveraging block sizes
        # of the first-level reduction kernel.
        hidden_size = 4096

        def f(x, scale, amax_keep_dim):
            x = x * 1.1
            amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim)
            x_scaled = x * scale
            y = torch.nn.functional.sigmoid(x_scaled)
            return (y, amax)

        inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float))

        compiled_f = torch.compile(f)
        compiled_f(*inp, True)

        # 3 kernels:
        # kernel 1: (input = X, scale, output = pointwise(X))
        # kernel 2: (input = X, output = first-level amax)
        # kernel 3: (input = first-level amax, output = final amax)
        # scale (1) + X (4*2048*hidden_size) * 3 + amax (num_splits * 2 + 1)
        # num_splits depends on SM architectures.
        expected_numel = 1 + 4 * 2048 * hidden_size * 3 + 1
        actual_numel_amax_keep_dim = count_numel(f, *inp, True)
        actual_numel_amax_no_keep_dim = count_numel(f, *inp, False)
        self.assertEqual(actual_numel_amax_keep_dim, actual_numel_amax_no_keep_dim)
        self.assertGreaterAlmostEqual(actual_numel_amax_keep_dim, str(expected_numel))


class SchedulerFusionTests(TestCase):
    """
    Testing the fusion group creation heuristic (i.e. cases where we can't fuse
    everything into a single kernel)
    Disables inductor rematerialization for easier reasoning of tests.
    """

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls._stack = contextlib.ExitStack()
        cls._stack.enter_context(patch.object(config, "realize_opcount_threshold", 0))

    @classmethod
    def tearDownClass(cls):
        cls._stack.close()
        super().tearDownClass()

    @patch.object(config, "pattern_matcher", False)
    def test_fusion_choice1(self):
        # Doesn't matter where we break fusion group here
        def f(a):
            c = a.cos()
            d = torch.mm(c, c)
            e = c.cos()
            return d + e

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """700""")

    @patch.object(config, "pattern_matcher", False)
    def test_fusion_choice2(self):
        # We should materialize e (it's smaller!)
        # [c, e]: 210, [f]: 210, [d]: 200
        def f(a):
            c = a.cos()
            d = torch.mm(c, c)
            e = c.sum(dim=1)
            f = d + e
            return f

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """620""")

    @patch.object(config, "pattern_matcher", False)
    def test_fusion_choice3(self):
        # We should materialize e.
        # [c, e]: 300, [f]: 300, [d]: 200
        def f(a):
            c = a.cos()
            d = torch.mm(c, c)
            e = c + a
            f = d + e
            return f, e

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """800""")

    @patch.object(config, "pattern_matcher", False)
    def test_fusion_choice4_cpu(self):
        # Fuse nodes with same number of elements and compatible orginal var ranges
        # [buf0: {d0: 60, d1: 11}, buf1: {d0: 660}] -> buf0_buf1
        def f(x, w):
            o1 = x * w
            output = o1 + 1.0
            return output

        inp = (T(2, 3, 10, 11, device="cpu"), T(11, device="cpu"))
        self.assertExpectedInline(count_numel(f, *inp), """1331""")

        # [buf0_buf1: {d0: 60, d1: 11}, buf2: {d0: 660}] -> buf0_buf1_buf2
        def f(x, w1, w2):
            o1 = x * w1
            o2 = x * w2
            output = o1 + o2
            return output

        inp = (T(2, 3, 10, 11, device="cpu"), T(11, device="cpu"), T(11, device="cpu"))
        self.assertExpectedInline(count_numel(f, *inp), """1342""")


class TilingTests(TestCase):
    def test_tiling_simple(self):
        def f(a, b):
            return a + b.t()

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """300""")

        def f(a, b):
            return a.t() + b

        inp = (T(10, 10), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """300""")

    def test_tiling_three(self):
        def f(a, b, c):
            return a + b.permute(1, 2, 0) + c.permute(2, 0, 1)

        inp = (T(10, 10, 10), T(10, 10, 10), T(10, 10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """4000""")


class MinCutPartitioningTests(TestCase):
    def test_partitioning_full_remat(self):
        def f(x):
            return x.cos().cos().cos()

        inp = (T(10, grad=True),)
        self.assertExpectedInline(count_numel_train(f, *inp), """50""")

    def test_partitioning_partial_remat(self):
        def f(a, b, c, d):
            x = a + b + c + d
            return x.cos().cos()

        inp = (T(10, grad=True), T(10, grad=True), T(10, grad=True), T(10, grad=True))
        self.assertExpectedInline(count_numel_train(f, *inp), """90""")

    def test_partitioning_dtype(self):
        def f(x):
            return (x < 0) * x

        inp = (T(100, grad=True),)
        self.assertExpectedInline(count_numel_train(f, *inp), """450""")

    @patch.object(functorch.compile.config, "max_dist_from_bw", 1000)
    def test_partitioning_unremat_bw(self):
        def f(x):
            return torch.mm(x, x.new_ones(x.shape)).tanh().tanh()

        inp = (T(10, 10, grad=True),)
        self.assertExpectedInline(count_numel_train(f, *inp), """1300""")

    @patch.object(config, "pattern_matcher", False)
    def test_partitioning_unremat_bw2(self):
        def f(a):
            a = torch.mm(a, a)
            a = a + 1
            b = a + 2
            c = torch.mm(a, b)
            return c

        inp = (T(10, 10, grad=True),)
        self.assertExpectedInline(count_numel_train(f, *inp), """2600""")

    def test_partitioning_keops(self):
        def f(a, b):
            return (a * b).cos().sum(dim=1)

        inp = (T(20, 1, grad=True), T(1, 20, grad=True))
        self.assertExpectedInline(count_numel_train(f, *inp), """220""")

    def test_partitioning_cat(self):
        def f(a, b):
            a = torch.tanh(a)
            return torch.cat([a, b])

        inp = (T(10, grad=True), T(10, grad=True))
        self.assertExpectedInline(count_numel_train(f, *inp), """70""")

    def test_partitioning_with_view(self):
        class Foo(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                y = x.sin()
                x = x.cos()
                x = x.view(10, 10)
                ctx.save_for_backward(x, y)
                x = x.cos()
                return x

            @staticmethod
            def backward(ctx, gradOut):
                x, y = ctx.saved_tensors
                return torch.mm(gradOut, x).view(100) * y

        def f(a):
            return Foo.apply(a)

        inp = (T(100, grad=True),)
        # We do not want to recompute the x.cos().view() chain, as it's
        # materialized in backwards
        self.assertExpectedInline(count_numel_train(f, *inp), """900""")

    @patch.object(config, "pattern_matcher", False)
    def test_partitioning_long_chain_add(self):
        def f(x):
            orig = x
            for _ in range(2):
                x = x * x
                x = torch.mm(x, x)
                x = x * 2
                x = orig + x
                orig = x
            return x

        inp = (T(10, 10, grad=True),)
        self.assertExpectedInline(count_numel_train(f, *inp), """3900""")


def unfusible(x):
    # For the purpose of noop tests, we want inductor to fall back to
    # eager mode, so, below we must use a aten operator that does not
    # have decomposition nor lowering:
    return aten._lazy_clone(x)


class NoopTests(TestCase):
    def test_noop_clones(self):
        def f(a):
            b = a.clone()
            b = unfusible(b)
            return b

        inp = T(10)
        self.assertExpectedInline(count_numel(f, inp), """20""")

        def f(a):
            b = a.clone()
            c = unfusible(b)
            return b, c

        self.assertExpectedInline(count_numel(f, inp), """40""")

    def test_noop_slice_scatter(self):
        def f(a):
            b = aten.slice_scatter(a, a)
            c = unfusible(b)
            return c

        inp = T(10)
        self.assertExpectedInline(count_numel(f, inp), """20""")

    def test_noop_dtype_conversion(self):
        def f(a):
            b = torch.ops.prims.convert_element_type(a, torch.float32)
            c = unfusible(b)
            return c

        inp = T(10)
        self.assertExpectedInline(count_numel(f, inp), """20""")

    def test_noop_device_conversion(self):
        def f(a):
            b = torch.ops.prims.device_put(a, "cuda")
            c = unfusible(b)
            return c

        inp = T(10)
        self.assertExpectedInline(count_numel(f, inp), """20""")

    def test_noop_int_ops(self):
        def f1(a):
            b = torch.ceil(a)
            c = unfusible(b)
            return c

        def f2(a):
            d = torch.floor(a)
            e = unfusible(d)
            return e

        def f3(a):
            f = torch.round(a)
            g = unfusible(f)
            return g

        def f4(a):
            f = torch.pow(a, 1)
            g = unfusible(f)
            return g

        inp = TI(10)
        self.assertExpectedInline(count_numel(f1, inp), """20""")
        self.assertExpectedInline(count_numel(f2, inp), """20""")
        self.assertExpectedInline(count_numel(f3, inp), """20""")
        self.assertExpectedInline(count_numel(f4, inp), """20""")

    def test_noop_cat(self):
        def f1(a):
            b = torch.cat([a])
            return unfusible(b)

        inp = T(10)
        self.assertExpectedInline(count_numel(f1, inp), """20""")

        def f2(a):
            b = torch.cat([a])
            c = torch.cat([b])
            return c

        self.assertExpectedInline(count_numel(f2, inp), """20""")


class InplacingTests(TestCase):
    def test_inplace_scatter(self):
        def f(a, b):
            a = a.cos()
            a[b] = 1
            return a

        inp = (T(10), TI(2, mx=5))
        self.assertExpectedInline(count_numel(f, *inp), """26""")

        def f(a, b):
            out = aten.index_put(a, (b,), torch.tensor(1.0))
            return a.copy_(out)

        inp = (T(10), TI(2, mx=5))
        self.assertExpectedInline(count_numel(f, *inp), """6""")

        def f(a, b):
            out = aten._unsafe_index_put(a, (b,), torch.tensor(1.0))
            return a.copy_(out)

        inp = (T(10), TI(2, mx=5))
        self.assertExpectedInline(count_numel(f, *inp), """6""")

    def test_inplace_scatter_noop_view(self):
        def f(a, b):
            a[:, b] = 1
            return a

        inp = (T(10, 10), TI(2, mx=5))
        self.assertExpectedInline(count_numel(f, *inp), """42""")

    @requires_cuda
    def test_inplace_triton_kernel_training(self):
        @triton.jit
        def sin_kernel(
            in_ptr0,
            out_ptr,
            n_elements,
            BLOCK_SIZE: "tl.constexpr",
        ):
            pid = tl.program_id(axis=0)
            block_start = pid * BLOCK_SIZE
            offsets = block_start + tl.arange(0, BLOCK_SIZE)
            mask = offsets < n_elements
            x = tl.load(in_ptr0 + offsets, mask=mask)
            output = tl.sin(x)
            tl.store(out_ptr + offsets, output, mask=mask)

        def sin_triton(x, out):
            n_elements = x.numel()
            sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)

        factory_op = torch.empty_like

        class MySin(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                out = factory_op(x)
                sin_triton(x, out)
                ctx.save_for_backward(out)
                return out

            @staticmethod
            def backward(ctx, grad):
                (saved,) = ctx.saved_tensors
                out = factory_op(grad)
                sin_triton(saved, out)
                return out

        def f(x):
            return MySin.apply(x)

        x = T(3, grad=True)
        self.assertExpectedInline(count_numel_train(f, x), """9""")

    @requires_cuda
    def test_inplace_custom_op_training_two_mutated_inputs(self):
        @torch.library.custom_op(
            "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}
        )
        def sin_cos(
            x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor
        ) -> None:
            out_sin.copy_(x.sin())
            out_cos.copy_(x.cos())

        def f(x):
            out0 = torch.empty_like(x)
            out1 = torch.empty_like(x)
            sin_cos(x, out0, out1)
            return x.clone(), out0, out1

        x = T(3, grad=True)
        self.assertExpectedInline(count_numel(f, x), """21""")

    @requires_cuda
    def test_inplace_custom_op_training(self):
        @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"})
        def sin(x: torch.Tensor, result: torch.Tensor) -> None:
            result.copy_(x.sin())

        factory_op = torch.empty_like

        class MySin(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                out = factory_op(x)
                sin(x, out)
                ctx.save_for_backward(out)
                return out

            @staticmethod
            def backward(ctx, grad):
                (saved,) = ctx.saved_tensors
                out = factory_op(grad)
                sin(saved, out)
                return out

        def f(x):
            return MySin.apply(x)

        x = T(3, grad=True)
        self.assertExpectedInline(count_numel_train(f, x), """9""")

    @requires_cuda
    def test_inplace_custom_op(self):
        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
            m.define("foo(Tensor x, Tensor(a!) out) -> ()")

            def foo(x: torch.Tensor, out: torch.Tensor) -> None:
                out.copy_(x.sin())

            m.impl("foo", foo, "CompositeExplicitAutograd")

            def f(x, out):
                torch.ops.mylib.foo(x, out)
                torch.ops.mylib.foo(out, out)
                torch.ops.mylib.foo(out, out)
                return out

            x = T(3)
            out = T(3)

            compiled_out, (code,) = run_and_get_code(
                torch.compile(f, fullgraph=True), x, out
            )
            self.assertEqual(compiled_out, x.sin().sin().sin())

            # Check that we are allocating the minimum number of intermediate buffers
            matches = re.findall(r"empty_strided_\w+\(", code)
            self.assertEqual(len(matches), 0)

            self.assertExpectedInline(count_numel(f, x, out), """21""")

    @requires_cuda
    def test_inplace_custom_op_intermediate(self):
        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
            m.define("foo(Tensor x, Tensor(a!) out) -> ()")

            def foo(x: torch.Tensor, out: torch.Tensor) -> None:
                out.copy_(x.sin())

            m.impl("foo", foo, "CompositeExplicitAutograd")

            def f(x, out):
                out = torch.empty_like(x)
                torch.ops.mylib.foo(x, out)
                torch.ops.mylib.foo(out, out)
                torch.ops.mylib.foo(out, out)
                return out

            x = T(3)
            out = T(3)

            compiled_out, (code,) = run_and_get_code(
                torch.compile(f, fullgraph=True), x, out
            )
            self.assertEqual(compiled_out, x.sin().sin().sin())

            # Check that we are allocating the minimum number of intermediate buffers
            matches = re.findall(r"empty_strided_\w+\(", code)
            self.assertEqual(len(matches), 1)

            self.assertExpectedInline(count_numel(f, x, out), """21""")

    @requires_cuda
    def test_inplace_custom_op_two_mutated_inputs(self):
        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
            m.define("foo(Tensor q, Tensor(a!) k_cache, Tensor(b!) v_cache) -> Tensor")

            def foo(q, k_cache, v_cache):
                k_cache.add_(1)
                v_cache.add_(1)
                return q + 1

            m.impl("foo", foo, "CompositeExplicitAutograd")

            q = T(3)
            k_cache = T(3)
            v_cache = torch.rand_like(k_cache)

            def f():
                x = 0
                for _ in range(2):
                    x = x + torch.ops.mylib.foo(q, k_cache, v_cache)
                return x

            compiled_out, (code,) = run_and_get_code(
                torch.compile(f, fullgraph=True),
            )

            # Check that we are allocating the minimum number of intermediate buffers
            matches = re.findall(r"empty_strided_\w+\(", code)
            self.assertEqual(len(matches), 1)

            self.assertExpectedInline(count_numel(f), """39""")

    @requires_cuda
    def test_inplace_triton_kernel_v1(self):
        def f(x: torch.Tensor, y: torch.Tensor):
            output = torch.zeros_like(x)
            n_elements = output.numel()
            grid = (n_elements,)
            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
            return output

        inp = (T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """50""")

    @requires_cuda
    def test_inplace_triton_kernel_v2(self):
        def f(x: torch.Tensor, y: torch.Tensor):
            output = torch.zeros_like(x)
            n_elements = output.numel()
            grid = (n_elements,)
            tmp = torch.add(x, 1)
            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
            return output, tmp

        inp = (T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """70""")

    @requires_cuda
    def test_inplace_triton_kernel_v3(self):
        def f(x: torch.Tensor, y: torch.Tensor):
            output = torch.zeros_like(x)
            n_elements = output.numel()
            grid = (n_elements,)
            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
            x.add_(1)
            return output

        inp = (T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """80""")

    @requires_cuda
    def test_inplace_triton_kernel_v4(self):
        def f(x: torch.Tensor, y: torch.Tensor):
            x_view = x.view(-1)
            output = torch.zeros_like(x)
            n_elements = output.numel()
            grid = (n_elements,)
            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
            output2 = x_view.mul(2)
            return output, output2

        inp = (T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """70""")

    @requires_cuda
    def test_inplace_triton_kernel_v5(self):
        def f(x: torch.Tensor, y: torch.Tensor):
            x_view = x.view(-1)
            output = torch.zeros_like(x)
            n_elements = output.numel()
            grid = (n_elements,)
            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
            x_view.mul_(2)
            return output

        inp = (T(10), T(10))
        self.assertExpectedInline(count_numel(f, *inp), """80""")

    @requires_cuda
    def test_inplace_triton_kernel_v6(self):
        def f(x: torch.Tensor, y: torch.Tensor):
            output = torch.zeros_like(x)
            n_elements = output.numel()
            grid = (n_elements,)
            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
            return output

        t = T(10)
        inp = (t, t.view(-1))
        self.assertExpectedInline(count_numel(f, *inp), """50""")

    def test_inplace_randperm_scatter(self):
        def scaled_index_add(x, y, scale_y):
            index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
            out = x.index_add_(dim=0, source=y * scale_y, index=index)
            return out

        inp = (T(10, 10), T(5, 10), T(10))
        self.assertExpectedInline(count_numel(scaled_index_add, *inp), """250""")


# Test cases where we don't do the right thing yet.
class WouldBeNiceIfItWorked:
    def test_horizontal(self):
        def f(a):
            b = a.sum(dim=0)
            c = a.cos()
            return b, c

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """210""")

    # TODO: We aren't fusing outer dim softmaxes
    def test_softmax_outer(self):
        def f(a):
            return torch.softmax(a, dim=0)

        inp = (T(10, 10),)
        self.assertExpectedInline(count_numel(f, *inp), """200""")

    # TODO: The greedy fusion strategy results in suboptimal grouping
    @patch.object(config, "realize_opcount_threshold", 0)
    def test_fusion_choice4(self):
        def f(a, b, b2):
            c = a + b
            d = torch.mm(c, c)
            e = c + b + b2
            f = d + e + b2
            return f, e

        inp = (T(10, 10), T(10, 10, dtype=torch.float16), T(10, 10))
        self.assertExpectedInline(count_numel(f, *inp), """1000""")

    # TODO: We materialize the intermediate if we don't unroll the reduction
    def test_neighbor(self):
        def f(a, b):
            return ((a - b) ** 2).sum(dim=-1).amax(dim=1)

        inp = (T(10, 1, 8), T(1, 10, 8))
        self.assertExpectedInline(count_numel(f, *inp), """170""")


if __name__ == "__main__":
    from torch._inductor.test_case import run_tests

    if HAS_CUDA:
        run_tests(needs="filelock")
