# Owner(s): ["oncall: distributed"]

import operator
import os
import sys
import threading
from functools import reduce
from unittest import skip, SkipTest

import torch
import torch.autograd
import torch.distributed as dist
from torch._C._distributed_c10d import ReduceOp


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

from torch.testing._internal.common_distributed import (
    MultiThreadedTestCase,
    skip_if_lt_x_gpu,
    spawn_threads_and_init_comms,
)
from torch.testing._internal.common_utils import IS_SANDCASTLE, run_tests, TestCase


DEFAULT_WORLD_SIZE = 4


class TestCollectivesWithWrapper(TestCase):
    @spawn_threads_and_init_comms(world_size=4)
    def test_broadcast_object_list(self):
        val = 99 if dist.get_rank() == 0 else None
        object_list = [val] * dist.get_world_size()

        dist.broadcast_object_list(object_list=object_list)
        self.assertEqual(99, object_list[0])

    def test_collective_error_on_rank_zero(self):
        @spawn_threads_and_init_comms(world_size=4)
        def _test_method(self):
            input_tensor = torch.ones(3, 3) * dist.get_rank()  # perform 1st all gather
            output_tensors = [
                torch.empty_like(input_tensor) for _ in range(dist.get_world_size())
            ]
            dist.all_gather(output_tensors, input_tensor)

            if dist.get_rank() == 0:
                raise AssertionError("Mimic real test failure.")  # fail on rank 0

            dist.all_gather(output_tensors, input_tensor)  # perform 2nd all gather

        with self.assertRaises(RuntimeError):
            _test_method(self)

    def test_collective_error_on_rank_non_zero(self):
        @spawn_threads_and_init_comms(world_size=4)
        def _test_method(self):
            input_tensor = torch.ones(3, 3) * dist.get_rank()  # perform 1st all gather
            output_tensors = [
                torch.empty_like(input_tensor) for _ in range(dist.get_world_size())
            ]
            dist.all_gather(output_tensors, input_tensor)

            if dist.get_rank() == 1:
                raise AssertionError("Mimic real test failure.")  # fail on rank 1

            dist.all_gather(output_tensors, input_tensor)  # perform 2nd all gather

        with self.assertRaises(RuntimeError):
            _test_method(self)

    def test_collective_error_on_rank_non_zero_all(self):
        @spawn_threads_and_init_comms(world_size=4)
        def _test_method(self):
            input_tensor = torch.ones(3, 3) * dist.get_rank()  # perform 1st all gather
            output_tensors = [
                torch.empty_like(input_tensor) for _ in range(dist.get_world_size())
            ]
            dist.all_gather(output_tensors, input_tensor)

            if dist.get_rank() > 0:
                raise AssertionError(
                    "Mimic real test failure."
                )  # fail on all non-zero rank

            dist.all_gather(output_tensors, input_tensor)  # perform 2nd all gather

        with self.assertRaises(RuntimeError):
            _test_method(self)

    def test_skip(self):
        @spawn_threads_and_init_comms(world_size=4)
        @skip("check if skip exception can be captured correctly.")
        def _test_method(self):
            pass

        if not IS_SANDCASTLE:
            with self.assertRaises(SkipTest):
                _test_method(self)

    @spawn_threads_and_init_comms(world_size=4)
    def test_all_to_all_single_tensor(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        send = torch.full((world_size, 2), rank)
        sizes = torch.ones(world_size, dtype=torch.int64)

        out = torch.zeros(world_size, 2, dtype=send.dtype)
        dist.all_to_all_single(out, send, sizes, sizes)
        self.assertEqual(out.tolist(), list(zip(range(world_size), range(world_size))))

    @spawn_threads_and_init_comms(world_size=4)
    def test_all_to_all_single_list(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        send = torch.full((world_size, 2), rank)
        sizes = [1] * world_size

        out = torch.zeros(world_size, 2, dtype=send.dtype)
        dist.all_to_all_single(out, send, sizes, sizes)
        self.assertEqual(out.tolist(), list(zip(range(world_size), range(world_size))))

    @spawn_threads_and_init_comms(world_size=4)
    def test_all_to_all_single_none(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        send = torch.full((world_size, 2), rank)

        out = torch.zeros(world_size, 2, dtype=send.dtype)
        dist.all_to_all_single(out, send)
        self.assertEqual(out.tolist(), list(zip(range(world_size), range(world_size))))


class TestCollectivesWithBaseClass(MultiThreadedTestCase):
    @property
    def world_size(self):
        return 4

    def setUp(self):
        os.environ["TORCH_DIST_INIT_BARRIER"] = "1"
        super().setUp()
        self._spawn_threads()

    def tearDown(self):
        super().tearDown()
        os.environ["TORCH_DIST_INIT_BARRIER"] = "0"

    def test_allgather(self):
        input_tensor = torch.ones(3, 3) * dist.get_rank()
        output_tensors = [
            torch.empty_like(input_tensor) for _ in range(self.world_size)
        ]
        dist.all_gather(output_tensors, input_tensor)
        for rank, out_tensor in enumerate(output_tensors):
            self.assertEqual(out_tensor, torch.ones(3, 3) * rank)

    def test_broadcast(self):
        input_tensor = torch.ones(3, 3) * dist.get_rank()
        for rank in range(self.world_size):
            cloned_input = input_tensor.clone()
            dist.broadcast(cloned_input, src=rank)
            self.assertEqual(cloned_input, torch.ones(3, 3) * rank)

    def test_scatter(self):
        if dist.get_rank() == 0:
            scatter_list = [torch.ones(3, 3) * rank for rank in range(self.world_size)]
        else:
            scatter_list = None
        output_tensor = torch.empty(3, 3)

        dist.scatter(output_tensor, scatter_list)
        self.assertEqual(output_tensor, torch.ones(3, 3) * dist.get_rank())

    def test_reduce_scatter(self):
        to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(self.world_size)]
        output_tensor = torch.empty(3, 3)

        dist.reduce_scatter(output_tensor, to_reduce_scatter)
        expected_tensor = torch.ones(3, 3) * dist.get_rank() * self.world_size
        self.assertEqual(output_tensor, expected_tensor)

        output_tensor = torch.empty(3, 3)
        dist.reduce_scatter(output_tensor, to_reduce_scatter, op=dist.ReduceOp.AVG)
        expected_tensor = torch.ones(3, 3) * dist.get_rank()
        self.assertEqual(output_tensor, expected_tensor)

    def test_broadcast_object_list(self):
        val = 99 if dist.get_rank() == 0 else None
        object_list = [val] * dist.get_world_size()
        print(f"{dist.get_rank()} -> {dist.get_world_size()}")

        dist.broadcast_object_list(object_list=object_list)
        self.assertEqual(99, object_list[0])

    def test_all_reduce(self):
        output = torch.ones(3, 3) * dist.get_rank()
        dist.all_reduce(output)
        res_num = ((0 + self.world_size - 1) * self.world_size) / 2
        self.assertEqual(output, torch.ones(3, 3) * res_num)

    def test_all_to_all(self):
        rank = self.rank
        world_size = self.world_size
        input_tensor_list = [
            torch.ones(3, 3) * x
            for x in range(rank * world_size, (rank + 1) * world_size)
        ]
        output_tensor_list = [torch.empty_like(tensor) for tensor in input_tensor_list]
        dist.all_to_all(output_tensor_list, input_tensor_list)
        expected_tensor_list = [
            torch.ones(3, 3) * x
            for x in range(rank, world_size * world_size, world_size)
        ]
        self.assertEqual(expected_tensor_list, output_tensor_list)

    def test_all_reduce_ops(self):
        tensor = torch.tensor([dist.get_rank() + 1])
        dist.all_reduce(tensor, op=ReduceOp.PRODUCT)
        expected = reduce(operator.mul, range(1, self.world_size + 1))
        self.assertEqual(expected, tensor.item())

        tensor = torch.tensor([dist.get_rank() + 1])
        dist.all_reduce(tensor, op=ReduceOp.MIN)
        self.assertEqual(1, tensor.item())

        tensor = torch.tensor([dist.get_rank() + 1])
        dist.all_reduce(tensor, op=ReduceOp.MAX)
        self.assertEqual(self.world_size, tensor.item())

        tensor = torch.tensor([dist.get_rank() + 1])
        dist.all_reduce(tensor, op=ReduceOp.BAND)
        expected = reduce(operator.and_, range(1, self.world_size + 1))
        self.assertEqual(expected, tensor.item())

        tensor = torch.tensor([dist.get_rank() + 1])
        dist.all_reduce(tensor, op=ReduceOp.BOR)
        expected = reduce(operator.or_, range(1, self.world_size + 1))
        self.assertEqual(expected, tensor.item())

        tensor = torch.tensor([dist.get_rank() + 1])
        dist.all_reduce(tensor, op=ReduceOp.BXOR)
        expected = reduce(operator.xor, range(1, self.world_size + 1))
        self.assertEqual(expected, tensor.item())

    def test_assert_equal_on_rank(self):
        # RNG is shared across threads. So instead of asserting on all threads
        # we only assert on rank 0
        self_tensor = torch.rand(3, 3)
        rank_0_tensor = self_tensor.clone()
        dist.broadcast(rank_0_tensor, src=0)
        self.assertEqualOnRank(rank_0_tensor, self_tensor, rank=0)
        self.assertNotEqualOnRank(rank_0_tensor, self_tensor, rank=1)

    def test_subpg(self):
        subpg0 = dist.new_group([0, 1])
        subpg1 = dist.new_group([2, 3])
        current_rank = dist.get_rank()
        output = torch.ones(3, 3) * current_rank

        # call all_reduce on subpg0 and subpg1 concurrently
        if current_rank in [0, 1]:
            dist.all_reduce(output, group=subpg0)
        else:
            dist.all_reduce(output, group=subpg1)

        if current_rank in [0, 1]:
            self.assertEqual(output, torch.ones(3, 3) * 1)
        else:
            self.assertEqual(output, torch.ones(3, 3) * 5)

    def test_using_pg_from_another_thread(self):
        def stuff_in_other_thread(pg):
            x = torch.rand(4, requires_grad=True)
            dist.all_reduce(x, group=pg)

        t = threading.Thread(target=stuff_in_other_thread, args=(dist.group.WORLD,))
        t.start()
        t.join()

    def test_gather(self):
        if dist.get_rank() == 0:
            gather_list = [torch.empty(3, 3) for _ in range(self.world_size)]
        else:
            gather_list = None
        input_tensor = torch.ones(3, 3) * dist.get_rank()

        dist.gather(input_tensor, gather_list)
        if dist.get_rank() == 0:
            for i in range(self.world_size):
                self.assertEqual(gather_list[i], torch.ones(3, 3) * i)

    def test_all_reduce_coalesced(self):
        t0 = torch.ones(3, 3) * dist.get_rank()
        t1 = torch.ones(3, 3) * dist.get_rank() * 2
        dist.all_reduce_coalesced([t0, t1])
        res_num = ((0 + self.world_size - 1) * self.world_size) / 2
        self.assertEqual(t0, torch.ones(3, 3) * res_num)
        self.assertEqual(t1, torch.ones(3, 3) * (res_num * 2))

    @skip_if_lt_x_gpu(1)
    def test_bwd_sees_fwd_pg(self):
        fwd_tid = threading.current_thread().ident

        class MyFunc(torch.autograd.Function):
            @staticmethod
            def forward(ctx, rank):
                result = rank * 2

                ctx.save_for_backward(result, rank)
                assert int(rank.item()) == dist.get_rank()
                return result

            @staticmethod
            def backward(ctx, grad_output):
                result, rank = ctx.saved_tensors
                bwd_tid = threading.current_thread().ident

                self.assertEqual(
                    fwd_tid,
                    bwd_tid,
                    f"bwd not running in the same thread a fwd for rank {rank.item()}",
                )
                self.assertTrue(dist.is_initialized())
                self.assertEqual(int(rank.item()), dist.get_rank())
                dist.all_reduce(result)
                self.assertEqual(int(result.item()), 12)  # (0 + 1 + 2 + 3) * 2

                return grad_output * result

        x = torch.tensor(
            [dist.get_rank()], dtype=torch.float, device="cuda", requires_grad=True
        )
        x = MyFunc.apply(x)
        x.sum().backward()


if __name__ == "__main__":
    run_tests()
