# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]

import torch
import torch.nn as nn
from torch.distributed._tensor import (
    DeviceMesh,
    distribute_module,
    distribute_tensor,
    DTensor,
    Replicate,
    Shard,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    with_comms,
)


class MyModel(nn.Module):
    def __init__(self, n_features, n_layers, device):
        super().__init__()
        self.seq = nn.Sequential(
            *[nn.Linear(n_features, n_features, device=device) for _ in range(n_layers)]
        )

    def forward(self, x):
        return self.seq(x)

    def reset_parameters(self):
        for m in self.seq:
            m.reset_parameters()


class DTensorAPITest(DTensorTestBase):
    @property
    def world_size(self) -> int:
        # hard code world size to 4 as we need to test
        # at least with 2d mesh
        return 4

    @with_comms
    def test_distribute_tensor(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
        shard_spec = [Shard(0)]

        for requires_grad in [True, False]:
            tensor_to_shard = torch.randn(
                3 * self.world_size, 3, requires_grad=requires_grad
            )
            dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
            self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3]))
            local_tensor = dist_tensor.to_local()
            self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
            if requires_grad:
                self.assertTrue(dist_tensor.requires_grad)
                self.assertTrue(dist_tensor.is_leaf)

        # test negative dim
        shard_minus_spec = [Shard(-1)]
        tensor_to_shard = torch.randn(3, 3 * self.world_size)
        dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec)
        self.assertEqual(dist_tensor.placements[0].dim, 1)

    @with_comms
    def test_distribute_tensor_errors(self):
        device_mesh = DeviceMesh(
            self.device_type, torch.arange(self.world_size).reshape(2, 2)
        )
        tensor_shape = [3 * self.world_size, 3 * self.world_size]
        tensor_to_distribute = torch.randn(*tensor_shape)

        with self.assertRaisesRegex(ValueError, "must have the same length"):
            shard_spec = [Shard(0)]
            distribute_tensor(tensor_to_distribute, device_mesh, shard_spec)

        with self.assertRaisesRegex(RuntimeError, "distribute leaf tensor"):
            shard_spec = [Shard(0)]
            global_tensor = torch.randn(*tensor_shape, requires_grad=True)
            global_tensor_to_distribute = global_tensor + 2
            distribute_tensor(global_tensor_to_distribute, device_mesh, shard_spec)

        spec = [Shard(0), Shard(1)]
        dtensor = distribute_tensor(tensor_to_distribute, device_mesh, spec)

        with self.assertRaisesRegex(ValueError, "to a different device mesh"):
            new_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
            distribute_tensor(dtensor, new_mesh, [Shard(0)])

        with self.assertRaisesRegex(ValueError, "to a different placements"):
            new_spec = [Shard(0), Replicate()]
            distribute_tensor(dtensor, device_mesh, new_spec)

    @with_comms
    def test_distribute_tensor_uneven_sharding(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
        input_sizes_and_shard_dims = [
            ((self.world_size * 3 + 1, 3, 3), 0),
            ((self.world_size * 3 + 2, 3, 3), 0),
            ((3, self.world_size * 3 + 1, 3), 1),
            ((3, self.world_size * 3 + 2, 3), 1),
            ((3, 3, self.world_size * 3 + 1), 2),
            ((3, 3, self.world_size * 3 + 2), 2),
        ]
        for input_size, shard_dim in input_sizes_and_shard_dims:
            shard_spec = [Shard(shard_dim)]
            tensor_to_shard = torch.randn(input_size)
            splitted_tensor_list = list(
                torch.chunk(tensor_to_shard, self.world_size, dim=shard_dim)
            )
            dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
            self.assertEqual(dist_tensor.size(), torch.Size(input_size))
            local_tensor = dist_tensor.to_local()
            self.assertEqual(local_tensor, splitted_tensor_list[self.rank])

    @with_comms
    def test_distribute_module(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
        # fully shard all linear modules on dim 0
        module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type)
        shard_spec = [Shard(0)]

        def shard_fn(name, module, device_mesh):
            if isinstance(module, nn.Linear):
                for name, param in module.named_parameters():
                    dist_param = torch.nn.Parameter(
                        distribute_tensor(param, device_mesh, shard_spec)
                    )
                    module.register_parameter(name, dist_param)

        sharded_module = distribute_module(module_to_shard, device_mesh, shard_fn)
        for param in sharded_module.parameters():
            self.assertIsInstance(param, DTensor)
            self.assertEqual(param.placements, shard_spec)

        replica_spec = [Replicate()]
        # fully replicate all modules without passing in partition_fn
        module_to_replicate = MyModel(5, 20, device=self.device_type)
        replica_module = distribute_module(module_to_replicate, device_mesh)
        for param in replica_module.parameters():
            self.assertIsInstance(param, DTensor)
            self.assertEqual(param.placements, replica_spec)

        # fully replicate all modules by passing in partition_fn
        def replicate_fn(name, module, device_mesh):
            if isinstance(module, nn.Linear):
                for name, param in module.named_parameters():
                    dist_param = torch.nn.Parameter(
                        distribute_tensor(param, device_mesh, replica_spec)
                    )
                    module.register_parameter(name, dist_param)

        module_to_replicate = MyModel(5, 20, device=self.device_type)
        replica_module = distribute_module(
            module_to_replicate, device_mesh, replicate_fn
        )
        for param in replica_module.parameters():
            self.assertIsInstance(param, DTensor)
            self.assertEqual(param.placements, replica_spec)

        # only shard part of module, and rest of module should be replicate
        def shard_fn(name, module, device_mesh):
            if isinstance(module, nn.Linear) and (name == "seq.0" or name == "seq.8"):
                for name, param in module.named_parameters():
                    dist_param = torch.nn.Parameter(
                        distribute_tensor(param, device_mesh, shard_spec)
                    )
                    module.register_parameter(name, dist_param)

        module_to_distribute = MyModel(5 * self.world_size, 20, device=self.device_type)
        dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn)
        for name, param in dist_module.named_parameters():
            self.assertIsInstance(param, DTensor)
            if name.startswith(("seq.0", "seq.8")):
                self.assertEqual(param.placements, shard_spec)
            else:
                self.assertEqual(param.placements, replica_spec)

    @with_comms
    def test_distribute_module_input_fn_output_fn(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

        # fully replicate all linear modules
        module_to_replicate = MyModel(20, 1, device=self.device_type)

        # mark input sharding on dim 0
        def input_fn(mod, inputs, device_mesh):
            return DTensor.from_local(inputs[0], device_mesh, [Shard(0)])

        def output_fn(mod, outputs, device_mesh):
            assert isinstance(outputs, DTensor)
            return outputs.to_local()

        replica_module = distribute_module(
            module_to_replicate,
            device_mesh,
            input_fn=input_fn,
            output_fn=output_fn,
        )

        input_tensor = torch.randn(5, 20, device=self.device_type)
        local_out = replica_module(input_tensor)
        self.assertIsInstance(local_out, torch.Tensor)
        self.assertNotIsInstance(local_out, DTensor)

        # full replicate (even on inputs)
        model = MyModel(10, 10, device=self.device_type)

        def replicate_input_fn(mod, inputs, device_mesh):
            return DTensor.from_local(inputs[0], device_mesh, [Replicate()])

        replica_model = distribute_module(
            model,
            device_mesh,
            input_fn=replicate_input_fn,
        )
        input = torch.randn(10, 10, requires_grad=True)
        output = replica_model(input)
        output.sum().backward()
        param_grad = next(iter(replica_model.parameters())).grad
        self.assertTrue(isinstance(param_grad, DTensor))
        self.assertTrue(isinstance(param_grad.placements[0], Replicate))

    @with_comms
    def test_distribute_module_input_fn_output_fn_warning(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

        # fully replicate all linear modules
        module_to_replicate = MyModel(20, 1, device=self.device_type)

        # mark input sharding on dim 0
        def input_fn(inputs, device_mesh):
            return DTensor.from_local(inputs[0], device_mesh, [Shard(0)])

        def output_fn(outputs, device_mesh):
            assert isinstance(outputs, DTensor)
            return outputs.to_local()

        with self.assertWarnsRegex(FutureWarning, "Deprecating"):
            replica_module = distribute_module(
                module_to_replicate,
                device_mesh,
                input_fn=input_fn,
                output_fn=output_fn,
            )

        input_tensor = torch.randn(5, 20, device=self.device_type)
        local_out = replica_module(input_tensor)
        self.assertIsInstance(local_out, torch.Tensor)
        self.assertNotIsInstance(local_out, DTensor)

    @with_comms
    def test_distribute_module_casting(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

        # check DTensor casting
        dt = DTensor.from_local(torch.rand(10), device_mesh, [Replicate()])
        dt = dt.to(torch.bfloat16)
        self.assertEqual(dt.dtype, torch.bfloat16)
        self.assertEqual(dt._local_tensor.dtype, torch.bfloat16)

        # check distribute_tensor casting
        dt = distribute_tensor(torch.rand(10), device_mesh, [Replicate()])
        dt = dt.to(torch.bfloat16)
        self.assertEqual(dt.dtype, torch.bfloat16)
        self.assertEqual(dt._local_tensor.dtype, torch.bfloat16)

        # check distribute_module casting
        model = MyModel(10, 10, device=self.device_type)
        replica_model = distribute_module(
            model,
            device_mesh,
        )
        replica_model = replica_model.to(torch.bfloat16)
        self.assertEqual(replica_model.seq[0].weight.dtype, torch.bfloat16)
        self.assertEqual(
            replica_model.seq[0].weight._local_tensor.dtype, torch.bfloat16
        )

        # check autocast
        dt = distribute_tensor(torch.rand(10), device_mesh, [Replicate()])
        replica_model = distribute_module(
            model,
            device_mesh,
        )
        with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
            output = replica_model(dt)
        self.assertEqual(output.dtype, torch.bfloat16)

    @with_comms
    def test_distribute_module_meta(self):
        # If  the model is too big, the user may first the create entire model on the meta device and then initialize
        # it on the device in the partition function.
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

        # fully shard all parameters on dim 0
        module_to_shard = MyModel(5 * self.world_size, 20, device="meta")

        shard_spec = [Shard(0)]

        def shard_fn(name, module, device_mesh):
            for param_name, param in module._parameters.items():
                dist_param = distribute_tensor(param, device_mesh, shard_spec)
                dist_param = torch.empty_like(
                    dist_param, device=device_mesh.device_type
                )
                module.register_parameter(param_name, torch.nn.Parameter(dist_param))

        sharded_module = distribute_module(module_to_shard, device_mesh, shard_fn)
        for param in sharded_module.parameters():
            self.assertIsInstance(param, DTensor)
            self.assertFalse(param.is_meta)
            self.assertTrue(param.device.type == device_mesh.device_type)


if __name__ == "__main__":
    run_tests()
