# Owner(s): ["oncall: distributed"]

import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TEST_WITH_DEV_DBG_ASAN,
)
from torch.utils.checkpoint import checkpoint


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


def get_cur_mem(rank, result, prefix):
    """Collect memory allocated values in a result dict in MB"""
    torch._C._cuda_clearCublasWorkspaces()
    result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)


class Model(nn.Module):
    def __init__(self, hidden_dim, with_fsdp=False, with_checkpoint=False):
        super().__init__()
        if with_fsdp:
            self.stem = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3),
                FSDP(nn.BatchNorm2d(64)),
                nn.ReLU(inplace=True),
            )
        else:
            self.stem = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
            )
        if with_fsdp:
            self.blocks = nn.Sequential(
                nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
                FSDP(nn.BatchNorm2d(hidden_dim)),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
                FSDP(nn.BatchNorm2d(hidden_dim)),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
                FSDP(nn.BatchNorm2d(hidden_dim)),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                nn.Flatten(),
            )
        else:
            self.blocks = nn.Sequential(
                nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                nn.Flatten(),
            )

        self.head = nn.Linear(hidden_dim, 10)
        self.with_checkpoint = with_checkpoint

    def forward(self, x):
        if self.with_checkpoint:
            return self.head(checkpoint(self.blocks, self.stem(x), use_reentrant=True))
        else:
            return self.head(self.blocks(self.stem(x)))


def create_model(with_fsdp, with_checkpoint, model_hidden_dim):
    torch.manual_seed(0)
    model = Model(model_hidden_dim, with_fsdp, with_checkpoint)
    if with_fsdp:
        model.stem = FSDP(model.stem)
        model.blocks = FSDP(model.blocks)
        model.head = FSDP(model.head)

    return model


class TestFSDPMemory(FSDPTest):
    @property
    def world_size(self):
        return 2

    def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
        gpu_id = self.rank
        world_size = self.world_size

        batch = torch.randn(size=(2, 3, 224, 224)).cuda()

        model = create_model(
            with_fsdp=True,
            with_checkpoint=with_checkpoint,
            model_hidden_dim=model_hidden_dim,
        )
        model = model.cuda()
        model = FSDP(model)

        # We enable momentum so that after the first iteration, the optimizer state is added
        # to the total memory used.
        criterion = nn.MSELoss()
        optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

        results = {}  # results of memory stats
        for iteration in range(iterations):
            get_cur_mem(gpu_id, results, f"iter {iteration}: start")

            out = model(batch)
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")

            out = sum(o.sum() for o in out[0])
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")

            fake_loss.backward()
            get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

            optimizer.step()
            get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

            # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
            model.zero_grad(set_to_none=True)
            get_cur_mem(gpu_id, results, f"iter {iteration}: done")

        def cmp(results, expected):
            ret = ""
            self.assertEqual(results.keys(), expected.keys())
            for k, v in results.items():
                exp = expected[k]
                if abs(exp - v) > 1:  # allow 1MB rounding differences
                    ret += f"{k}: got {v}, expected {exp}\n"
            return ret

        output = cmp(results, expected)
        self.assertEqual(output, "")

    @skip_if_lt_x_gpu(2)
    @parametrize("ckpt", ["no_ckpt", "ckpt"])
    def test_fsdp_memory(self, ckpt):
        # hidden_dim 128: model size ~4MB
        model_hidden_dim = 128

        model = create_model(
            with_fsdp=False, with_checkpoint=False, model_hidden_dim=model_hidden_dim
        ).cuda()
        model_size_mb = round(torch.cuda.memory_allocated() / 1024 / 1024)
        del model

        sharded_model_size_mb = int(model_size_mb / self.world_size)

        # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
        # test but on much bigger scale tests). We run 4 iterations here just in case it happens.
        iterations = 4

        expected = {}

        for iteration in range(iterations):
            if iteration == 0:
                # sharded model size + 1MB temp memory
                expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1
                # it is hard to calculate this memory size, get it from printed memory usage
                if ckpt == "ckpt":
                    expected[f"iter {iteration}: after fwd"] = 51
                    expected[f"iter {iteration}: after loss"] = 51
                else:
                    expected[f"iter {iteration}: after fwd"] = 340
                    expected[f"iter {iteration}: after loss"] = 340
                # sharded model size + sharded grad size + 1M temp memory
                expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1
            else:
                # after optimizer step in the first iteration, memory usage increased by
                # sharded_model_size_mb because of increased optimizer states memory usage
                expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1
                if ckpt == "ckpt":
                    expected[f"iter {iteration}: after fwd"] = (
                        51 + sharded_model_size_mb
                    )
                    expected[f"iter {iteration}: after loss"] = (
                        51 + sharded_model_size_mb
                    )
                else:
                    expected[f"iter {iteration}: after fwd"] = (
                        340 + sharded_model_size_mb
                    )
                    expected[f"iter {iteration}: after loss"] = (
                        340 + sharded_model_size_mb
                    )
                expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1

            # sharded model size + sharded grad size + optimizer states + 1M temp memory
            expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1
            # grad memory is claimed after setting grad = None
            # sharded model size + optimizer states + 1M temp memory
            expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1

        # Get the fsdp and checkpoint flags.
        with_ckpt = ckpt == "ckpt"

        self._dist_train(
            with_ckpt,
            expected,
            model_hidden_dim,
            iterations,
        )


instantiate_parametrized_tests(TestFSDPMemory)


if __name__ == "__main__":
    run_tests()
