# Owner(s): ["oncall: distributed"]

import os
from copy import deepcopy

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import DTensor
from torch.testing._internal.common_distributed import (
    MultiProcessTestCase,
    skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests


class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(2, 2)
        self.fc2 = nn.Linear(2, 2)
        self.fc3 = nn.Linear(2, 2)

    def forward(self, x):
        return self.fc3(self.fc2(self.fc1(x)))


class ReplicateStateDictTest(MultiProcessTestCase):
    def setUp(self) -> None:
        super().setUp()
        self._spawn_processes()

    def tearDown(self):
        super().tearDown()
        try:
            os.remove(self.file_name)
        except OSError:
            pass

    def _init_pg(self):
        dist.init_process_group(
            backend="gloo",
            rank=self.rank,
            world_size=self.world_size,
            store=dist.FileStore(self.file_name, self.world_size),
        )

    def _check_state_dict_parity(self, sd_1, sd_2):
        for k1, k2 in zip(sd_1.keys(), sd_2.keys()):
            self.assertEqual(k1, k2)

        for v1, v2 in zip(sd_1.values(), sd_2.values()):
            self.assertEqual(v1, v2)

    def test_replicate_single_module_save_load(self):
        """
        Tests that replicate() on a single module state_dict
        matches local module state_dict.
        """
        self._init_pg()
        model = Net()
        replicate_model = replicate(deepcopy(model))
        local_sd = model.state_dict()
        ddp_sd = replicate_model.state_dict()
        self._check_state_dict_parity(local_sd, ddp_sd)

    def test_replicate_non_root_multiple_save_load(self):
        """
        Tests tha replicate() on multiple submodules matches
        local module state_dict.
        """
        self._init_pg()
        model = Net()
        replicate_model = deepcopy(model)
        replicate(replicate_model.fc1)
        replicate(replicate_model.fc2)
        replicate(replicate_model.fc3)

        local_sd = model.state_dict()
        ddp_sd = replicate_model.state_dict()
        self._check_state_dict_parity(local_sd, ddp_sd)


class ReplicateTest(MultiProcessTestCase):
    @property
    def world_size(self) -> int:
        return 2

    def setUp(self) -> None:
        super().setUp()
        self._spawn_processes()

    def tearDown(self):
        super().tearDown()
        try:
            os.remove(self.file_name)
        except OSError:
            pass

    def _init_pg(self):
        dist.init_process_group(
            backend="gloo",
            rank=self.rank,
            world_size=self.world_size,
            store=dist.FileStore(self.file_name, self.world_size),
        )

    def _compare_module(self, mod, replicate_mod):
        local_batch_size = 1
        global_batch_size = self.world_size * local_batch_size
        input = torch.randn(global_batch_size, 2)
        target = torch.randn(global_batch_size, 2)

        def step_model(model, input, target):
            model.train()
            output = model(input)
            loss = F.mse_loss(output, target.to(output.device))
            loss.backward()
            for param in model.parameters():
                with torch.no_grad():
                    param -= param.grad
                param.grad = None

        for iteration in range(2):
            step_model(mod, input, target)
            step_model(
                replicate_mod,
                input[
                    self.rank * local_batch_size : (self.rank + 1) * local_batch_size
                ],
                target[
                    self.rank * local_batch_size : (self.rank + 1) * local_batch_size
                ],
            )

            self.assertEqual(
                len(list(mod.parameters())),
                len(list(replicate_mod.parameters())),
            )
            for i, j in zip(mod.parameters(), replicate_mod.parameters()):
                self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)

            # Shuffle the input so that DDP input is different
            torch.manual_seed(iteration)
            input = input[torch.randperm(global_batch_size)]

    def test_replicate_single_module(self):
        self._init_pg()
        model = Net()
        replicate_model = replicate(deepcopy(model))
        self._compare_module(model, replicate_model)

    @skip_if_lt_x_gpu(2)
    def test_replicate_move_args_kwargs_to_device(self):
        class MyNet(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = nn.Linear(2, 2)

            def forward(self, inp, *, kwarg=None):
                if kwarg is not None:
                    inp = inp @ kwarg
                return self.a(inp)

        self._init_pg()
        torch.cuda.set_device(self.rank)
        model = MyNet().cuda()
        replicate(model, device_id=torch.cuda.current_device())
        # CPU input ensures replicate can move arg and kwargs to device.
        a, b = torch.randn(2, 2), torch.randn(2, 2)
        model(a, kwarg=b).sum().backward()

    @skip_if_lt_x_gpu(2)
    def test_replicate_ignore_module(self):
        self._init_pg()
        torch.cuda.set_device(self.rank)
        # Seed ensures diff input and thus different local grads across ranks.
        torch.manual_seed(self.rank)
        torch.cuda.manual_seed(self.rank)
        model = Net().cuda()
        replicate(model, ignored_modules=[model.fc1])
        # CPU input ensures that replicate can move input to GPU as DDP does.
        inp = torch.randn(5, 2, device="cuda") * (self.rank + 1)
        out = model(inp) * 10
        out.sum().backward()
        # FC1 grads should not be synchronized, FC2 and 3 should be.
        fc1_grad = model.fc1.weight.grad
        tensor_list = [torch.zeros_like(fc1_grad) for _ in range(dist.get_world_size())]
        dist.all_gather(tensor_list, fc1_grad)
        grad, rest = tensor_list[0], tensor_list[1:]
        for g in rest:
            self.assertNotEqual(grad, g)

        for dp_grad in [model.fc2.weight.grad, model.fc3.weight.grad]:
            tensor_list = [
                torch.zeros_like(dp_grad) for _ in range(dist.get_world_size())
            ]
            dist.all_gather(tensor_list, dp_grad)
            grad, rest = tensor_list[0], tensor_list[1:]
            for g in rest:
                self.assertEqual(grad, g)

    def test_replicate_multi_module(self):
        self._init_pg()
        model = Net()
        replicate_model = deepcopy(model)
        replicate(replicate_model.fc1)
        replicate(replicate_model.fc2)
        replicate(replicate_model.fc3)
        self._compare_module(model, replicate_model)

    def test_replicate_with_kwargs(self):
        self._init_pg()
        model = Net()
        replicate_model = replicate(
            deepcopy(model), bucket_cap_mb=1, gradient_as_bucket_view=True
        )
        self._compare_module(model, replicate_model)

    @skip_if_lt_x_gpu(2)
    def test_replicate_device_id(self):
        self._init_pg()
        model = Net()
        model_cuda = deepcopy(model).cuda()
        model_cuda2 = deepcopy(model_cuda)
        replicate(model, device_id=torch.device("cpu"))
        # DDP instance is attached in first pre forward
        model(torch.randn(2, 2))
        replicate_ddp_weakref = replicate.state(model)._ddp_weakref()
        # Should be None for CPU training
        self.assertEqual(None, replicate_ddp_weakref.device_ids)

        replicate(model_cuda, device_id=torch.device(torch.cuda.current_device()))
        # DDP instance is attached in first pre forward
        model_cuda(torch.randn(2, 2))
        replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref()
        self.assertEqual([0], replicate_ddp_weakref.device_ids)
        # Pass in int as device_id
        replicate(model_cuda2, device_id=int(torch.cuda.current_device()))
        # DDP instance is attached in first pre forward
        model_cuda2(torch.randn(2, 2))
        replicate_ddp_weakref = replicate.state(model_cuda2)._ddp_weakref()
        self.assertEqual([0], replicate_ddp_weakref.device_ids)

    def test_replicate_wrong_device_id_type(self):
        self._init_pg()
        model = Net()
        with self.assertRaisesRegex(
            RuntimeError, "Expected device_id to be int or torch.device"
        ):
            replicate(model, device_id=[torch.device("cpu")])


class ReplicateFullyShardInit(ReplicateTest):
    @skip_if_lt_x_gpu(2)
    def test_replicate_fully_shard_init(self):
        class ToyModel(nn.Module):
            def __init__(self, dim: int):
                super().__init__()
                self.linears = nn.Sequential(
                    nn.Linear(dim, dim, bias=False),
                    nn.Linear(dim, dim, bias=False),
                    nn.Linear(dim, dim, bias=False),
                )
                self.proj = nn.Linear(dim, dim, bias=False)

            def forward(self, x: torch.Tensor):
                y = self.linears(x)
                y = self.proj(y)
                return y

        self._init_pg()
        torch.cuda.set_device(self.rank)
        dim = 3
        bz = 2
        model = ToyModel(dim).cuda()
        for linear in model.linears:
            fully_shard(linear)
        fully_shard(model.linears)
        replicate(model, device_id=torch.cuda.current_device())
        for linear in model.linears:
            self.assertTrue(isinstance(linear.weight, DTensor))
        inp = torch.rand(bz, dim)
        # trigger lazy init
        model(inp).sum()
        for linear in model.linears:
            self.assertTrue(isinstance(linear.weight, DTensor))


if __name__ == "__main__":
    run_tests()
