# Owner(s): ["oncall: distributed"]

import sys

import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils
from torch import distributed as dist
from torch.distributed.fsdp import (
    CPUOffload,
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
    CUDAInitMode,
    FSDPInitMode,
    FSDPTest,
    NestedWrappedModule,
)
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    run_tests,
    TEST_WITH_DEV_DBG_ASAN,
)


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class TestPureFP16(FSDPTest):
    @property
    def world_size(self):
        # Test fails due to inaccuracies when using more than 4 GPUs
        return min(4, super().world_size)

    @skip_if_lt_x_gpu(2)
    def test_pure_fp16_training(self):
        """Tests pure FP16 training, including when the parameter's dtype is
        changed after FSDP initialization and before training."""
        self.run_subtests(
            {
                "cpu_offload": [
                    CPUOffload(offload_params=True),
                    CPUOffload(offload_params=False),
                ]
            },
            self._test_pure_fp16_training,
        )

    def _test_pure_fp16_training(self, cpu_offload: CPUOffload):
        self._test_fsdp_parity(
            NestedWrappedModule,
            FSDPInitMode.RECURSIVE,
            cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
            # Run one iteration to avoid NaN without a gradient scaler
            num_iters=1,
            cpu_offload=cpu_offload,
            use_pure_fp16=True,
        )

    @skip_if_lt_x_gpu(2)
    def test_fp16_dtypes(self):
        """
        Tests that both user-facing parameter/gradient dtypes and internal
        saved dtype attributes are as expected when using an FP16 model
        possibly with explicit mixed precision enabled.
        """
        self.run_subtests(
            {
                "to_half_before_fsdp_init": [False, True],
                "use_orig_params": [False, True],
                "mixed_precision": [
                    MixedPrecision(),
                    MixedPrecision(
                        param_dtype=torch.float16,
                        reduce_dtype=torch.float32,
                    ),
                    MixedPrecision(
                        param_dtype=torch.float32,
                    ),
                ],
            },
            self._test_fp16_dtypes,
        )

    def _test_fp16_dtypes(
        self,
        to_half_before_fsdp_init: bool,
        use_orig_params: bool,
        mixed_precision: MixedPrecision,
    ):
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_NEVER,
            {},
        )
        fsdp_kwargs = {
            "use_orig_params": use_orig_params,
            "device_id": torch.cuda.current_device(),
            "mixed_precision": mixed_precision,
        }
        if to_half_before_fsdp_init:
            model = model.half()
        fsdp_model = FSDP(model, **fsdp_kwargs)
        if not to_half_before_fsdp_init:
            fsdp_model = fsdp_model.half()
        for param in fsdp_model.parameters():
            self.assertEqual(param.dtype, torch.float16)
        inp = tuple(
            t.half() if torch.is_tensor(t) else t
            for t in fsdp_model.module.get_input(torch.device("cuda"))
        )
        out = fsdp_model(*inp)
        out.sum().backward()

        # Check handle dtype attributes
        for handle in traversal_utils._get_fsdp_handles(fsdp_model):
            self.assertEqual(handle.flat_param.dtype, torch.float16)
            self.assertEqual(handle.flat_param.grad.dtype, torch.float16)
            self.assertEqual(handle._orig_param_dtype, torch.float16)
            # Specifying `mixed_precision` takes precedence over the model
            # dtype for both `param_dtype` and `reduce_dtype`
            if mixed_precision.param_dtype is not None:
                self.assertEqual(
                    handle._fwd_bwd_param_dtype, mixed_precision.param_dtype
                )
            else:
                self.assertEqual(handle._fwd_bwd_param_dtype, torch.float16)
            if mixed_precision.reduce_dtype is not None:
                self.assertEqual(handle._reduce_dtype, mixed_precision.reduce_dtype)
            elif (
                mixed_precision.reduce_dtype is None
                and mixed_precision.param_dtype is not None
            ):
                # Special case: infer reduce dtype from parameter dtype
                self.assertEqual(handle._reduce_dtype, mixed_precision.param_dtype)
            else:
                self.assertEqual(handle._reduce_dtype, torch.float16)

        # Check parameter/gradient dtypes
        for param in fsdp_model.parameters():
            self.assertEqual(param.dtype, torch.float16)
            if param.grad is not None:
                self.assertEqual(param.grad.dtype, torch.float16)


instantiate_parametrized_tests(TestPureFP16)

if __name__ == "__main__":
    run_tests()
