# Owner(s): ["oncall: distributed"]

import contextlib
import sys
from copy import deepcopy
from functools import partial

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    offload_wrapper,
)
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    FullyShardedDataParallel as FSDP,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TEST_WITH_DEV_DBG_ASAN,
)
from torch.utils.checkpoint import checkpoint


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


_save_on_cpu_called = False


def get_patched_save_on_cpu():
    orig_save_on_cpu = (
        torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu
    )

    def patched_save_on_cpu(*args, **kwargs):
        global _save_on_cpu_called
        _save_on_cpu_called = True
        return orig_save_on_cpu(*args, **kwargs)

    return patched_save_on_cpu


@contextlib.contextmanager
def patch_save_on_cpu(new_save_on_cpu):
    orig_save_on_cpu = (
        torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu
    )
    torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = (
        new_save_on_cpu
    )
    try:
        yield
    finally:
        torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = (
            orig_save_on_cpu
        )


class TestFSDPCheckpoint(FSDPTest):
    class SequentialModule(nn.Module):
        def __init__(
            self,
            checkpoint_layer=False,
            offload_activations=False,
            wrap_fsdp=False,
            *fsdp_args,
            **fsdp_kwargs,
        ):
            torch.manual_seed(0)
            torch.cuda.manual_seed(0)
            super().__init__()
            l1 = nn.Linear(3, 3).cuda()
            l2 = nn.Linear(3, 3).cuda()
            l3 = nn.Linear(3, 3).cuda()

            if checkpoint_layer:
                if offload_activations:
                    ckpt_wrapper = offload_wrapper
                else:
                    ckpt_wrapper = checkpoint_wrapper

                l1 = ckpt_wrapper(l1)
                l2 = ckpt_wrapper(l2)
                l3 = ckpt_wrapper(l3)

            fsdp_wrapper = partial(
                _maybe_wrap_fsdp, *fsdp_args, wrap_fsdp=wrap_fsdp, **fsdp_kwargs
            )
            self.ffn = nn.Sequential(
                fsdp_wrapper(l1),
                fsdp_wrapper(l2),
                fsdp_wrapper(l3),
            )

        def forward(self, x):
            return self.ffn(x)

    def _verify_parity(self, losses, outputs, models):
        assert losses
        assert outputs
        assert models

        for l, o in zip(losses[1:], outputs[1:]):
            self.assertEqual(losses[0], l)
            self.assertEqual(outputs[0], o)

        # Verify grads
        ref_model = models[0]
        ref_grads = [p.grad for p in ref_model.parameters()]
        for m in models[1:]:
            grads = [p.grad for p in m.parameters()]
            for ref_g, g in zip(ref_grads, grads):
                self.assertEqual(ref_g, g)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
    )
    @parametrize("offload_activations", [True, False])
    @parametrize("use_orig_params", [False, True])
    def test_checkpoint_fsdp_wrapping(
        self,
        cpu_offload: CPUOffload,
        offload_activations: bool,
        use_orig_params: bool,
    ):
        # Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
        if offload_activations:
            wrapper_to_use = offload_wrapper
        else:
            wrapper_to_use = checkpoint_wrapper

        fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
        ckpt_sequential_wrapped_fsdp = wrapper_to_use(
            TestFSDPCheckpoint.SequentialModule(
                wrap_fsdp=True,
                **fsdp_kwargs,
            ),
        )
        # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
        inner_ckpt = TestFSDPCheckpoint.SequentialModule(
            checkpoint_layer=True,
            offload_activations=offload_activations,
            wrap_fsdp=True,
            **fsdp_kwargs,
        )

        baseline = TestFSDPCheckpoint.SequentialModule(
            wrap_fsdp=True,
            **fsdp_kwargs,
        )

        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.
        inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

        global _save_on_cpu_called
        models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]
        with patch_save_on_cpu(get_patched_save_on_cpu()):
            for i in range(2):
                losses = []
                outputs = []
                for m in models:
                    check_offload = m != baseline and i == 0 and offload_activations
                    if check_offload:
                        self.assertFalse(_save_on_cpu_called)
                    out = m(inp)
                    if check_offload:
                        self.assertTrue(_save_on_cpu_called)
                        _save_on_cpu_called = False
                    loss = out.sum()
                    loss.backward()
                    losses.append(loss)
                    outputs.append(out)

                self._verify_parity(losses, outputs, models)

        dist.barrier()

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
    )
    @parametrize("offload_activations", [True, False])
    @parametrize("use_orig_params", [False, True])
    def test_basic_checkpoint_end_to_end(
        self,
        cpu_offload: CPUOffload,
        offload_activations: bool,
        use_orig_params: bool,
    ):
        fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
        global _save_on_cpu_called
        with patch_save_on_cpu(get_patched_save_on_cpu()):
            seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
            # Runs FSDP with no checkpointing
            fsdp_only_seq = FSDP(deepcopy(seq), **fsdp_kwargs)
            # Runs checkpoint-wrapped FSDP
            if offload_activations:
                wrapper_to_use = offload_wrapper
            else:
                wrapper_to_use = checkpoint_wrapper

            checkpointed_fsdp = wrapper_to_use(
                FSDP(deepcopy(seq), **fsdp_kwargs),
            )
            # Runs FSDP-wrapped checkpointed module
            fsdp_wrapped_checkpoint = FSDP(
                wrapper_to_use(deepcopy(seq)),
                **fsdp_kwargs,
            )
            # Runs FSDP with manual calls to checkpoint.
            fsdp_call_checkpoint = FSDP(deepcopy(seq), **fsdp_kwargs)
            # note that reentrant-based checkpointing requires inputs to have grad
            # flag set.

            inp = torch.randn(
                10, 3, device=torch.cuda.current_device(), requires_grad=True
            )

            models = [
                fsdp_only_seq,
                checkpointed_fsdp,
                fsdp_wrapped_checkpoint,
                fsdp_call_checkpoint,
            ]
            # Ensure _save_on_cpu is not yet called
            self.assertFalse(_save_on_cpu_called)
            for i in range(6):
                losses = []
                outputs = []
                for m in models:
                    check_offload = (
                        m != fsdp_only_seq and i == 0 and offload_activations
                    )
                    if m == fsdp_call_checkpoint:
                        # _save_on_cpu should not be called yet
                        self.assertFalse(_save_on_cpu_called)
                        offload_ctx = (
                            get_patched_save_on_cpu()(pin_memory=True)
                            if offload_activations
                            else contextlib.nullcontext()
                        )
                        with offload_ctx:
                            out = checkpoint(m, inp, use_reentrant=True)
                    else:
                        # _save_on_cpu should not be called yet
                        self.assertFalse(_save_on_cpu_called)
                        out = m(inp)

                    if check_offload:
                        self.assertTrue(_save_on_cpu_called)
                    loss = out.sum()
                    loss.backward()
                    losses.append(loss)
                    outputs.append(out)
                    _save_on_cpu_called = False

                self._verify_parity(losses, outputs, models)

        dist.barrier()


instantiate_parametrized_tests(TestFSDPCheckpoint)


class CheckpointModule(nn.Module):
    def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
        super().__init__()
        self.seq = nn.Sequential(*[nn.Linear(100, 100) for _ in range(4)])
        self.checkpoint = checkpoint
        self.use_reentrant = use_reentrant

    def forward(self, x):
        return (
            checkpoint(self.seq, x, use_reentrant=self.use_reentrant)
            if self.checkpoint
            else self.seq(x)
        )


class ModelWithCheckpointSubmodule(nn.Module):
    def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
        super().__init__()
        self.l1 = nn.Linear(100, 100)
        self.s1 = CheckpointModule(checkpoint, use_reentrant)
        self.s2 = CheckpointModule(checkpoint, use_reentrant)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(100, 100)

    def forward(self, x):
        return self.l2(self.relu(self.s2(self.s1(self.l1(x)))))


class TestModel(nn.Module):
    def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
        super().__init__()
        self.l1 = nn.Linear(100, 100)
        self.relu = nn.ReLU()
        self.checkpoint1 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
        self.checkpoint2 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
        self.l2 = nn.Linear(100, 100)

    def forward(self, x):
        return self.l2(self.relu(self.checkpoint2(self.checkpoint1(self.l1(x)))))


class TestFSDPCheckpointSubmodule(FSDPTest):
    # TODO: grad value checks occasionally fails when use_reentrant = True
    @skip_if_lt_x_gpu(2)
    @parametrize("use_reentrant", [False])
    def test_checkpoint_submodule(self, use_reentrant: bool):
        model = TestModel(use_reentrant=use_reentrant).cuda()
        model_ac = deepcopy(model)

        for _, m in model_ac.named_modules():
            if isinstance(m, CheckpointModule):
                m.checkpoint = True

        self.assertTrue(model_ac.checkpoint1.s1.checkpoint)
        self.assertTrue(model_ac.checkpoint2.s2.checkpoint)

        fsdp_kwargs = {
            "device_id": torch.cuda.current_device(),
            "sharding_strategy": ShardingStrategy.NO_SHARD,
        }

        # Wrap no checkpointing model submodules with FSDP
        model.checkpoint1 = FSDP(module=model.checkpoint1, **fsdp_kwargs)
        model.checkpoint2 = FSDP(module=model.checkpoint2, **fsdp_kwargs)

        # Wrap checkpointing model submodules with FSDP
        model_ac.checkpoint1 = FSDP(module=model_ac.checkpoint1, **fsdp_kwargs)
        model_ac.checkpoint2 = FSDP(module=model_ac.checkpoint2, **fsdp_kwargs)

        x = torch.randn(2, 100, device="cuda")

        model(x).sum().backward()
        model_ac(x).sum().backward()

        for (n1, p1), (n2, p2) in zip(
            model.named_parameters(), model_ac.named_parameters()
        ):
            self.assertEqual(n1, n2)
            self.assertTrue(p1.grad.allclose(p2.grad))


instantiate_parametrized_tests(TestFSDPCheckpointSubmodule)


if __name__ == "__main__":
    run_tests()
