# Owner(s): ["oncall: distributed"]

import sys
from typing import List
from unittest.mock import patch

import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root
from torch.distributed.fsdp._flat_param import HandleTrainingState
from torch.distributed.fsdp._runtime_utils import (
    _get_handle_to_prefetch,
    _get_training_state,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN


NUM_ITERS = 2
DECODER_PARAM_FQNS = [
    "decoder.layers.{index}.self_attn.in_proj_weight",
    "decoder.layers.{index}.self_attn.in_proj_bias",
    "decoder.layers.{index}.self_attn.out_proj.weight",
    "decoder.layers.{index}.self_attn.out_proj.bias",
    "decoder.layers.{index}.multihead_attn.in_proj_weight",
    "decoder.layers.{index}.multihead_attn.in_proj_bias",
    "decoder.layers.{index}.multihead_attn.out_proj.weight",
    "decoder.layers.{index}.multihead_attn.out_proj.bias",
    "decoder.layers.{index}.linear1.weight",
    "decoder.layers.{index}.linear1.bias",
    "decoder.layers.{index}.linear2.weight",
    "decoder.layers.{index}.linear2.bias",
    "decoder.layers.{index}.norm1.weight",
    "decoder.layers.{index}.norm1.bias",
    "decoder.layers.{index}.norm2.weight",
    "decoder.layers.{index}.norm2.bias",
    "decoder.layers.{index}.norm3.weight",
    "decoder.layers.{index}.norm3.bias",
]
ENCODER_PARAM_FQNS = [
    "encoder.layers.{index}.self_attn.in_proj_weight",
    "encoder.layers.{index}.self_attn.in_proj_bias",
    "encoder.layers.{index}.self_attn.out_proj.weight",
    "encoder.layers.{index}.self_attn.out_proj.bias",
    "encoder.layers.{index}.linear1.weight",
    "encoder.layers.{index}.linear1.bias",
    "encoder.layers.{index}.linear2.weight",
    "encoder.layers.{index}.linear2.bias",
    "encoder.layers.{index}.norm1.weight",
    "encoder.layers.{index}.norm1.bias",
    "encoder.layers.{index}.norm2.weight",
    "encoder.layers.{index}.norm2.bias",
]
TOTAL_NUM_PREFETCH_FOR_PRE = 12
TOTAL_NUM_PREFETCH_FOR_POST = 11
ENCODER_BEGIN_INDEX_FOR_PRE = 6
ENCODER_BEGIN_INDEX_FOR_POST = 5
ENCODER_PREFETCH_NUM = 5

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class TestBackwardPrefetch(FSDPTest):
    @property
    def world_size(self):
        return 2

    def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
        rank = self.rank
        orig_get_handle_to_prefetch = _get_handle_to_prefetch

        torch.manual_seed(0)
        policy = ModuleWrapPolicy(
            {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
        )
        model = FSDP(
            nn.Transformer(d_model=1024, nhead=8, device="cuda"),
            device_id=torch.cuda.current_device(),
            auto_wrap_policy=policy,
            use_orig_params=True,
            backward_prefetch=backward_prefetch,
        )
        optim = torch.optim.SGD(model.parameters(), lr=1e-2)

        # prepare input
        torch.manual_seed(rank + 1)
        src = torch.randn((10, 1, 1024), device="cuda")
        tgt = torch.randn((20, 1, 1024), device="cuda")

        # monkey patch
        all_handle_fqns: List[List[str]] = []

        def patched_get_handle_to_prefetch(*args, **kwargs):
            handle = orig_get_handle_to_prefetch(*args, **kwargs)

            self.assertEqual(
                len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)"
            )
            state = args[0]
            current_handle = args[1]
            training_state = _get_training_state(current_handle)
            if (
                training_state == HandleTrainingState.BACKWARD_PRE
                and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
            ) or (
                training_state == HandleTrainingState.BACKWARD_POST
                and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
            ):
                nonlocal all_handle_fqns
                # FQNs prefixed from the root module
                # state._exec_order_data.param_to_fqn
                fqns = _get_handle_fqns_from_root(state, handle)
                all_handle_fqns.append(fqns)
            return handle

        # flat params from prefetch handle should match
        # DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS
        with patch(
            "torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch",
            patched_get_handle_to_prefetch,
        ):
            for _ in range(NUM_ITERS):
                optim.zero_grad()
                loss = model(src, tgt).sum()
                loss.backward()
                optim.step()
                if backward_prefetch is None:
                    self.assertEqual(len(all_handle_fqns), 0)
                    continue
                elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
                    # state._exec_order_data.handles_post_forward_order
                    # equals forward order
                    #     encoder 0...5 -> decoder 0...5 -> root
                    # pre-backward hook order
                    #     root -> decoder 5...0 -> encoder 5...0
                    # prefetch order
                    #     decoder 5...0 -> encoder 5...0 -> None
                    # None: when current_handle=encoder 0,
                    #       _get_handle_to_prefetch returns None
                    # +1 is for the above None
                    encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE
                    self.assertEqual(
                        len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1
                    )
                elif backward_prefetch == BackwardPrefetch.BACKWARD_POST:
                    # state._exec_order_data.handles_post_forward_order
                    # equals forward order (same as BACKWARD_PRE)
                    #     encoder 0...5 -> decoder 0...5 -> root
                    # post-backward hook (AccumulateGrad) order
                    #     decoder 5, 4...0 -> encoder 5...0 -> root
                    # prefetch order
                    #     decoder 4...0 -> encoder 5...0 -> None -> None
                    # 1st None: when current_handle=encoder 0,
                    #           _get_handle_to_prefetch returns None
                    # 2nd None: when current_handle=root,
                    #           get decoder 5 inside _get_handle_to_prefetch
                    #           but not needed since decoder 5 is computed already
                    # +2 is for the above Nones
                    encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST
                    self.assertEqual(
                        len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2
                    )

                # ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch
                for ith_prefetch, fqns in enumerate(all_handle_fqns):
                    if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index:
                        layer_index = encoder_begin_index - 1 - ith_prefetch
                        self.assertEqual(
                            fqns,
                            [x.format(index=layer_index) for x in DECODER_PARAM_FQNS],
                        )
                    elif (
                        ith_prefetch >= encoder_begin_index
                        and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM
                    ):
                        layer_index = (
                            encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch
                        )
                        self.assertEqual(
                            fqns,
                            [x.format(index=layer_index) for x in ENCODER_PARAM_FQNS],
                        )
                    else:
                        self.assertTrue(fqns is None)

                all_handle_fqns = []

    @skip_if_lt_x_gpu(2)
    def test_backward_prefetch(self):
        # subtest reuse process group to shorten test time
        self.run_subtests(
            {
                "backward_prefetch": [
                    None,
                    BackwardPrefetch.BACKWARD_PRE,
                    BackwardPrefetch.BACKWARD_POST,
                ],
            },
            self._test_backward_prefetch,
        )

    def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch):
        self._dist_train(backward_prefetch)


if __name__ == "__main__":
    run_tests()
