# Owner(s): ["oncall: distributed"]
from copy import deepcopy

import torch
import torch.distributed.checkpoint as dist_cp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor import init_device_mesh, Replicate
from torch.distributed.checkpoint.default_planner import (
    DefaultLoadPlanner,
    DefaultSavePlanner,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    ShardingStrategy,
    StateDictType,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir


class SimpleModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.net1 = nn.Linear(5, 8)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(8, 4)
        self.net3 = nn.Linear(4, 12)

    def forward(self, x):
        x = F.relu(self.net1(x))
        x = F.relu(self.net2(x))
        x = F.relu(self.net3(x))
        return x

    def get_input(self):
        return torch.rand(4, 5, device="cuda")


class SimpleModelUneven(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.net1 = nn.Linear(5, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 15)
        self.net3 = nn.Linear(15, 30)
        self.net4 = nn.Linear(30, 5)

    def forward(self, x):
        x = F.relu(self.net1(x))
        x = F.relu(self.net2(x))
        x = F.relu(self.net3(x))
        x = F.relu(self.net4(x))
        return x

    def get_input(self):
        return torch.rand(4, 5, device="cuda")


class TestHSDPCheckpoint(DTensorTestBase):
    @property
    def backend(self):
        return "cpu:gloo,cuda:nccl"

    @with_comms
    @skip_if_lt_x_gpu(4)
    @with_temp_dir
    @parametrize("is_even_sharded_model", [True, False])
    def test_hsdp_checkpoint(self, is_even_sharded_model) -> None:
        CHECKPOINT_DIR = self.temp_dir
        simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven

        mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
        model = FSDP(
            simple_model().cuda(),
            sharding_strategy=ShardingStrategy.HYBRID_SHARD,
            device_mesh=mesh_2d,
        )
        optim = torch.optim.Adam(model.parameters(), lr=0.1)

        FSDP.set_state_dict_type(
            model,
            StateDictType.SHARDED_STATE_DICT,
        )
        state_dict = {"model": model.state_dict()}
        state_dict_to_save = deepcopy(state_dict)

        dist_cp.save(
            state_dict=state_dict_to_save,
            storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
            planner=DefaultSavePlanner(),
        )

        # Update the parameters so current model state_dict now be different from state_dict_to_save.
        model(model.get_input()).sum().backward()
        optim.step()

        # At this point, the current state dict is different from state_dict_to_save.
        for (k1, v1), (k2, v2) in zip(
            state_dict_to_save["model"].items(), model.state_dict().items()
        ):
            self.assertEqual(k1, k2)
            self.assertEqual(v1.device_mesh, v2.device_mesh)
            self.assertEqual(v1.placements, v2.placements)
            self.assertNotEqual(v1.to_local(), v2.to_local())

        dist_cp.load(
            state_dict=state_dict_to_save,
            storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
            planner=DefaultLoadPlanner(),
        )
        model.load_state_dict(state_dict_to_save["model"])

        state_dict_after_load = model.state_dict()
        # After loading, the current model state dict should be the same as state_dict_to_save.
        for (k1, v1), (k2, v2) in zip(
            state_dict_to_save["model"].items(), model.state_dict().items()
        ):
            self.assertEqual(k1, k2)
            self.assertEqual(v1.device_mesh, v2.device_mesh)
            self.assertEqual(v1.placements, v2.placements)
            self.assertEqual(v1.to_local(), v2.to_local())

    @with_comms
    @skip_if_lt_x_gpu(4)
    @with_temp_dir
    @parametrize("is_even_sharded_model", [True, False])
    def test_hsdp_fsdp_checkpoint_conversion(self, is_even_sharded_model) -> None:
        CHECKPOINT_DIR = self.temp_dir
        simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven

        # save the hsdp model state_dict
        mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
        hsdp_model = FSDP(
            simple_model().cuda(),
            sharding_strategy=ShardingStrategy.HYBRID_SHARD,
            device_mesh=mesh_2d,
        )
        FSDP.set_state_dict_type(
            hsdp_model,
            StateDictType.SHARDED_STATE_DICT,
        )
        hsdp_state_dict = {"model": hsdp_model.state_dict()}
        dist_cp.save_state_dict(
            state_dict=hsdp_state_dict,
            storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
            planner=DefaultSavePlanner(),
        )

        # initialize a fsdp model to load checkpoint into
        mesh_1d = init_device_mesh(self.device_type, (self.world_size,))
        fsdp_model = FSDP(
            simple_model().cuda(),
            device_mesh=mesh_1d,
        )
        FSDP.set_state_dict_type(
            fsdp_model,
            StateDictType.SHARDED_STATE_DICT,
        )
        fsdp_state_dict = {"model": fsdp_model.state_dict()}

        # at this point, the hsdp model parameters are different from fsdp model parameters.
        for (k1, v1), (k2, v2) in zip(
            hsdp_state_dict["model"].items(), fsdp_state_dict["model"].items()
        ):
            self.assertEqual(k1, k2)
            self.assertNotEqual(v1.device_mesh, v2.device_mesh)
            self.assertNotEqual(v1.placements, v2.placements)
            v1_all_gather = v1.redistribute(
                mesh_2d, placements=(Replicate(), Replicate())
            )
            v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),))
            self.assertNotEqual(v1_all_gather.to_local(), v2_all_gather.to_local())

        # load the fsdp state_dict from storage
        dist_cp.load_state_dict(
            state_dict=fsdp_state_dict,
            storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
            planner=DefaultLoadPlanner(),
        )
        fsdp_model.load_state_dict(fsdp_state_dict["model"])

        state_dict_after_load = fsdp_model.state_dict()
        # After loading, the current model state dict should be the same as hsdp_state_dict.
        for (k1, v1), (k2, v2) in zip(
            hsdp_state_dict["model"].items(), state_dict_after_load.items()
        ):
            self.assertEqual(k1, k2)
            self.assertNotEqual(v1.device_mesh, v2.device_mesh)
            self.assertNotEqual(v1.placements, v2.placements)
            v1_all_gather = v1.redistribute(
                mesh_2d, placements=(Replicate(), Replicate())
            )
            v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),))
            self.assertEqual(v1_all_gather.to_local(), v2_all_gather.to_local())


instantiate_parametrized_tests(TestHSDPCheckpoint)
if __name__ == "__main__":
    run_tests()
