# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]

import os
import unittest
from functools import wraps
from typing import Any, Callable, Dict, Tuple

import numpy as np

import torch
from torch import nn
from torch.distributed._tensor import (
    DeviceMesh,
    distribute_module,
    distribute_tensor,
    Replicate,
    Shard,
)
from torch.testing._internal.common_utils import run_tests, TestCase


# wrapper to check xla test requirements
def with_xla(func: Callable) -> Callable:
    assert func is not None

    @wraps(func)  # pyre-ignore[6]
    def wrapper(
        self, *args: Tuple[object], **kwargs: Dict[str, Any]  # type: ignore[misc]
    ) -> None:
        # TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
        os.environ["XLA_USE_SPMD"] = "1"
        try:
            import torch_xla  # type:ignore[import]  # noqa: F401
        except ImportError as exc:
            raise unittest.SkipTest("torch_xla is not installed.") from exc
        self.device_type = "xla"
        func(self, *args, **kwargs)  # type: ignore[misc]
        os.environ["XLA_USE_SPMD"] = "0"

    return wrapper


class DTensorXLAIntegrationTest(TestCase):
    class SimpleLinear(nn.Module):
        def __init__(self) -> None:
            super(DTensorXLAIntegrationTest.SimpleLinear, self).__init__()
            self.fc1 = nn.Linear(128, 64)
            self.relu = nn.ReLU()
            self.fc2 = nn.Linear(64, 1)

        def forward(self, x):
            y = self.relu(self.fc1(x))
            z = self.fc2(y)
            return z

    @with_xla
    def test_xla_distribute_tensor_1d_shard(self):
        import torch_xla.runtime as xr  # type:ignore[import]

        device_count = xr.global_runtime_device_count()
        if device_count > 1:
            device_mesh = DeviceMesh("xla", list(range(device_count)))
            shard_spec = [Shard(0)]

            for requires_grad in [True, False]:
                tensor_to_shard = torch.randn(
                    3 * device_count, 3, requires_grad=requires_grad
                )
                dist_tensor = distribute_tensor(
                    tensor_to_shard, device_mesh, shard_spec
                )
                # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
                assert type(dist_tensor).__name__ == "XLAShardedTensor"
                global_tensor = dist_tensor.global_tensor  # type:ignore[attr-defined]
                self.assertEqual(
                    global_tensor.size(), torch.Size([3 * device_count, 3])
                )
                local_tensor = dist_tensor.local_shards[0].data
                self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
                if requires_grad:
                    self.assertTrue(dist_tensor.global_tensor.requires_grad)
                    self.assertTrue(dist_tensor.is_leaf)

    @with_xla
    def test_xla_distribute_tensor_1d_replicate(self):
        import torch_xla.runtime as xr  # type:ignore[import]

        device_count = xr.global_runtime_device_count()
        device_mesh = DeviceMesh("xla", list(range(device_count)))
        shard_spec = [Replicate()]

        for requires_grad in [True, False]:
            tensor_to_shard = torch.randn(
                3 * device_count, 3, requires_grad=requires_grad
            )
            dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
            # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
            assert type(dist_tensor).__name__ == "XLAShardedTensor"
            global_tensor = dist_tensor.global_tensor  # type:ignore[attr-defined]
            self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3]))
            local_tensor = dist_tensor.local_shards[0].data
            self.assertEqual(local_tensor.size(), torch.Size([3 * device_count, 3]))
            if requires_grad:
                self.assertTrue(dist_tensor.global_tensor.requires_grad)
                self.assertTrue(dist_tensor.is_leaf)

    @with_xla
    def test_xla_distribute_tensor_2d(self):
        import torch_xla.runtime as xr  # type:ignore[import]

        device_count = xr.global_runtime_device_count()
        if device_count > 1:
            device_mesh = DeviceMesh(
                "xla", np.array(range(device_count)).reshape(2, device_count // 2)
            )
            shard_spec = [Replicate(), Shard(0)]

            for requires_grad in [True, False]:
                tensor_to_shard = torch.randn(
                    3 * device_count // 2, 3, requires_grad=requires_grad
                )
                dist_tensor = distribute_tensor(
                    tensor_to_shard, device_mesh, shard_spec
                )
                # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
                assert type(dist_tensor).__name__ == "XLAShardedTensor"
                global_tensor = dist_tensor.global_tensor  # type:ignore[attr-defined]
                self.assertEqual(
                    global_tensor.size(), torch.Size([3 * device_count // 2, 3])
                )
                local_tensor = dist_tensor.local_shards[0].data
                self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
                if requires_grad:
                    self.assertTrue(dist_tensor.global_tensor.requires_grad)
                    self.assertTrue(dist_tensor.is_leaf)

    @with_xla
    def text_xla_distribute_module(self):
        import torch_xla  # type:ignore[import]
        import torch_xla.core.xla_model as xm  # type:ignore[import]
        import torch_xla.runtime as xr  # type:ignore[import]

        model = self.SimpleLinear().to(xm.xla_device())

        device_count = xr.global_runtime_device_count()
        device_mesh = DeviceMesh("xla", list(range(device_count)))

        def shard_params(mod_name, mod, mesh):
            shard_spec = [Shard(0)]
            # annoate fc1 and fc2
            if isinstance(mod, nn.Linear):
                for name, param in mod.named_parameters():
                    # annotate the parameter tensors directly
                    distribute_tensor(param, mesh, shard_spec)

        sharded_model = distribute_module(model, device_mesh, shard_params)
        self.assertTrue(
            torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != ""
        )
        self.assertTrue(
            torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != ""
        )


if __name__ == "__main__":
    run_tests()
