# Owner(s): ["oncall: distributed"]

import copy
import sys
from typing import Dict

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable import checkpoint, fully_shard, replicate
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp.api import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_dist_composable import (
    CompositeModel,
    CompositeParamModel,
    UnitModule,
)
from torch.testing._internal.common_distributed import (
    SaveForwardInputsModel,
    skip_if_lt_x_gpu,
)
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    run_tests,
    TEST_WITH_DEV_DBG_ASAN,
)


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)


if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class TestFSDPCheckpoint(FSDPTest):
    @property
    def world_size(self) -> int:
        return 2

    # TODO: Define `use_same_inputs_across_ranks` for now for BC since some
    # test model configs do not have a simple base model to compare against. In
    # those cases, we use the same inputs across ranks so that the averaged
    # gradient equals the local gradient to check for parity. This means that
    # the gradient reduction is unchecked.
    def _test_parity(
        self,
        base_model: nn.Module,
        test_model: nn.Module,
        inp_size: torch.Size,
        inp_device: torch.device,
        grad_to_none: bool,
        use_same_inputs_across_ranks: bool,
    ):
        LR = 0.01
        base_optim = torch.optim.Adam(base_model.parameters(), lr=LR)
        test_optim = torch.optim.Adam(test_model.parameters(), lr=LR)

        for _ in range(5):
            if use_same_inputs_across_ranks:
                torch.manual_seed(0)
            x = torch.randn(inp_size, device=inp_device)
            test_loss = test_model(x).sum()
            base_loss = base_model(x).sum()

            self.assertEqual(test_loss, base_loss)

            test_loss.backward()
            test_optim.step()
            test_optim.zero_grad(set_to_none=grad_to_none)

            base_loss.backward()
            base_optim.step()
            base_optim.zero_grad(set_to_none=grad_to_none)

    @skip_if_lt_x_gpu(2)
    def test_wrap_same_submodule(self):
        model = UnitModule(device=torch.device("cuda"))

        base_model = copy.deepcopy(model)

        test_model = copy.deepcopy(model)
        # compose checkpoint and fully_shard
        test_model.seq = checkpoint(test_model.seq)
        test_model.seq = fully_shard(
            test_model.seq,
            policy=ModuleWrapPolicy({nn.Linear}),
        )

        self.run_subtests(
            {
                "base_model": [base_model],
                "test_model": [test_model],
                "inp_size": [torch.Size((2, 100))],
                "inp_device": [torch.device("cuda")],
                "grad_to_none": [True, False],
                "use_same_inputs_across_ranks": [True],
            },
            self._test_parity,
        )

    def _test_checkpoint_fsdp_submodules(self):
        model = CompositeModel(device=torch.device("cuda"))

        base_model = copy.deepcopy(model)

        test_model = copy.deepcopy(model)
        test_model.u1 = fully_shard(test_model.u1, policy=None)
        test_model.u2 = fully_shard(test_model.u2)

        test_model.u1.seq = checkpoint(test_model.u1.seq)
        test_model.u2.seq = checkpoint(test_model.u2.seq)

        self.run_subtests(
            {
                "base_model": [base_model],
                "test_model": [test_model],
                "inp_size": [torch.Size((2, 100))],
                "inp_device": [torch.device("cuda")],
                "grad_to_none": [True, False],
                "use_same_inputs_across_ranks": [True],
            },
            self._test_parity,
        )

    @skip_if_lt_x_gpu(2)
    def test_checkpoint_fsdp_submodules_non_reentrant(self):
        self._test_checkpoint_fsdp_submodules()

    @skip_if_lt_x_gpu(2)
    def test_checkpoint_fully_shard_cast_forward_inputs(self):
        self.run_subtests(
            {
                "checkpoint_strict_submodule": [False, True],
            },
            self._test_checkpoint_fully_shard_cast_forward_inputs,
        )

    def _test_checkpoint_fully_shard_cast_forward_inputs(
        self, checkpoint_strict_submodule: bool
    ):
        forward_inputs: Dict[nn.Module, torch.Tensor] = {}
        fp16_mp = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
        fp32_mp = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)

        model = SaveForwardInputsModel(
            forward_inputs=forward_inputs, cast_forward_inputs=False
        ).cuda()
        x = torch.zeros(2, 100, device="cuda")

        fully_shard(model.c2, mixed_precision=fp16_mp)
        if checkpoint_strict_submodule:
            checkpoint(model.c2.l)
        else:
            checkpoint(model.c2)
        fully_shard(model, mixed_precision=fp32_mp)

        loss = model(x).sum()
        loss.backward()

        self.assertEqual(forward_inputs[model].dtype, torch.float32)
        self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
        # Notably, check that the recomputed forward preserves the right dtype
        self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)

    @skip_if_lt_x_gpu(2)
    def test_fully_shard_replicate_correct_replicate_params(self):
        model = CompositeParamModel(device=torch.device("cuda"))
        # Shard Linears within UnitModule
        fully_shard(model.u1, policy=ModuleWrapPolicy({nn.Linear}))
        fully_shard(model.u2, policy=ModuleWrapPolicy({nn.Linear}))
        # replicate the rest
        replicate(model)
        # Run fwd + bwd to initialize DDP
        inp = torch.randn(2, 100, device="cuda")
        model(inp).sum().backward()
        # Ensure replicate param names are as expected, i.e.
        # immediate parameters of model and parameters of model's non-UnitModule
        # submodules are replicated
        param_names = replicate.state(model)._param_names
        replicated_modules = [
            (name, mod)
            for (name, mod) in model.named_children()
            if mod not in [model.u1, model.u2]
        ]
        replicated_param_names = [
            f"{module_name}.{n}"
            for module_name, mod in replicated_modules
            for n, _ in mod.named_parameters()
        ]
        replicated_param_names.extend(
            [n for n, _ in model.named_parameters(recurse=False)]
        )
        self.assertEqual(set(param_names), set(replicated_param_names))

    @skip_if_lt_x_gpu(2)
    def test_checkpoint_fsdp_submodules_with_param(self):
        model = CompositeParamModel(device=torch.device("cuda"))

        base_model = copy.deepcopy(model)

        test_model = copy.deepcopy(model)
        test_model.u1.seq = checkpoint(test_model.u1.seq)
        test_model.u2.seq = checkpoint(test_model.u2.seq)
        test_model = fully_shard(test_model)

        self.run_subtests(
            {
                "base_model": [base_model],
                "test_model": [test_model],
                "inp_size": [torch.Size((2, 100))],
                "inp_device": [torch.device("cuda")],
                "grad_to_none": [True, False],
                "use_same_inputs_across_ranks": [True],
            },
            self._test_parity,
        )

    @skip_if_lt_x_gpu(2)
    def test_checkpoint_fsdp_submodules_with_param_no_shard(self):
        model = CompositeParamModel(device=torch.device("cuda"))

        base_model = copy.deepcopy(model)

        test_model = copy.deepcopy(model)
        test_model.u1.seq = checkpoint(test_model.u1.seq)
        test_model.u2.seq = checkpoint(test_model.u2.seq)
        test_model = fully_shard(test_model, strategy=ShardingStrategy.NO_SHARD)

        self.run_subtests(
            {
                "base_model": [base_model],
                "test_model": [test_model],
                "inp_size": [torch.Size((2, 100))],
                "inp_device": [torch.device("cuda")],
                "grad_to_none": [True, False],
                "use_same_inputs_across_ranks": [True],
            },
            self._test_parity,
        )

    @skip_if_lt_x_gpu(2)
    def test_composable_fsdp_replicate(self):
        # Verify how the APIs can be composed, e.g. if both `fully_shard` and
        # `replicate` are applied on the same module, it should raise exception.
        model = CompositeModel(device=torch.device("cuda"))
        fully_shard(model.l1)
        with self.assertRaisesRegex(RuntimeError, "Cannot apply .*replicate"):
            replicate(model.l1)
        replicate(model.l2)  # should not raise

    @skip_if_lt_x_gpu(2)
    def test_fully_shard_replicate_composability(self):
        """
        Tests composing ``fully_shard`` and ``replicate``. To save unit test
        time, we run the different configs in subtests.
        """
        self.run_subtests(
            {
                "config": [
                    "1fm,1r",
                    "1r,1fm",
                    "1r,1fa",
                    "1r1fm,1fm",
                    "1r1fa,1fm",
                    "1fm1fm,1r1r,1fm",
                ]
            },
            self._test_replicate_in_fully_shard,
        )

    def _test_replicate_in_fully_shard(self, config: str):
        """
        To interpret the config, each comma delineates a level in the module
        tree ordered bottom-up; 'r' means ``replicate``; 'f' means
        ``fully_shard``; 'a' means auto wrap; and 'm' means manual wrap.
        """
        # Set the seed to ensure that all ranks initialize the same model
        torch.manual_seed(0)
        if config == "1fm,1r":
            base_model = CompositeModel(device=torch.device("cuda"))
            test_model = copy.deepcopy(base_model)
            fully_shard(test_model.l1)
            replicate(test_model)
        elif config == "1r,1fm":
            base_model = CompositeParamModel(torch.device("cuda"))
            test_model = copy.deepcopy(base_model)
            replicate(test_model.u1)
            fully_shard(test_model)
        elif config == "1r,1fa":
            base_model = CompositeParamModel(torch.device("cuda"))
            test_model = copy.deepcopy(base_model)
            replicate(test_model.u1)
            fully_shard(test_model, policy=ModuleWrapPolicy({UnitModule}))
        elif config == "1r1fm,1fm":
            base_model = CompositeParamModel(torch.device("cuda"))
            test_model = copy.deepcopy(base_model)
            replicate(test_model.u1)
            fully_shard(test_model.u2)
            fully_shard(test_model)
        elif config == "1r1fa,1fm":
            base_model = CompositeParamModel(torch.device("cuda"))
            test_model = copy.deepcopy(base_model)
            replicate(test_model.u1)
            fully_shard(test_model.u2, policy=ModuleWrapPolicy({UnitModule}))
            fully_shard(test_model)
        elif config == "1fm1fm,1r1r,1fm":
            base_model = CompositeParamModel(torch.device("cuda"))
            test_model = copy.deepcopy(base_model)
            fully_shard(test_model.u1.seq)
            fully_shard(test_model.u2.seq)
            replicate(test_model.u1)
            replicate(test_model.u2)
            fully_shard(test_model)
        else:
            raise ValueError(f"Unknown config: {config}")
        # Apply data parallelism to the base model for parity since we apply
        # data parallelism to the test model
        replicate(base_model)

        # Set the seed to ensure that ranks get different input data
        torch.manual_seed(self.rank + 1)
        self._test_parity(
            base_model,
            test_model,
            torch.Size((2, 100)),
            torch.device("cuda"),
            True,
            False,
        )

    @skip_if_lt_x_gpu(2)
    def test_state_dict_fsdp_submodules(self):
        model = CompositeModel(device=torch.device("cuda"))

        full_shard_args = {"strategy": ShardingStrategy.FULL_SHARD}
        no_shard_args = {"strategy": ShardingStrategy.NO_SHARD}

        model.u1 = fully_shard(model.u1, **full_shard_args)
        model.u2 = fully_shard(model.u2, **no_shard_args)

        FSDP.set_state_dict_type(
            model,
            StateDictType.SHARDED_STATE_DICT,
        )

        state_dict = model.state_dict()
        for fqn, tensor in state_dict.items():
            if "u1" in fqn:
                self.assertIsInstance(tensor, ShardedTensor)
            elif "u2" in fqn:
                self.assertIsInstance(tensor, torch.Tensor)
        # Ensure that get_state_dict_type can still correctly get the settings.
        _ = FSDP.get_state_dict_type(model)


instantiate_parametrized_tests(TestFSDPCheckpoint)


if __name__ == "__main__":
    run_tests()
