# Owner(s): ["oncall: distributed"]

import os
import sys
import unittest
from functools import partial, wraps

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as ft_c
import torch.distributed._tensor as dt
import torch.distributed.distributed_c10d as c10d
from functorch import make_fx
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.utils._triton import has_triton


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

from torch.testing._internal.common_distributed import (
    MultiProcessTestCase,
    MultiThreadedTestCase,
    requires_nccl,
    TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TestCase,
)


def new_subgroups(group_size: int, pg_tag=None):
    world_size = dist.get_world_size()
    subgroups = []
    cur_subgroup = None

    for subgroup_id in range(world_size // group_size):
        start_rank = subgroup_id * group_size
        end_rank = start_rank + group_size
        ranks_in_subgroup = list(range(start_rank, end_rank))
        subgroup = c10d._new_group_with_tag(
            ranks=ranks_in_subgroup,
            pg_tag=pg_tag,
        )
        subgroups.append(subgroup)

        rank = dist.get_rank()
        if rank in ranks_in_subgroup:
            cur_subgroup = subgroup

    return cur_subgroup, subgroups


class TestExpand(MultiThreadedTestCase):
    @property
    def world_size(self):
        return 4

    def setUp(self):
        super().setUp()
        self._spawn_threads()

    def test_expand_1d_rank_list(self):
        tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
        self.assertEqual("", tag)
        self.assertEqual([0, 1, 2, 3], rankset)
        self.assertEqual(4, group_size)

        tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
        self.assertEqual("bla", tag)

    def test_expand_2d_rank_list(self):
        tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
        self.assertEqual("", tag)
        self.assertEqual([0, 1, 2, 3], rankset)
        self.assertEqual(2, group_size)

        tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
        self.assertEqual("blu", tag)

        with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
            ft_c._expand_group([[0], [1, 2, 3]])

    def test_expand_process_group(self):
        tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
        self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
        self.assertEqual([0, 1, 2, 3], rankset)
        self.assertEqual(4, group_size)

        tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
        self.assertEqual("bla", tag)

        my_pg, others = new_subgroups(group_size=2)
        tag, rankset, group_size = ft_c._expand_group(my_pg)
        self.assertEqual(c10d._get_group_tag(my_pg), tag)
        self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
        self.assertEqual(2, group_size)

        my_pg = None
        for i in range(dist.get_world_size()):
            group = c10d._new_group_with_tag([i], pg_tag="my_pg")
            if i == dist.get_rank():
                my_pg = group
        tag, rankset, group_size = ft_c._expand_group(my_pg)
        self.assertEqual("my_pg", tag)
        self.assertEqual([dist.get_rank()], rankset)
        self.assertEqual(1, group_size)

        tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
        self.assertEqual("bla", tag)

    def test_expand_device_mesh(self):
        mesh = dt.DeviceMesh("cpu", torch.arange(4))
        tag, rankset, group_size = ft_c._expand_group(mesh)
        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
        self.assertEqual([0, 1, 2, 3], rankset)
        self.assertEqual(4, group_size)

        mesh = dt.DeviceMesh("cpu", torch.arange(4))
        tag, rankset, group_size = ft_c._expand_group(mesh)
        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
        self.assertEqual([0, 1, 2, 3], rankset)
        self.assertEqual(4, group_size)

    def test_expand_device_mesh_tuple(self):
        mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
        with self.assertRaisesRegex(AssertionError, "Only 1D mesh"):
            tag, rankset, group_size = ft_c._expand_group(mesh)

        tag, rankset, group_size = ft_c._expand_group((mesh, 0))
        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
        expected_rankset = [0, 2] if dist.get_rank() in [0, 2] else [1, 3]
        self.assertEqual(expected_rankset, rankset)
        self.assertEqual(2, group_size)

        tag, rankset, group_size = ft_c._expand_group((mesh, 1))
        expected_rankset = [0, 1] if dist.get_rank() in [0, 1] else [2, 3]
        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=1)), tag)
        self.assertEqual(expected_rankset, rankset)
        self.assertEqual(2, group_size)


class TestPgTag(MultiThreadedTestCase):
    @property
    def world_size(self):
        return 4

    def setUp(self):
        super().setUp()
        self._spawn_threads()

    """
    The behavior we want is as follow:

    - rankset+tag will always result in the same PG.
    Do we enforce this by failing creation of new PGs or returning existing ones?
        Return existing one.

    - default tag gives existing behavior.
        This means we should create duplicates.
    - _expand_group on _default-tagged pg should always resolve to it
        This mean we can't depend on empty tag + rankset.
    """

    def test_pg_creation_with_tag(self):
        my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
        my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
        self.assertEqual(my_group, my_group2)

        my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
        self.assertNotEqual(my_group, my_group3)

        my_group4, _ = new_subgroups(group_size=2)
        self.assertNotEqual(my_group, my_group4)

        my_group5, _ = new_subgroups(group_size=2)
        self.assertNotEqual(my_group4, my_group5)

    def test_pg_lookup_roundtrip(self):
        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
        pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
        pg_notag0, _ = new_subgroups(group_size=2)
        pg_notag1, _ = new_subgroups(group_size=2)

        def roundtrip(pg):
            tag, rankset, _ = ft_c._expand_group(pg)
            return c10d._find_pg_by_ranks_and_tag(tag, rankset)

        self.assertEqual(pg_tag0, roundtrip(pg_tag0))
        self.assertEqual(pg_tag1, roundtrip(pg_tag1))
        self.assertEqual(pg_notag0, roundtrip(pg_notag0))
        self.assertEqual(pg_notag1, roundtrip(pg_notag1))

    def test_pg_lookup_with_tag(self):
        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
        pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
        pg_notag0, _ = new_subgroups(group_size=2)

        def roundtrip(pg, pg_tag):
            tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
            return c10d._find_pg_by_ranks_and_tag(tag, rankset)

        self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
        self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
        # Cannot erase the tag of a PG
        self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))

    def test_find_or_create_pg(self):
        pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
        self.assertEqual(pg, pg_tag0)

    def test_find_root_pg(self):
        pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
        self.assertEqual(dist.group.WORLD, pg)


@instantiate_parametrized_tests
class TestTraceableCollectives(MultiThreadedTestCase):
    @property
    def world_size(self):
        return 4

    def setUp(self):
        super().setUp()
        self._spawn_threads()

    @parametrize("device", ["cpu", "cuda"])
    def test_broadcast(self, device):
        if device == "cuda":
            if torch.cuda.device_count() < self.world_size:
                self.skipTest("Not enough CUDA devices")
            torch.cuda.set_device(dist.get_rank())

        if dist.get_rank() == 0:
            tensor = torch.ones([4], device=device)
        else:
            tensor = torch.zeros([4], device=device)

        mesh = dt.DeviceMesh(device, torch.arange(4))
        res = ft_c.broadcast(tensor, 0, mesh)
        self.assertEqual(res, torch.ones([4], device=device))

    @parametrize("device", ["cpu", "cuda"])
    def test_all_reduce_eager(self, device):
        if device == "cuda":
            if torch.cuda.device_count() < self.world_size:
                self.skipTest("Not enough CUDA devices")
            torch.cuda.set_device(dist.get_rank())

        tensor = torch.ones([4], device=device)
        mesh = dt.DeviceMesh(device, torch.arange(4))

        res = ft_c.all_reduce(tensor, "sum", mesh)
        self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))

        mesh = dt.DeviceMesh(device, torch.arange(4).view(2, 2))
        res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
        self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))

    @parametrize("device", ["cpu", "cuda"])
    def test_all_reduce_coalesced_eager(self, device):
        if device == "cuda":
            if torch.cuda.device_count() < self.world_size:
                self.skipTest("Not enough CUDA devices")
            torch.cuda.set_device(dist.get_rank())

        t0 = torch.ones([4], device=device)
        t1 = torch.ones([6], device=device) + 2
        mesh = dt.DeviceMesh(device, torch.arange(4))

        res = ft_c.all_reduce_coalesced([t0, t1], "sum", mesh)
        self.assertEqual(res[0], t0 * 4)
        self.assertEqual(res[1], t1 * 4)

    @parametrize("device", ["cpu", "cuda"])
    def test_all_gather_tensor(self, device):
        if device == "cuda":
            if torch.cuda.device_count() < self.world_size:
                self.skipTest("Not enough CUDA devices")
            torch.cuda.set_device(dist.get_rank())

        # testing 1d/2d mesh
        mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
        mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
        for mesh in [mesh_1d, mesh_2d]:
            dims_to_gather = [0, 1, 2]
            for dim in dims_to_gather:
                output_size = [3, 3, 3]
                output_size[dim] *= mesh.size(0)
                # each rank have its own tensor, all_gather gives a bigger tensor
                local_tensor = torch.ones([3, 3, 3], device=device)
                gathered_tensor = ft_c.all_gather_tensor(
                    local_tensor, gather_dim=dim, group=(mesh, 0)
                )
                self.assertEqual(gathered_tensor, torch.ones(output_size))

    @parametrize("device", ["cpu", "cuda"])
    def test_all_gather_into_tensor_coalesced(self, device):
        if device == "cuda":
            if torch.cuda.device_count() < self.world_size:
                self.skipTest("Not enough CUDA devices")
            torch.cuda.set_device(dist.get_rank())

        tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
        mesh = dt.DeviceMesh(device, torch.arange(4))

        res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
        self.assertEqual(2, len(res))
        self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device), res[0])
        self.assertEqual(
            torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1]
        )

    @parametrize("device", ["cpu", "cuda"])
    def test_reduce_scatter_tensor(self, device):
        if device == "cuda":
            if torch.cuda.device_count() < self.world_size:
                self.skipTest("Not enough CUDA devices")
            torch.cuda.set_device(dist.get_rank())

        # testing 1d/2d mesh
        mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
        mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
        for mesh in [mesh_1d, mesh_2d]:
            dims_to_scatter = [0, 1]
            for dim in dims_to_scatter:
                group_size = mesh.size(0)
                input_size = [3, 3]
                output_size = [3, 3]
                output_size[dim] *= group_size
                input_tensor = torch.ones(output_size, device=device)
                res_num = 1 * group_size
                rs_tensor = ft_c.reduce_scatter_tensor(
                    input_tensor, "sum", scatter_dim=dim, group=(mesh, 0)
                )
                self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)

    @parametrize("device", ["cpu", "cuda"])
    def test_reduce_scatter_into_tensor_coalesced(self, device):
        if device == "cuda":
            if torch.cuda.device_count() < self.world_size:
                self.skipTest("Not enough CUDA devices")
            torch.cuda.set_device(dist.get_rank())
        tensors = [
            torch.ones([4], dtype=torch.int64, device=device),
            torch.ones([4], dtype=torch.int64, device=device) + 1,
        ]
        mesh = dt.DeviceMesh(device, torch.arange(4))

        res = ft_c.reduce_scatter_tensor_coalesced(tensors, "sum", [0, 0], mesh)
        self.assertEqual(2, len(res))
        self.assertEqual(torch.tensor([4], device=device), res[0])
        self.assertEqual(torch.tensor([8], device=device), res[1])


class TestMetaCollectives(TestCase):
    def test_all_reduce(self):
        x = torch.rand((2, 3, 4), device="meta")
        out = ft_c.all_reduce(x, "sum", "0")
        self.assertEqual(x.size(), out.size())


class TestGradCollectives(MultiThreadedTestCase):
    @property
    def world_size(self):
        return 2

    def setUp(self):
        super().setUp()
        self._spawn_threads()

    def test_all_reduce(self):
        x = torch.rand([4], requires_grad=True)
        y = torch.rand([4], requires_grad=True)
        out = ft_c.all_reduce(x, "sum", dist.group.WORLD)
        (out + y).sum().backward()
        self.assertIsNone(x.grad)


class TestMakeFx(TestCase):
    def setUp(self):
        # make_fx is not thread-safe due to patching nd mutating global states
        # so create a fake_pg.
        self.rank = 0
        self.world_size = 2
        store = FakeStore()
        dist.init_process_group(
            backend="fake",
            world_size=self.world_size,
            rank=self.rank,
            store=store,
        )

    def tearDown(self):
        super().tearDown()

        self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing())

    def test_all_reduce_tracing(self):
        def allred(input):
            return ft_c.all_reduce(input, "sum", group=dist.group.WORLD) + 1

        graph = make_fx(allred)(torch.rand(4))
        FileCheck().check("all_reduce").check("wait_tensor").run(str(graph.graph))

        mesh = dt.DeviceMesh("cpu", torch.arange(self.world_size))

        def allred_mesh(input):
            return ft_c.all_reduce(input, "sum", mesh) + 1

        mesh_graph = make_fx(allred_mesh)(torch.rand(4))
        FileCheck().check_not("get_attr").check("wait_tensor").run(
            str(mesh_graph.graph)
        )

        def allred_mesh_dim(input):
            return ft_c.all_reduce(input, "sum", (mesh, 0)) + 1

        mesh_dim_graph = make_fx(allred_mesh_dim)(torch.rand(4))
        FileCheck().check_not("get_attr").check("wait_tensor").run(
            str(mesh_dim_graph.graph)
        )


BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
WORLD_SIZE = 2


def exit_if_lt_x_gpu(x):
    if torch.cuda.device_count() < x:
        sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)


def with_comms(func=None):
    if func is None:
        return partial(
            with_comms,
        )

    @wraps(func)
    def wrapper(self, *args, **kwargs):
        global BACKEND

        if "BACKEND" in os.environ:
            BACKEND = os.environ["BACKEND"]
        if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
        self.dist_init()
        func(self)
        self.destroy_comms()

    return wrapper


class TestCollectivesWithNCCL(MultiProcessTestCase):
    def setUp(self):
        super().setUp()
        os.environ["WORLD_SIZE"] = str(self.world_size)
        os.environ["BACKEND"] = dist.Backend.NCCL
        BACKEND = dist.Backend.NCCL
        self._spawn_processes()

    @property
    def device(self):
        return torch.device(self.rank)

    @property
    def world_size(self):
        return WORLD_SIZE

    @property
    def process_group(self):
        return dist.group.WORLD

    def dist_init(self):
        dist.init_process_group(
            backend=BACKEND,
            world_size=self.world_size,
            rank=self.rank,
            init_method=f"file://{self.file_name}",
        )

        # set device for nccl pg for collectives
        if BACKEND == "nccl":
            torch.cuda.set_device(self.rank)

    def destroy_comms(self):
        # Wait for all ranks to reach here before starting shutdown.
        dist.barrier()
        dist.destroy_process_group()

    @requires_nccl()
    @with_comms()
    def test_all_gather_into_tensor_coalesced(self):
        exit_if_lt_x_gpu(self.world_size)

        tensors = [
            torch.ones([4], device=f"cuda:{self.rank}"),
            torch.ones([4], device=f"cuda:{self.rank}") + 1,
        ]
        mesh = dt.DeviceMesh(f"cuda:{self.rank}", torch.arange(self.world_size))

        res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
        self.assertEqual(2, len(res))
        self.assertEqual(torch.ones([4 * dist.get_world_size()]), res[0])
        self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1])

    @with_comms()
    def test_all_to_all_single(self):
        device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
        mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
        rank = dist.get_rank()

        row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
        x = torch.ones(int(row), 5, device=device) * (rank + 1)
        split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
        y = ft_c.all_to_all_single(
            x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
        )
        expected = []
        for idx, tensor in enumerate(torch.split(x, split_sizes)):
            expected.append(torch.full_like(tensor, (idx + 1)))
        expected = torch.cat(expected)
        self.assertEqual(y, expected)

    @with_comms()
    def test_all_to_all_single_1d_input(self):
        device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
        mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
        rank = dist.get_rank()

        row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
        x = torch.ones(int(row), device=device) * (rank + 1)
        split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
        y = ft_c.all_to_all_single(
            x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
        )
        expected = []
        for idx, tensor in enumerate(torch.split(x, split_sizes)):
            expected.append(torch.full_like(tensor, (idx + 1)))
        expected = torch.cat(expected)
        self.assertEqual(y, expected)

    @with_comms()
    def test_all_to_all_single_split_sizes_none(self):
        device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
        mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
        rank = dist.get_rank()

        x = torch.ones(self.world_size, self.world_size, device=device) * (rank + 1)
        y = ft_c.all_to_all_single(
            x, output_split_sizes=None, input_split_sizes=None, group=mesh
        )
        expected = []
        for idx, tensor in enumerate(torch.chunk(x, self.world_size)):
            expected.append(torch.full_like(tensor, (idx + 1)))
        expected = torch.cat(expected)
        self.assertEqual(y, expected)

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @requires_nccl()
    @with_comms()
    def test_tracing(self):
        def allreduce(t, pg):
            return ft_c.all_reduce(t, "sum", pg)

        compiled_allreduce = torch.compile(allreduce, fullgraph=True)
        compiled_allreduce(torch.randn(8, device=self.device), self.process_group)

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    def test_tracing_with_fakepg(self):
        exit_if_lt_x_gpu(self.world_size)

        def allreduce(t, pg):
            return ft_c.all_reduce(t, "sum", pg)

        compiled_allreduce = torch.compile(allreduce, fullgraph=True)
        dist.init_process_group(
            backend="fake",
            rank=0,
            world_size=8,
            store=FakeStore(),
        )
        allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)

    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
    @requires_nccl()
    @with_comms()
    def test_tracing_with_dce_code(self):
        if self.world_size > 2:
            return

        def func(batch, group, rank):
            ret = ft_c.permute_tensor(batch, [1, 0], group)
            if hasattr(ret, "wait"):
                ret = ret.wait()
            if rank == 0:
                return ret
            else:
                return batch * 5

        compiled_func = torch.compile(func)
        ret = compiled_func(
            torch.ones((100,), device="cuda"), self.process_group, self.rank
        )
        dist.barrier()


class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL):
    @property
    def world_size(self):
        return 4

    @requires_nccl()
    @with_comms()
    def test_permute_tensor_with_sub_group(self):
        exit_if_lt_x_gpu(self.world_size)

        device = "cuda"
        mesh_dim_names = ["dp", "tp"]

        mesh_2d = dt.init_device_mesh(
            device, (2, self.world_size // 2), mesh_dim_names=mesh_dim_names
        )

        for mesh_name in mesh_dim_names:
            mesh = mesh_2d[mesh_name]
            rank = mesh.get_local_rank()

            # rank0: [0., 1.], rank1: [2., 3.]
            send_tensor = torch.arange(2, dtype=torch.float32, device=device) + 2 * rank
            recvd_tensor = ft_c.permute_tensor(send_tensor, [1, 0], group=mesh)

            # rank0: [2., 3.], rank1: [0., 1.]
            expected = torch.arange(2, dtype=torch.float32, device=device) + 2 * (
                (rank - 1 + 2) % 2
            )
            self.assertEqual(
                recvd_tensor,
                expected,
                msg=f"Expected {expected} on {self.rank=} (local_rank={rank}), "
                f"but received {recvd_tensor} instead.",
            )


@instantiate_parametrized_tests
class TestFunctionalAutograd(MultiThreadedTestCase):
    def setUp(self):
        super().setUp()
        self._spawn_threads()

    @property
    def world_size(self):
        return 2

    @parametrize("compile", [True, False])
    def test_all_to_all_single(self, compile: bool = True) -> None:
        group = dist.group.WORLD.group_name

        t = torch.ones((self.world_size, 2), requires_grad=True)

        def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
            sizes = [1] * world_size
            t = t * 2
            assert t.requires_grad
            out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
            out = out + 0
            return out

        if compile:
            compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
        else:
            compiled = my_func

        out = compiled(t, self.world_size)
        self.assertEqual(out.shape, t.shape)
        self.assertEqual(out, torch.full_like(t, 2.0))
        self.assertIsNotNone(out.grad_fn)
        self.assertTrue(out.requires_grad)
        loss = out.sum()
        loss.backward()
        self.assertEqual(t.grad, torch.full_like(t, 2.0))

    def test_all_to_all_single_inductor(self) -> None:
        group = dist.group.WORLD.group_name

        t = torch.rand((self.world_size, 2), requires_grad=True)

        def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
            sizes = [1] * world_size
            t = t * 10
            assert t.requires_grad
            out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
            out = out + 2
            return out.sum()

        compiled = torch.compile(my_func, fullgraph=True)

        def run_with_backward():
            out = compiled(t, self.world_size)
            out.backward()

        res, codes = run_and_get_code(run_with_backward)
        for code in codes:
            FileCheck().check_count(
                "_c10d_functional.all_to_all_single.default", 1, exactly=True
            ).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(
                code
            )

        self.assertIsNotNone(t.grad)

    @parametrize("compile", [True, False])
    def test_all_gather_tensor(self, compile: bool) -> None:
        group = dist.group.WORLD.group_name

        def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
            assert t.requires_grad
            out = ft_c.all_gather_tensor_autograd(
                t * 1.0,
                gather_dim=dim,
                group=group,
            )
            out = out * 1.0
            return out

        if compile:
            compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
        else:
            compiled = my_func

        dims_to_gather = [0, 1, 2]
        for dim in dims_to_gather:
            output_size = [3, 3, 3]
            output_size[dim] *= self.world_size
            # each rank have its own tensor, all_gather gives a bigger tensor
            local_tensor = torch.ones([3, 3, 3], requires_grad=True)
            gathered_tensor = compiled(local_tensor, dim)
            self.assertEqual(gathered_tensor, torch.ones(output_size))

            gathered_tensor.sum().backward()
            self.assertEqual(
                local_tensor.grad,
                torch.full((3, 3, 3), fill_value=float(self.world_size)),
            )

    @parametrize("compile", [True, False])
    def test_reduce_scatter_tensor(self, compile: bool) -> None:
        group = dist.group.WORLD.group_name

        def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
            assert t.requires_grad
            rs_tensor = (
                ft_c.reduce_scatter_tensor_autograd(
                    input_tensor * 1.0, "sum", scatter_dim=dim, group=group
                )
                * 1.0
            )
            return rs_tensor

        if compile:
            compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
        else:
            compiled = my_func

        dims_to_scatter = [0, 1]
        for dim in dims_to_scatter:
            group_size = self.world_size
            input_size = [3, 3]
            output_size = [3, 3]
            output_size[dim] *= group_size
            input_tensor = torch.ones(output_size, requires_grad=True)
            rs_tensor = compiled(input_tensor, dim)
            res_num = 1 * group_size
            self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
            rs_tensor.sum().backward()
            self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))


class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
    def setUp(self):
        super().setUp()
        os.environ["WORLD_SIZE"] = str(self.world_size)
        os.environ["BACKEND"] = dist.Backend.NCCL
        self._spawn_processes()

    @property
    def device(self):
        return torch.device(self.rank)

    @property
    def world_size(self):
        return 2

    @property
    def process_group(self):
        return dist.group.WORLD

    def dist_init(self):
        dist.init_process_group(
            backend=BACKEND,
            world_size=self.world_size,
            rank=self.rank,
            init_method=f"file://{self.file_name}",
        )

        # set device for nccl pg for collectives
        if BACKEND == "nccl":
            torch.cuda.set_device(self.rank)

    def destroy_comms(self):
        # Wait for all ranks to reach here before starting shutdown.
        dist.barrier()
        dist.destroy_process_group()

    @requires_nccl()
    @with_comms()
    def test_all_to_all_single(self) -> None:
        group = self.process_group.group_name

        t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device)

        sizes = [1] * self.world_size
        assert t.requires_grad
        out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0

        self.assertEqual(out.shape, t.shape)
        self.assertEqual(out, torch.full_like(t, 2.0))
        self.assertIsNotNone(out.grad_fn)
        self.assertTrue(out.requires_grad)
        loss = out.sum()
        loss.backward()
        self.assertEqual(t.grad, torch.full_like(t, 2.0))


if __name__ == "__main__":
    run_tests()
