# Owner(s): ["module: c10d"]
import threading
import unittest
from typing import List

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch._C import FileCheck
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
from torch.distributed._functional_collectives import (
    all_gather_into_tensor_coalesced,
    all_gather_tensor,
    all_reduce,
    all_reduce_coalesced,
    all_to_all_single,
    AsyncCollectiveTensor,
    reduce_scatter_tensor,
    reduce_scatter_tensor_coalesced,
)
from torch.testing._internal.common_distributed import (
    MultiProcessTestCase,
    requires_nccl,
    skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
    run_tests,
    TestCase,
)
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.utils._triton import has_triton


def load_test_module(name):
    import sys
    from importlib.machinery import SourceFileLoader
    from pathlib import Path
    from unittest import mock

    testdir = Path(__file__).absolute().parent.parent
    with mock.patch("sys.path", [*sys.path, str(testdir)]):
        return SourceFileLoader(
            name, str(testdir / f"{name.replace('.', '/')}.py")
        ).load_module()


AOTIRunnerUtil = load_test_module("inductor.test_aot_inductor_utils").AOTIRunnerUtil

import sys


if not dist.is_available():
    print("distributed package not available, skipping tests", file=sys.stderr)
    sys.exit(0)


@requires_nccl()
class TestWithNCCL(MultiProcessTestCase):
    def setUp(self) -> None:
        super().setUp()
        self._spawn_processes()

    @property
    def world_size(self) -> int:
        return 2

    @property
    def ranks(self) -> List[int]:
        return list(range(self.world_size))

    @property
    def device(self) -> torch.device:
        return torch.device(f"cuda:{self.rank}")

    def _init_process_group(self) -> None:
        # Allow testing aoti after torch.compile
        torch._inductor.config.triton.store_cubin = True
        torch._inductor.config.debug = True

        torch.cuda.set_device(self.device)
        store = dist.FileStore(self.file_name, self.world_size)
        dist.init_process_group(
            backend="nccl",
            world_size=self.world_size,
            rank=self.rank,
            store=store,
        )
        torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)

    @skip_if_lt_x_gpu(2)
    def test_all_reduce_single(self) -> None:
        self._init_process_group()

        input = torch.full((10, 10), float(self.rank), device=self.device)
        output = torch.ops._c10d_functional.all_reduce(
            input,
            "avg",
            "default",
        )
        output = torch.ops._c10d_functional.wait_tensor(output)
        assert id(output) != id(input)
        expect = sum(self.ranks) / self.world_size
        assert output.eq(expect).all()

        # Test Python API and AsyncCollectiveTensor
        output = all_reduce(
            input,
            "avg",
            "default",
        )
        assert isinstance(output, AsyncCollectiveTensor)
        assert not output.completed
        assert output.eq(expect).all()
        assert output.completed

    @skip_if_lt_x_gpu(2)
    def test_all_reduce_single_(self) -> None:
        self._init_process_group()

        input = torch.full((10, 10), float(self.rank), device=self.device)
        output = torch.ops._c10d_functional.all_reduce_(
            input,
            "avg",
            "default",
        )
        output = torch.ops._c10d_functional.wait_tensor(output)
        assert id(output) == id(input)
        expect = sum(self.ranks) / self.world_size
        assert output.eq(expect).all()

    @skip_if_lt_x_gpu(2)
    def test_all_reduce_coalesced(self) -> None:
        self._init_process_group()

        inputs = [
            torch.full((i, i), float(self.rank * i), device=self.device)
            for i in range(10)
        ]
        outputs = torch.ops._c10d_functional.all_reduce_coalesced(
            inputs,
            "avg",
            "default",
        )
        for i, (output, input) in enumerate(zip(outputs, inputs)):
            output = torch.ops._c10d_functional.wait_tensor(output)
            assert id(output) != id(input)
            assert output.eq(sum(self.ranks) / self.world_size * i).all()

        # Test Python API and AsyncCollectiveTensor
        outputs = all_reduce_coalesced(
            inputs,
            "avg",
            "default",
        )
        for i, (output, input) in enumerate(zip(outputs, inputs)):
            assert not output.completed
            assert output.eq(sum(self.ranks) / self.world_size * i).all()
            assert output.completed

    @skip_if_lt_x_gpu(2)
    def test_all_reduce_coalesced_(self) -> None:
        self._init_process_group()

        inputs = [
            torch.full((i, i), float(self.rank * i), device=self.device)
            for i in range(10)
        ]
        outputs = torch.ops._c10d_functional.all_reduce_coalesced_(
            inputs,
            "avg",
            "default",
        )
        for i, (output, input) in enumerate(zip(outputs, inputs)):
            output = torch.ops._c10d_functional.wait_tensor(output)
            assert id(output) == id(input)
            assert output.eq(sum(self.ranks) / self.world_size * i).all()

    @skip_if_lt_x_gpu(2)
    def test_all_gather_into_tensor_single(self) -> None:
        self._init_process_group()

        input = torch.full((10, 10), float(self.rank), device=self.device)
        output = torch.ops._c10d_functional.all_gather_into_tensor(
            input,
            self.world_size,
            "default",
        )
        output = torch.ops._c10d_functional.wait_tensor(output)
        expect = torch.cat(
            [
                torch.full((10, 10), float(rank), device=self.device)
                for rank in self.ranks
            ]
        )
        assert torch.allclose(output, expect)
        assert output.eq(expect).all()

        # Test out-variant of all_gather_into_tensor
        output = torch.empty(expect.shape, device=self.device)
        output = torch.ops._c10d_functional.all_gather_into_tensor_out(
            input,
            self.world_size,
            "default",
            out=output,
        )
        output = torch.ops._c10d_functional.wait_tensor(output)
        assert torch.allclose(output, expect)
        assert output.eq(expect).all()

        # Test Python API and AsyncCollectiveTensor
        output = all_gather_tensor(
            input,
            0,
            "default",
        )
        assert isinstance(output, AsyncCollectiveTensor)
        assert not output.completed
        assert output.eq(expect).all()
        assert output.completed

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @skip_if_lt_x_gpu(2)
    # https://github.com/pytorch/pytorch/issues/126338
    def test_inductor_dtypeview_memory_leak(self):
        self._init_process_group()

        def func(arg: torch.Tensor) -> torch.Tensor:
            ag0 = torch.ops._c10d_functional.all_gather_into_tensor.default(
                arg,
                self.world_size,
                "default",
            )
            ag0_view = torch.ops.aten.view.dtype(ag0, torch.int32)
            return funcol.wait_tensor(ag0_view)

        arg = torch.full(
            (10, 10),
            float(self.rank),
            device=self.device,
            dtype=torch.float32,
        )
        compiled = torch.compile(func)
        mem_usage = {}
        # check if the aten.view.dtype is compiled to aten.view.dtype
        code = run_and_get_triton_code(compiled, arg)
        (
            FileCheck()
            .check("torch.ops._c10d_functional.wait_tensor.default(aten.view.dtype")
            .run(code)
        )
        # check memory leak
        for i in range(1, 10):
            mem_usage[i] = torch.cuda.max_memory_allocated()
            compiled(arg)

        assert mem_usage[9] == mem_usage[8]

    @skip_if_lt_x_gpu(2)
    def test_all_gather_into_tensor_coalesced(self) -> None:
        self._init_process_group()

        inputs = [
            torch.full((10, 10), float(self.rank * i), device=self.device)
            for i in range(10)
        ]
        outputs = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
            inputs,
            self.world_size,
            "default",
        )
        expect = [
            torch.cat(
                [
                    torch.full((10, 10), float(rank) * i, device=self.device)
                    for rank in self.ranks
                ]
            )
            for i in range(10)
        ]
        for i, output in enumerate(outputs):
            output = torch.ops._c10d_functional.wait_tensor(output)
            assert output.eq(expect[i]).all()

        # Test Python API and AsyncCollectiveTensor
        outputs = all_gather_into_tensor_coalesced(
            inputs,
            "default",
        )
        for i, output in enumerate(outputs):
            assert not output.completed
            assert output.eq(expect[i]).all()
            assert output.completed

    @skip_if_lt_x_gpu(2)
    def test_reduce_scatter_tensor_single(self) -> None:
        self._init_process_group()

        input = torch.tensor(self.ranks, device=self.device)
        output = torch.ops._c10d_functional.reduce_scatter_tensor(
            input,
            "avg",
            self.world_size,
            "default",
        )
        output = torch.ops._c10d_functional.wait_tensor(output)
        assert output.eq(self.rank).all()

        # Test Python API and AsyncCollectiveTensor
        output = reduce_scatter_tensor(
            input,
            "avg",
            0,
            "default",
        )
        assert isinstance(output, AsyncCollectiveTensor)
        assert not output.completed
        assert output.eq(self.rank).all()
        assert output.completed

    @skip_if_lt_x_gpu(2)
    def test_reduce_scatter_tensor_coalesced(self) -> None:
        self._init_process_group()

        inputs = [torch.tensor(self.ranks, device=self.device) * i for i in range(10)]
        outputs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
            inputs,
            "avg",
            self.world_size,
            "default",
        )
        for i, output in enumerate(outputs):
            output = torch.ops._c10d_functional.wait_tensor(output)
            assert output.eq(self.rank * i).all()

        # Test Python API and AsyncCollectiveTensor
        outputs = reduce_scatter_tensor_coalesced(
            inputs,
            "avg",
            [0] * 10,
            "default",
        )
        for i, output in enumerate(outputs):
            assert not output.completed
            assert output.eq(self.rank * i).all()
            assert output.completed

    @skip_if_lt_x_gpu(2)
    def test_all_to_all_single(self) -> None:
        self._init_process_group()
        torch.cuda.set_device(self.device)

        torch.manual_seed(42)
        send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))

        input_split_sizes = send_sz_matrix[self.rank].tolist()
        output_split_sizes = send_sz_matrix[:, self.rank].tolist()
        input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda()

        output = torch.ops._c10d_functional.all_to_all_single(
            input,
            output_split_sizes,
            input_split_sizes,
            "default",
        )
        output = torch.ops._c10d_functional.wait_tensor(output)
        expect = torch.cat(
            [
                torch.full((sz,), float(rank)).cuda()
                for rank, sz in enumerate(output_split_sizes)
            ]
        )
        assert output.eq(expect).all()

        # Test Python API and AsyncCollectiveTensor
        output = all_to_all_single(
            input, output_split_sizes, input_split_sizes, "default"
        )
        assert not output.completed
        assert output.eq(expect).all()
        assert output.completed

    @skip_if_lt_x_gpu(2)
    def test_broadcast(self) -> None:
        self._init_process_group()

        input = torch.full((10, 10), float(self.rank), device=self.device)
        output = torch.ops._c10d_functional.broadcast(
            input,
            1,
            "default",
        )
        output = torch.ops._c10d_functional.wait_tensor(output)
        assert id(output) != id(input)
        expect = 1
        assert output.eq(expect).all()

        # Test Python API and AsyncCollectiveTensor
        output = funcol.broadcast(
            input,
            1,
            "default",
        )
        assert isinstance(output, AsyncCollectiveTensor)
        assert not output.completed
        assert output.eq(expect).all()
        assert output.completed

    @skip_if_lt_x_gpu(2)
    def test_unwaited(self) -> None:
        # Verify that the process can terminate gracefully
        # even with unwaited tensors
        self._init_process_group()

        input = torch.full((10, 10), float(self.rank), device=self.device)
        output = torch.ops._c10d_functional.all_reduce(
            input,
            "avg",
            "default",
        )

    @skip_if_lt_x_gpu(2)
    def test_py_work(self) -> None:
        self._init_process_group()

        wait_called = False

        class MyWork(dist.Work):
            def wait(self, _):
                nonlocal wait_called
                wait_called = True

        tensor = torch.rand(2, 2)
        torch._C._distributed_c10d._register_work(tensor, MyWork())
        torch.ops._c10d_functional.wait_tensor(tensor)
        self.assertTrue(wait_called)

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @skip_if_lt_x_gpu(2)
    @fresh_inductor_cache()
    def test_threading(self):
        self._init_process_group()
        device = torch.device(f"cuda:{self.rank}")

        def func(arg: torch.Tensor) -> torch.Tensor:
            buf0 = arg + 42
            ar0 = funcol.all_reduce(buf0, "avg", "0")
            ar0 = funcol.wait_tensor(ar0)
            return ar0 + 1

        arg = torch.rand(4, 4, device=device)
        func(arg)

        compiled = torch.compile(func, fullgraph=True)
        code = run_and_get_triton_code(compiled, arg)
        FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code)

        # Unless explicitly specified (e.g. in a custom runtime), the process
        # group registry is shared among all threads in a process. Here we
        # verify that a process group registered in main thread can be resolved
        # in a different thread.
        class TestThread(threading.Thread):
            def run(self):
                self.exc = None
                try:
                    func(arg)
                    compiled(arg)
                except BaseException as exc:
                    self.exc = exc

            def join(self):
                threading.Thread.join(self)
                if self.exc:
                    raise self.exc

        t = TestThread()
        t.start()
        t.join()


class CompileTest(TestCase):
    def setUp(self):
        # Allow testing aoti after torch.compile
        torch._inductor.config.triton.store_cubin = True
        torch._inductor.config.debug = True

        self.rank = 0
        self.world_size = 2
        torch.cuda.set_device("cuda:0")

        store = FakeStore()
        dist.init_process_group(
            backend="fake",
            world_size=self.world_size,
            rank=self.rank,
            store=store,
        )

    def tearDown(self):
        dist.destroy_process_group()

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_all_reduce_single(self):
        def func(arg: torch.Tensor) -> torch.Tensor:
            buf0 = arg + 42
            # Expect in-place with inductor allocated buf
            ar0 = funcol.all_reduce(buf0, "avg", "0")
            ar0 = funcol.wait_tensor(ar0)
            # Expect no in-place with graph input
            ar1 = funcol.all_reduce(arg, "avg", "0")
            ar1 = funcol.wait_tensor(ar1)
            return ar0, ar1

        arg = torch.rand(4, 4, device="cuda")
        compiled = torch.compile(func)

        code = run_and_get_triton_code(compiled, arg)
        (
            FileCheck()
            .check("buf0 = empty")
            .check("buf7 = empty")
            # Expect in-place with inductor allocated buf
            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
            # Expect no in-place with graph input (buf5 is a clone)
            .check("torch.ops._c10d_functional.all_reduce_.default(buf7")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf7")
            # Expect no extra copy on return
            .check("return (buf0, buf7, )")
            .run(code)
        )
        assert "= torch.ops._c10d_functional.wait_tensor.default" not in code

        # Test aoti
        out = AOTIRunnerUtil.run("cuda", func, (arg,))
        torch.cuda.synchronize()

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_all_reduce_coalesced(self):
        def func(args: List[torch.Tensor]) -> torch.Tensor:
            bufs = [arg + 42 for arg in args]
            # Expect in-place with inductor allocated buf
            ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0")
            ar0 = [funcol.wait_tensor(out) for out in ar0]
            # Expect no in-place with graph input
            ar1 = funcol.all_reduce_coalesced(args, "avg", "0")
            ar1 = [funcol.wait_tensor(out) for out in ar1]
            return ar0, ar1

        args = [torch.rand(4, 4, device="cuda") for _ in range(2)]
        compiled = torch.compile(func)
        code = run_and_get_triton_code(compiled, args)
        (
            FileCheck()
            .check("buf0 = empty")
            .check("buf5 = empty")
            .check("buf1 = empty")
            .check("buf6 = empty")
            # Expect in-place with inductor allocated buf
            .check(
                "torch.ops._c10d_functional.all_reduce_coalesced_"
                ".default([buf0, buf1]"
            )
            # Expect no in-place with graph input (buf5, buf6 are clones)
            .check(
                "torch.ops._c10d_functional.all_reduce_coalesced_"
                ".default([buf5, buf6]"
            )
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf1")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf5")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf6")
            # Expect no extra copy on return
            .check("return (buf0, buf1, buf5, buf6, )")
            .run(code)
        )
        assert "= torch.ops._c10d_functional.wait_tensor.default" not in code

        # Test aoti
        out = AOTIRunnerUtil.run("cuda", func, (args,))
        torch.cuda.synchronize()

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_inplace_op_on_view(self):
        def func(arg: torch.Tensor) -> torch.Tensor:
            buf0 = (arg + 10)[:2]
            ar0 = funcol.all_reduce(buf0, "avg", "0")
            ar0 = funcol.wait_tensor(ar0)
            return ar0

        arg = torch.rand(4, 4, device="cuda")
        compiled = torch.compile(func)

        code = run_and_get_triton_code(compiled, arg)
        (
            FileCheck()
            .check("buf0 = empty")
            # Ensure the all_reduce_ input is a view
            .check(
                "torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf0"
            )
            .check(
                "torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf0"
            )
            .check("return (reinterpret_tensor(buf0")
            .run(code)
        )

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_reuse_buffer_after_inplace_collective(self):
        def func(arg: torch.Tensor) -> torch.Tensor:
            # Expect allocation
            buf0 = arg + 42
            ar0 = funcol.all_reduce(buf0, "avg", "0")
            ar0 = funcol.wait_tensor(ar0)
            # Expect allocation
            buf1 = torch.mm(arg, ar0)
            # Expect buf0 to be reused
            buf2 = torch.mm(arg, buf1)
            return buf1, buf2

        arg = torch.rand(4, 4, device="cuda")
        compiled = torch.compile(func)
        code = run_and_get_triton_code(compiled, arg)
        (
            FileCheck()
            # Expect allocation
            .check("buf0 = empty")
            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
            # Expect allocation
            .check("buf7 = empty")
            .check("extern_kernels.mm(arg0_1, buf0, out=buf7")
            # Expect buf0 to be reused
            .check("buf8 = buf0; del buf0  # reuse")
            .check("extern_kernels.mm(arg0_1, buf7, out=buf8")
            # Expect no extra copy on return
            .check("return (buf7, buf8, )")
            .run(code)
        )
        assert "= torch.ops._c10d_functional.wait_tensor.default" not in code

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_all_gather_into_tensor_single(self):
        def func(arg: torch.Tensor) -> torch.Tensor:
            ag0 = funcol.all_gather_tensor(arg, 0, "0")
            ag0 = funcol.wait_tensor(ag0)
            return ag0

        arg = torch.rand(4, 4, device="cuda")
        compiled = torch.compile(func)
        code = run_and_get_triton_code(compiled, arg)
        (
            FileCheck()
            .check(
                "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(arg0_1"
            )
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
            # Expect no extra copy on return
            .check("return (buf0, )")
            .run(code)
        )
        assert "= torch.ops._c10d_functional.wait_tensor.default" not in code

        # Test aoti
        out = AOTIRunnerUtil.run("cuda", func, (arg,))
        torch.cuda.synchronize()

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_all_gather_into_tensor_coalesced(self):
        def func(args: List[torch.Tensor]) -> torch.Tensor:
            ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
            ag0 = [funcol.wait_tensor(out) for out in ag0]
            return ag0

        args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
        compiled = torch.compile(func)
        code = run_and_get_triton_code(compiled, args)
        (
            FileCheck()
            .check(
                "buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
                ".default([arg0_1, arg1_1, arg2_1, arg3_1]"
            )
            .check("buf1 = buf0[0]")
            .check("buf2 = buf0[1]")
            .check("buf3 = buf0[2]")
            .check("buf4 = buf0[3]")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf1")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf4")
            # Expect no extra copy on return
            .check("return (buf1, buf2, buf3, buf4, )")
            .run(code)
        )

        # Test aoti
        out = AOTIRunnerUtil.run("cuda", func, (args,))
        torch.cuda.synchronize()

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_reduce_scatter_tensor_single(self):
        def func(arg: torch.Tensor) -> torch.Tensor:
            rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0")
            rs0 = funcol.wait_tensor(rs0)
            return rs0

        arg = torch.rand(4, 4, device="cuda")
        compiled = torch.compile(func)
        code = run_and_get_triton_code(compiled, arg)
        (
            FileCheck()
            .check(
                "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor.default(arg0_1"
            )
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
            # Expect no extra copy on return
            .check("return (buf0, )")
            .run(code)
        )

        # Test aoti
        out = AOTIRunnerUtil.run("cuda", func, (arg,))
        torch.cuda.synchronize()

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_reduce_scatter_tensor_coalesced(self):
        def func(args: List[torch.Tensor]) -> torch.Tensor:
            rs0 = funcol.reduce_scatter_tensor_coalesced(
                args, "avg", [0] * len(args), "0"
            )
            rs0 = [funcol.wait_tensor(out) for out in rs0]
            return rs0

        args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
        compiled = torch.compile(func)
        code = run_and_get_triton_code(compiled, args)
        (
            FileCheck()
            .check(
                "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced"
                ".default([arg0_1, arg1_1, arg2_1, arg3_1]"
            )
            .check("buf1 = buf0[0]")
            .check("buf2 = buf0[1]")
            .check("buf3 = buf0[2]")
            .check("buf4 = buf0[3]")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf1")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf4")
            # Expect no extra copy on return
            .check("return (buf1, buf2, buf3, buf4, )")
            .run(code)
        )

        # Test aoti
        AOTIRunnerUtil.run("cuda", func, (args,))
        torch.cuda.synchronize()

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_all_to_all_single(self):
        def _tolist_with_constrain_as_size(tensor):
            lst = tensor.tolist()
            for elem in lst:
                torch._check_is_size(elem)
            return lst

        def func(
            input: torch.Tensor,
            output_split_sizes: torch.Tensor,
            input_split_sizes: torch.Tensor,
        ) -> torch.Tensor:
            output = funcol.all_to_all_single(
                input,
                _tolist_with_constrain_as_size(output_split_sizes),
                _tolist_with_constrain_as_size(input_split_sizes),
                "0",
            )
            return funcol.wait_tensor(output)

        torch.manual_seed(42)
        send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))

        input_split_sizes = send_sz_matrix[self.rank]
        output_split_sizes = send_sz_matrix[:, self.rank].contiguous()
        input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda()

        with torch._dynamo.config.patch(
            dynamic_shapes=True,
            capture_dynamic_output_shape_ops=True,
            capture_scalar_outputs=True,
        ):
            compiled = torch.compile(func, dynamic=True)
            code = run_and_get_triton_code(
                compiled, input, output_split_sizes, input_split_sizes
            )
        (
            FileCheck()
            .check_regex(
                "torch.ops._c10d_functional.all_to_all_single.default\\("
                "arg\\d+_\\d+, \\[u\\d+, u\\d+\\], \\[u\\d+, u\\d+\\]"
            )
            .check("torch.ops._c10d_functional.wait_tensor.default(")
            .run(code)
        )

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_inductor_broadcast(self):
        def func(arg: torch.Tensor) -> torch.Tensor:
            buf0 = arg + 42
            # Expect in-place with inductor allocated buf
            br0 = funcol.broadcast(buf0, 1, "0")
            br0 = funcol.wait_tensor(br0)
            # Expect no in-place with graph input
            br1 = funcol.broadcast(arg, 0, "0")
            br1 = funcol.wait_tensor(br1)
            return br0, br1

        arg = torch.rand(4, 4, device="cuda")
        compiled = torch.compile(func)

        code = run_and_get_triton_code(compiled, arg)
        (
            FileCheck()
            .check("buf0 = empty")
            .check("buf7 = empty")
            # Expect in-place with inductor allocated buf
            .check("torch.ops._c10d_functional.broadcast_.default(buf0")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
            # Expect no in-place with graph input (buf5 is a clone)
            .check("torch.ops._c10d_functional.broadcast_.default(buf7")
            .check("torch.ops._c10d_functional.wait_tensor.default(buf7")
            # Expect no extra copy on return
            .check("return (buf0, buf7, )")
            .run(code)
        )

        # Test aoti
        out = AOTIRunnerUtil.run("cuda", func, (arg,))
        torch.cuda.synchronize()

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @fresh_inductor_cache()
    def test_ranks_and_tag(self):
        def func(arg: torch.Tensor) -> torch.Tensor:
            buf0 = arg + 42
            # Expect in-place with inductor allocated buf
            ar0 = funcol.all_reduce(buf0, "avg", [0, 1], "")
            ar0 = funcol.wait_tensor(ar0)
            # Expect no in-place with graph input
            ar1 = funcol.all_reduce(arg, "avg", [0, 1], "")
            ar1 = funcol.wait_tensor(ar1)
            return ar0, ar1

        arg = torch.rand(4, 4, device="cuda")
        compiled = torch.compile(func, fullgraph=True)

        code = run_and_get_triton_code(compiled, arg)
        (FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code))


if __name__ == "__main__":
    run_tests()
