# Owner(s): ["oncall: distributed"]

from copy import deepcopy

import torch
import torch.distributed.checkpoint as dcp
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.default_planner import (
    DefaultLoadPlanner,
    DefaultSavePlanner,
)
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    parallelize_module,
    RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    MLPModule,
    skip_if_lt_x_gpu,
    with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir


class UnevenShardedModel(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        torch.manual_seed(5)
        self.net1 = torch.nn.Linear(5, 10, device=device)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 15, device=device)
        self.net3 = torch.nn.Linear(15, 1, device=device)

    def forward(self, x):
        return self.net3(self.net2(self.relu(self.net1(x))))


class TestTpCheckpoint(DTensorTestBase):
    @with_comms
    @skip_if_lt_x_gpu(2)
    @with_temp_dir
    def test_tp_checkpoint(self):
        CHECKPOINT_DIR = self.temp_dir
        mesh_shpe = (self.world_size,)
        tp_mesh = init_device_mesh(self.device_type, mesh_shpe)

        # create model and move it to GPU with id rank
        model = MLPModule(self.device_type).cuda(self.rank)
        # Parallelize the module based on the given Parallel Style.
        parallelize_plan = {
            "net1": ColwiseParallel(),
            "net2": RowwiseParallel(),
        }
        model = parallelize_module(model, tp_mesh, parallelize_plan)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.25)
        original_state_dict = deepcopy(model.state_dict())

        dcp.save(
            state_dict=original_state_dict,
            storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR),
            planner=DefaultSavePlanner(),
        )

        # Update the parameters so model.state_dict() will be different from original_state_dict.
        torch.manual_seed(0)
        inp = torch.rand(20, 10).cuda(self.rank)
        output = model(inp)
        output.sum().backward()
        optimizer.step()
        state_dict = model.state_dict()

        # ensure the current model parameters are different from original_state_dict before loading from checkpoint
        for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
            self.assertNotEqual(param1.to_local(), param2.to_local())

        dcp.load(
            state_dict=state_dict,
            storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
            planner=DefaultLoadPlanner(),
        )

        # now load from checkpoint to check current model parameters are the same as original_state_dict
        for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
            self.assertEqual(param1.to_local(), param2.to_local())

    @with_comms
    @skip_if_lt_x_gpu(2)
    @with_temp_dir
    def test_tp_checkpoint_load_on_meta_device(self):
        CHECKPOINT_DIR = self.temp_dir
        mesh_shpe = (self.world_size,)
        tp_mesh = init_device_mesh(self.device_type, mesh_shpe)

        # create model and move it to GPU with id rank
        model = UnevenShardedModel(self.device_type).cuda(self.rank)
        # Parallelize the module based on the given Parallel Style.
        parallelize_plan = {
            "net1": ColwiseParallel(),
            "net2": RowwiseParallel(),
            "net3": ColwiseParallel(),
        }
        model = parallelize_module(model, tp_mesh, parallelize_plan=parallelize_plan)
        original_state_dict = {
            "model": model.state_dict(),
        }

        dcp.save(
            state_dict=original_state_dict,
            storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR),
        )

        model2 = parallelize_module(
            UnevenShardedModel("meta"), tp_mesh, parallelize_plan=parallelize_plan
        )
        model2_sd_before_load = model2.state_dict()
        state_dict_to_load = {"model": model2_sd_before_load}

        dcp.load(
            state_dict=state_dict_to_load,
            storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
        )
        # We need to make sure state_dict_to_load["model"] is the same as state_dict_after_load["model"],
        # since we are doing in-place loading.
        self.assertTrue(state_dict_to_load["model"] is model2_sd_before_load)

        model2.load_state_dict(state_dict_to_load["model"], assign=True)
        state_dict_after_load = {"model": model2.state_dict()}

        self.assertEqual(
            len(original_state_dict["model"]), len(state_dict_to_load["model"])
        )
        self.assertEqual(
            len(original_state_dict["model"]), len(state_dict_after_load["model"])
        )

        for name, param in original_state_dict["model"].items():
            param_to_load = state_dict_to_load["model"][name]
            param_after_load = state_dict_after_load["model"][name]

            # we need to explicitly check the device is not meta as the assertEqual check
            # currently doesn't handle DTensor with meta device.
            self.assertTrue(not param_to_load.is_meta)
            self.assertTrue(not param_after_load.is_meta)
            self.assertEqual(param.to_local(), param_to_load.to_local())
            self.assertEqual(param.to_local(), param_after_load.to_local())


if __name__ == "__main__":
    run_tests()
