# Owner(s): ["oncall: distributed"]
import contextlib
import itertools
import math
import sys
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp import (
    CPUOffload,
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp._common_utils import clean_tensor_name
from torch.distributed.fsdp._flat_param import FlatParameter
from torch.distributed.fsdp.fully_sharded_data_parallel import FLAT_PARAM
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
    CUDAInitMode,
    FSDPInitMode,
    FSDPTest,
    NestedWrappedModule,
    TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class TestUnshardParamsBase(FSDPTest):
    """
    This contains any methods common to both the sharded and non-sharded cases.
    """

    @property
    def device(self) -> torch.device:
        return torch.device("cuda", self.rank)

    def _test_unshard_params_writeback(
        self,
        writeback: bool,
        check_outer: bool,
        **fsdp_kwargs: Dict[str, Any],
    ):
        model = nn.Sequential(
            nn.Linear(5, 5, bias=False, device=self.device),
            nn.Linear(5, 3, bias=False, device=self.device),
        )
        model[0] = FSDP(model[0], **fsdp_kwargs)
        model = FSDP(model, **fsdp_kwargs)
        uses_sharded_strategy = model.sharding_strategy != ShardingStrategy.NO_SHARD
        offloading_params = model.cpu_offload.offload_params

        # Assumes depth-first `.parameters()`
        outer_param: Union[FlatParameter, nn.Parameter] = next(model.parameters())
        inner_param: Union[FlatParameter, nn.Parameter] = next(model[0].parameters())
        param_to_check = outer_param if check_outer else inner_param

        # Write a known value to all elements of the *sharded* parameter or
        # `FlatParameter` to check
        with torch.no_grad():
            param_to_check.zero_()
            param_to_check += self.rank + 2
        # Zero the *unsharded* parameters
        with FSDP.summon_full_params(model, writeback=writeback), torch.no_grad():
            for param in model.parameters():
                param.zero_()

        # Check the 0th singleton element of the sharded parameter to see if
        # the zeroing from inside the context persists
        param_elem_to_check = param_to_check[0]
        if param_elem_to_check.numel() > 1:
            # For `use_orig_params=True` and `NO_SHARD`, the parameter
            # preserves the original 2D shape, so we must access one more time
            param_elem_to_check = param_elem_to_check[0]
        if writeback or (not uses_sharded_strategy and not offloading_params):
            # When FSDP does not use a sharded strategy and is not offloading
            # parameters to CPU, it directly exposes the tensor storage that
            # serves as the unsharded source of truth, so the write is always
            # reflected regardless of `writeback`.
            self.assertEqual(param_elem_to_check, 0)
        else:
            self.assertEqual(param_elem_to_check, self.rank + 2)
        if offloading_params:
            cpu_device = torch.device("cpu")
            for param in model.parameters():
                self.assertEqual(param.device, cpu_device)

    def _get_test_unshard_params_writeback_config(self) -> Dict[str, List[Any]]:
        return {
            "writeback": [True, False],
            "check_outer": [True, False],
            "mixed_precision": [MixedPrecision(param_dtype=torch.float16), None],
            "cpu_offload": [
                CPUOffload(offload_params=False),
                CPUOffload(offload_params=True),
            ],
            "use_orig_params": [True, False],
        }

    def _test_unshard_params_param_data(
        self,
        rank0_only: bool,
        offload_to_cpu: bool,
        cpu_offload: CPUOffload,
        mixed_precision: Optional[MixedPrecision],
        use_orig_params: bool,
    ):
        local_model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            fsdp_kwargs={},
            deterministic=True,
        )
        # Apply FSDP such that the root module does not have FSDP applied,
        # while there are multiple FSDP root submodules (as proven later)
        fsdp_model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            fsdp_kwargs={
                "cpu_offload": cpu_offload,
                "mixed_precision": mixed_precision,
                "use_orig_params": use_orig_params,
            },
            deterministic=True,
        )
        self.assertFalse(isinstance(fsdp_model, FSDP))
        # Hard code the following names because getting them is non-trivial
        non_fsdp_managed_param_names = {
            "module.0.weight",
            "module.0.bias",
            "module.3.weight",
            "module.3.bias",
        }

        with FSDP.summon_full_params(
            fsdp_model,
            rank0_only=rank0_only,
            writeback=not rank0_only,
            offload_to_cpu=offload_to_cpu,
        ):
            if not rank0_only or self.rank == 0:
                for p1, (n2, p2) in zip(
                    local_model.parameters(), fsdp_model.named_parameters()
                ):
                    self.assertEqual(p1.shape, p2.shape)
                    if (
                        offload_to_cpu
                        and clean_tensor_name(n2) not in non_fsdp_managed_param_names
                    ):
                        self.assertEqual(torch.device("cpu"), p2.device)
                    else:
                        self.assertEqual(p1.device, p2.device)
                    self.assertEqual(
                        p1.dtype, p2.dtype
                    )  # even if FSDP uses mixed precision
                    self.assertEqual(p1, p2)
                    self.assertTrue(isinstance(p2, nn.Parameter))
            else:
                # Check that each `FlatParameter` has the sharded size as a
                # proxy for it being resharded
                for handle in traversal_utils._get_fsdp_handles(fsdp_model):
                    if handle.uses_sharded_strategy:
                        self.assertEqual(
                            handle.flat_param.shape, handle.flat_param._sharded_size
                        )
                    else:
                        self.assertEqual(
                            handle.flat_param.shape,
                            handle.flat_param._unpadded_unsharded_size,
                        )

        # Prove the number of FSDP roots after lazy initialization
        num_fsdp_roots = 0
        for fsdp_state in traversal_utils._get_fsdp_states(fsdp_model):
            num_fsdp_roots += fsdp_state._is_root
        self.assertGreater(num_fsdp_roots, 1)

    def _get_test_unshard_params_param_data_config(self) -> Dict[str, List[Any]]:
        return {
            "rank0_only": [False, True],
            "offload_to_cpu": [False, True],
            "cpu_offload": [
                CPUOffload(offload_params=False),
                CPUOffload(offload_params=True),
            ],
            "mixed_precision": [MixedPrecision(param_dtype=torch.float16), None],
            "use_orig_params": [True, False],
        }


class TestUnshardParams(TestUnshardParamsBase):
    @property
    def world_size(self) -> int:
        return 2

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_writeback(self):
        """Tests the ``writeback`` argument (using default for all others)."""
        self.run_subtests(
            self._get_test_unshard_params_writeback_config(),
            self._test_unshard_params_writeback,
        )

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_param_data(self):
        """
        Tests that parameters are exposed correctly for ``recurse=True`` and
        all other argument configs for a non-FSDP root module.
        """
        self.run_subtests(
            self._get_test_unshard_params_param_data_config(),
            self._test_unshard_params_param_data,
        )

    @skip_if_lt_x_gpu(2)
    def test_unshard_singleton_param_writeback(self):
        """
        Tests ``writeback=True`` for a singleton parameter, which includes
        testing that writing to padding does not persist.
        NOTE: This method depends on FSDP internals.
        """
        model = FSDP(nn.Linear(1, 1, bias=False, device=self.device))
        flat_param = model._handle.flat_param
        self.assertEqual(1, flat_param.numel())
        # Write a known value to the *sharded* `FlatParameter`
        with torch.no_grad():
            # For nonzero ranks, this write is to padding
            flat_param[0] = self.rank + 2
        with FSDP.summon_full_params(model, writeback=True):
            self.assertEqual(1, flat_param.numel())
            with torch.no_grad():
                flat_param.zero_()
        # NOTE: This checks that writes to padding did not persist, which is
        # *not* strictly required for correctness.
        if self.rank == 0:  # did not write to padding
            self.assertEqual(0, flat_param[0])
        else:  # wrote to padding
            self.assertEqual(self.rank + 2, flat_param[0])

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_respects_reshard(self):
        """
        Tests that unsharding parameters respects the expected reshard behavior
        between forward and backward as well as after backward.

        For mixed precision, we should *not* respect the reshard behavior
        because the ``summon_full_params()`` forces full precision, which uses
        a different all-gather tensor than the one already in memory and will
        not persist any modifications correctly.
        """
        self.run_subtests(
            {
                "rank0_only": [False, True],
                "offload_to_cpu": [False, True],
                "mixed_precision": [MixedPrecision(param_dtype=torch.float16), None],
                "use_orig_params": [False, True],
            },
            self._test_unshard_params_respects_reshard,
        )

    def _test_unshard_params_respects_reshard(
        self,
        rank0_only: bool,
        offload_to_cpu: bool,
        mixed_precision: Optional[MixedPrecision],
        use_orig_params: bool,
    ):
        """NOTE: This method depends on FSDP internals."""
        fsdp_kwargs = {
            "mixed_precision": mixed_precision,
            "use_orig_params": use_orig_params,
        }
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False, device=self.device), **fsdp_kwargs),
                nn.Linear(5, 3, bias=False, device=self.device),
            ),
            **fsdp_kwargs,
        )
        outer_flat_param = model._handle.flat_param
        inner_flat_param = model.module[0]._handle.flat_param
        # NOTE: This assumes uniform sharding with padding across ranks.
        expected_outer_flat_param_unsharded_numel = (
            outer_flat_param.numel() * self.world_size
        )

        def _get_unsharded_storage_size(flat_param: FlatParameter):
            return flat_param._full_param_padded.storage().size()

        # Validate the expected behavior: the root does not reshard after
        # forward; the non-root reshards after forward; and both reshard after
        # backward
        output = model(torch.zeros(5, device=self.device))
        self.assertEqual(
            expected_outer_flat_param_unsharded_numel,
            _get_unsharded_storage_size(outer_flat_param),
        )
        self.assertEqual(0, _get_unsharded_storage_size(inner_flat_param))
        output.sum().backward()
        self.assertEqual(0, _get_unsharded_storage_size(outer_flat_param))
        self.assertEqual(0, _get_unsharded_storage_size(inner_flat_param))

        # Check that with parameter unsharding in between forward and backward
        # as well as after backward, the reshard behavior matches
        output = model(torch.zeros(5, device=self.device))
        with FSDP.summon_full_params(
            model,
            rank0_only=rank0_only,
            writeback=not rank0_only,
            offload_to_cpu=offload_to_cpu,
        ):
            pass
        if mixed_precision is not None:
            # After forcing full precision, we must invalidate the existing
            # unsharded low-precision flat parameter since it will not persist
            # changes from the `summon_full_params()` context, so we cannot
            # respect the reshard behavior
            expected_outer_flat_param_unsharded_numel = 0
        self.assertEqual(
            expected_outer_flat_param_unsharded_numel,
            _get_unsharded_storage_size(outer_flat_param),
        )
        self.assertEqual(0, _get_unsharded_storage_size(inner_flat_param))
        output.sum().backward()
        with FSDP.summon_full_params(
            model,
            rank0_only=rank0_only,
            writeback=not rank0_only,
            offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(0, _get_unsharded_storage_size(outer_flat_param))
        self.assertEqual(0, _get_unsharded_storage_size(inner_flat_param))

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_recurse(self):
        """Tests the ``recurse`` argument (using default for all others)."""
        self.run_subtests(
            {
                "recurse": [False, True],
                "unshard_outer": [False, True],
                "mixed_precision": [MixedPrecision(param_dtype=torch.float16), None],
                "use_orig_params": [False, True],
            },
            self._test_unshard_params_recurse,
        )

    def _test_unshard_params_recurse(
        self,
        recurse: bool,
        unshard_outer: bool,
        mixed_precision: Optional[MixedPrecision],
        use_orig_params: bool,
    ):
        """NOTE: This method depends on FSDP internals."""
        fsdp_kwargs = {
            "mixed_precision": mixed_precision,
            "use_orig_params": use_orig_params,
        }
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False, device=self.device), **fsdp_kwargs),
                nn.Linear(5, 3, bias=False, device=self.device),
            ),
            **fsdp_kwargs,
        )
        # Hard code the numel values based on the model
        unsharded_inner_numel = 5 * 5
        unsharded_outer_numel = 5 * 3
        if use_orig_params:
            # Account for unsharded padding: since each `FlatParameter` only
            # has one original parameter, we only need to pad for divisibility
            # by world size and not address alignment
            if unsharded_inner_numel % self.world_size:
                unsharded_inner_numel += self.world_size - (
                    unsharded_inner_numel % self.world_size
                )
            if unsharded_outer_numel % self.world_size:
                unsharded_outer_numel += self.world_size - (
                    unsharded_outer_numel % self.world_size
                )
        # Round up the sharded numel to account for padding
        sharded_inner_numel = int(math.ceil(unsharded_inner_numel / self.world_size))
        sharded_outer_numel = int(math.ceil(unsharded_outer_numel / self.world_size))
        inner_flat_param = model.module[0]._handle.flat_param
        outer_flat_param = model._handle.flat_param
        self.assertEqual(sharded_inner_numel, inner_flat_param.numel())
        self.assertEqual(sharded_outer_numel, outer_flat_param.numel())
        expected_outer_numel = (
            unsharded_outer_numel if unshard_outer else sharded_outer_numel
        )
        expected_inner_numel = (
            unsharded_inner_numel
            if recurse or not unshard_outer
            else sharded_inner_numel
        )
        module_to_unshard = model if unshard_outer else model[0]
        with FSDP.summon_full_params(module_to_unshard, recurse=recurse):
            self.assertEqual(expected_outer_numel, outer_flat_param.numel())
            self.assertEqual(expected_inner_numel, inner_flat_param.numel())

    @skip_if_lt_x_gpu(2)
    def test_named_parameters_and_buffers(self):
        """
        Tests that ``named_parameters()`` and ``named_buffers()`` for a
        top-level FSDP-wrapped model matches their behavior for the equivalent
        non-wrapped module.
        """
        self.run_subtests(
            {"prefix": ["", "test_prefix"], "recurse": [False, True]},
            self._test_named_parameters_and_buffers,
        )

    def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool):
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        model.buffer = nn.Buffer(torch.ones(1))
        # Wrap the top-level with FSDP since `named_parameters()` and
        # `named_buffers` will contain FSDP prefixes if called on a non-FSDP
        # root module
        fsdp_model = FSDP(
            NestedWrappedModule.init(
                self.process_group,
                FSDPInitMode.NO_FSDP,
                CUDAInitMode.CUDA_BEFORE,
                deterministic=True,
            ),
            self.process_group,
        )
        fsdp_model.buffer = nn.Buffer(torch.ones(1))
        with FSDP.summon_full_params(fsdp_model):
            for call in ["named_parameters", "named_buffers"]:
                for (n1, p1), (n2, p2) in itertools.zip_longest(
                    getattr(fsdp_model, call)(prefix=prefix, recurse=recurse),
                    getattr(model, call)(prefix=prefix, recurse=recurse),
                ):
                    self.assertEqual(n1, n2)
                    self.assertEqual(p1, p2)

    @skip_if_lt_x_gpu(2)
    def test_with_grads_core(self):
        """
        Tests the core usage of``with_grads=True`` by comparing against DDP as
        the unsharded equivalent.
        """
        self.run_subtests(
            {
                "writeback": [False, True],
                "offload_to_cpu": [False, True],
                "sharding_strategy": [
                    ShardingStrategy.FULL_SHARD,
                    ShardingStrategy.SHARD_GRAD_OP,
                    ShardingStrategy.NO_SHARD,
                ],
                "use_orig_params": [True],
            },
            self._test_with_grads_core,
        )

    def _test_with_grads_core(
        self,
        writeback: bool,
        offload_to_cpu: bool,
        sharding_strategy: ShardingStrategy,
        use_orig_params: bool,
    ):
        def _check_grads(
            ddp_model: DDP,
            fsdp_model: FSDP,
            old_fsdp_grads: Optional[List[torch.Tensor]],
        ):
            """
            Checks that writes to the FSDP parameters' gradients persist or do
            not persist depending on ``writeback`` and the sharding strategy.
            The DDP model is used for checking gradient parity to ensure that
            FDSP all-gathers the correct gradient values.
            """
            WRITEBACK_FACTOR = 2
            with FSDP.summon_full_params(
                fsdp_model,
                writeback=writeback,
                offload_to_cpu=offload_to_cpu,
                with_grads=True,
            ):
                for (n1, p1), (n2, p2) in zip(
                    ddp_model.module.named_parameters(),
                    fsdp_model.named_parameters(),
                ):
                    self.assertEqual(n1, clean_tensor_name(n2))
                    assert p1.grad is not None
                    torch.testing.assert_close(p1.grad, p2.grad)
                    # Ensure that the tensor is not all zeros, which would
                    # mean that the multiplication is vacuous
                    assert torch.count_nonzero(p2.grad) > 0
                    p2.grad *= WRITEBACK_FACTOR
            new_fsdp_grads = [
                param.grad
                for param in fsdp_model.parameters()
                if param.grad is not None
            ]
            writeback_persists = writeback or (
                sharding_strategy == ShardingStrategy.NO_SHARD and not offload_to_cpu
            )
            for old_grad, new_grad in zip(old_fsdp_grads, new_fsdp_grads):
                if writeback_persists:
                    torch.testing.assert_close(old_grad * WRITEBACK_FACTOR, new_grad)
                else:
                    torch.testing.assert_close(old_grad, new_grad)
            if writeback_persists:
                # Modify the DDP gradients in the same way for parity
                for param in ddp_model.parameters():
                    param.grad *= WRITEBACK_FACTOR

        def _get_error_context(is_supported: bool):
            return (
                contextlib.nullcontext()
                if is_supported
                else self.assertRaises(NotImplementedError)
            )  # some configs are not implemented yet

        def _get_fsdp_grads(fsdp_model: FSDP, is_supported: bool):
            if is_supported:
                return [
                    param.grad.clone()
                    for param in fsdp_model.parameters()
                    if param.grad is not None
                ]
            return None  # unused

        is_supported = use_orig_params and not offload_to_cpu
        model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        ddp_model = DDP(model, device_ids=[self.rank])
        fsdp_model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
            fsdp_kwargs={
                "use_orig_params": use_orig_params,
                "sharding_strategy": sharding_strategy,
            },
        )
        with FSDP.summon_full_params(fsdp_model):
            for p1, p2 in zip(ddp_model.module.parameters(), fsdp_model.parameters()):
                assert torch.all(torch.isclose(p1, p2))

        # Check calling after backward
        inp = fsdp_model.get_input(torch.device("cuda"))
        ddp_out = ddp_model(*inp)
        fsdp_out = fsdp_model(*inp)
        ddp_out.sum().backward()
        fsdp_out.sum().backward()
        old_fsdp_grads = _get_fsdp_grads(fsdp_model, is_supported)
        with _get_error_context(is_supported):
            _check_grads(ddp_model, fsdp_model, old_fsdp_grads)

        # Check calling between forward and backward
        inp = fsdp_model.get_input(torch.device("cuda"))
        ddp_out = ddp_model(*inp)
        fsdp_out = fsdp_model(*inp)
        old_fsdp_grads = _get_fsdp_grads(fsdp_model, is_supported)
        with _get_error_context(is_supported):
            _check_grads(ddp_model, fsdp_model, old_fsdp_grads)

    @skip_if_lt_x_gpu(2)
    def test_with_grads_none_grads(self):
        """
        Tests that if all ranks' ``FlatParameter`` has ``None`` gradient, then
        each original parameter sees ``None`` gradient as well.
        """
        self.run_subtests(
            {
                "sharding_strategy": [
                    ShardingStrategy.FULL_SHARD,
                    ShardingStrategy.SHARD_GRAD_OP,
                    ShardingStrategy.NO_SHARD,
                ]
            },
            self._test_with_grads_none_grads,
        )

    def _test_with_grads_none_grads(self, sharding_strategy: ShardingStrategy):
        fsdp_model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
            fsdp_kwargs={
                "use_orig_params": True,
                "sharding_strategy": sharding_strategy,
            },
        )
        for fsdp_module in FSDP.fsdp_modules(fsdp_model):
            if fsdp_module._handle:
                assert fsdp_module._handle.flat_param.grad is None
        with FSDP.summon_full_params(fsdp_model, with_grads=True):
            for param in fsdp_model.parameters():
                self.assertTrue(param.grad is None)

    @skip_if_lt_x_gpu(2)
    def test_unshard_submodule(self):
        model = nn.Sequential(
            nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16)),
            nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16)),
        ).cuda()
        model = FSDP(model, auto_wrap_policy=ModuleWrapPolicy((nn.Sequential,)))
        with FSDP.summon_full_params(model[0]):
            # Check that the summoned module does not have its flat parameter
            for param_name, param in model[0].named_parameters():
                self.assertFalse(FLAT_PARAM in param_name)
            self.assertGreater(len(list(model[0].parameters())), 1)


class TestUnshardParamsNoShard(TestUnshardParamsBase):
    @property
    def world_size(self) -> int:
        return 1

    @skip_if_lt_x_gpu(1)
    def test_unshard_params_writeback_no_shard(self):
        """Tests the ``writeback`` argument (using default for all others)."""
        self.run_subtests(
            self._get_test_unshard_params_writeback_config(),
            self._test_unshard_params_writeback,
        )

    @skip_if_lt_x_gpu(1)
    def test_unshard_params_param_data_no_shard(self):
        """
        Tests that parameters are exposed correctly for ``recurse=True`` and
        all other argument configs for a non-FSDP root module.
        """
        config = self._get_test_unshard_params_param_data_config()
        # TODO: `offload_to_cpu=True` with `NO_SHARD` is not supported yet. See
        # `test_offload_to_cpu_no_shard_raises()`.
        config["offload_to_cpu"] = [False]
        self.run_subtests(
            config,
            self._test_unshard_params_param_data,
        )


class TestUnshardParamsErrors(TestUnshardParamsBase):
    @property
    def world_size(self) -> int:
        return 2

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_from_forward_raises(self):
        class MyModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = nn.Parameter(torch.zeros(5))

            def forward(self, fsdp_module):
                with fsdp_module.summon_full_params(fsdp_module):
                    pass

        model = FSDP(MyModule()).cuda(self.rank)
        with self.assertRaisesRegex(
            AssertionError, "Cannot manually unshard parameters during forward/backward"
        ):
            model(model)

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_from_backward_raises(self):
        model = FSDP(nn.Linear(2, 1, device=self.device))
        output = model(torch.ones(2, device=self.device))

        def invalid_backward_hook(*args, **kwargs):
            with FSDP.summon_full_params(model):
                pass

        self.assertTrue(output.requires_grad)
        output.register_hook(invalid_backward_hook)
        with self.assertRaisesRegex(
            AssertionError, "Cannot manually unshard parameters during forward/backward"
        ):
            output.backward()

    @skip_if_lt_x_gpu(2)
    def test_rank0_only_with_writeback_raises(self):
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
        )
        with self.assertRaisesRegex(NotImplementedError, "is not supported"):
            with FSDP.summon_full_params(
                nested_wrapped_module, rank0_only=True, writeback=True
            ):
                pass

    @skip_if_lt_x_gpu(2)
    def test_offload_to_cpu_no_shard_raises(self):
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            {"sharding_strategy": ShardingStrategy.NO_SHARD},
        )
        with self.assertRaisesRegex(NotImplementedError, "is not supported"):
            with FSDP.summon_full_params(
                nested_wrapped_module, rank0_only=True, writeback=True
            ):
                pass


if __name__ == "__main__":
    run_tests()
