# Owner(s): ["oncall: distributed"]

import sys

import torch
from torch.distributed._tensor import (
    distribute_module,
    distribute_tensor,
    DTensor,
    Replicate,
    Shard,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    with_comms,
)


if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


funcol = torch.ops.c10d_functional


class TestEmbeddingOp(DTensorTestBase):
    def _apply_sharding(self, embedding_mod, shard_dim, device_mesh):
        def shard_embedding_fn(name, module, device_mesh):
            for name, param in module.named_parameters():
                dist_param = torch.nn.Parameter(
                    distribute_tensor(param, device_mesh, [Shard(shard_dim)])
                )
                module.register_parameter(name, dist_param)

        sharded_embedding = distribute_module(
            embedding_mod, device_mesh, shard_embedding_fn
        )
        return sharded_embedding

    def _run_embedding_op_test(
        self,
        device_mesh,
        shard_dim,
        input_size,
        num_embeddings,
        embedding_dim,
        **kwargs,
    ):
        # Use same seed.
        torch.manual_seed(0)
        local_embedding = torch.nn.Embedding(
            num_embeddings,
            embedding_dim,
            device=self.device_type,
            **kwargs,
        )
        sharded_embedding = torch.nn.Embedding(
            num_embeddings,
            embedding_dim,
            device=self.device_type,
            **kwargs,
        )

        # Shard the parameter of local embedding and set it to sharded embedding.
        sharded_embedding.weight = torch.nn.Parameter(
            local_embedding.weight.clone().detach()
        )

        sharded_embedding = self._apply_sharding(
            sharded_embedding, shard_dim, device_mesh
        )

        # Run sharded computation
        torch.manual_seed(10)
        inp = torch.randint(
            0, num_embeddings, tuple(input_size), device=self.device_type
        )
        target = torch.empty(
            *inp.size(), embedding_dim, dtype=torch.float, device=self.device_type
        ).random_(0, 1)
        dist_inp = distribute_tensor(inp, device_mesh, [Replicate()])

        # fwd computation, ensure no comm happened
        with CommDebugMode() as fwd_mode:
            dist_output = sharded_embedding(dist_inp)
            self.assertEqual(fwd_mode.get_total_counts(), 0)

        output = dist_output.full_tensor()
        # Run local computation
        local_output = local_embedding(inp)

        # Verify
        self.assertEqual(local_output, output)

        # Use a sample cross entry loss to verify backward and grad computation.
        loss = torch.nn.CrossEntropyLoss()
        emb_loss = loss(
            output,
            target,
        )
        emb_dup_loss = loss(
            local_output,
            target,
        )

        # local embedding backward
        emb_dup_loss.backward()

        # sharded embedding bwd computation, ensure no comm happened
        with CommDebugMode() as bwd_mode:
            emb_loss.backward()
            self.assertEqual(bwd_mode.get_total_counts(), 0)

        gradient = sharded_embedding.weight.grad.full_tensor()

        local_grad = local_embedding.weight.grad

        # Verify gradient.
        self.assertEqual(gradient, local_grad)

        # Validate for torch.nn.functional.embedding version.
        local_output = torch.nn.functional.embedding(
            inp,
            local_embedding.weight,
            **kwargs,
        )
        sharded_output = torch.nn.functional.embedding(
            DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False),
            sharded_embedding.weight,
            **kwargs,
        )
        self.assertEqual(local_output, sharded_output.full_tensor())

    @with_comms
    def test_sharded_embedding_colwise(self):
        mesh = self.build_device_mesh()
        self._run_embedding_op_test(mesh, 1, [5, 4], 17, 12)
        self._run_embedding_op_test(mesh, 1, [6, 7, 6], 21, 11)
        self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13)
        self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4, 7], 23, 16)
        self._run_embedding_op_test(mesh, 1, [4], 15, 14)
        self._run_embedding_op_test(mesh, 1, [34], 15, 14, padding_idx=10)
        self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12)

    @with_comms
    def test_sharded_embedding_colwise_max_norm_errors(self):
        mesh = self.build_device_mesh()
        with self.assertRaisesRegex(
            NotImplementedError,
            "aten.embedding_renorm_.default does not have a sharding strategy registered.",
        ):
            self._run_embedding_op_test(
                mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0
            )

    @with_comms
    def test_sharded_embedding_rowwise(self):
        mesh = self.build_device_mesh()
        # test correctness
        self._run_embedding_op_test(mesh, 0, [5, 12], 16, 22)
        self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
        self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)

        from torch.distributed.tensor._ops._embedding_ops import _MaskPartial

        # test collectives
        embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
        sharded_embedding = self._apply_sharding(embedding_mod, 0, mesh)
        inp = torch.randint(0, 10, (8, 8), device=self.device_type)
        replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
        output = sharded_embedding(replicated_inp)
        self.assertIsInstance(output.placements[0], _MaskPartial)

        comm_mode = CommDebugMode()

        with comm_mode:
            output.full_tensor()
            self.assertEqual(comm_mode.get_total_counts(), 1)
            self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)

    @with_comms
    def test_multiple_embeddings_rowwise(self):
        mesh = self.build_device_mesh()

        inp = torch.randint(0, 10, (4, 4), device=self.device_type)
        replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)

        from torch.distributed.tensor._ops._embedding_ops import _MaskPartial

        # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial
        # and MaskBuffer, because of cache hit from sharding propagation

        emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
        sharded_emb1 = self._apply_sharding(emb1, 0, mesh)
        output1 = sharded_emb1(replicated_inp)

        emb2 = torch.nn.Embedding(10, 29, device=self.device_type)
        sharded_emb2 = self._apply_sharding(emb2, 0, mesh)
        output2 = sharded_emb2(replicated_inp)

        partial_placement1 = output1.placements[0]
        self.assertIsInstance(partial_placement1, _MaskPartial)
        output1.full_tensor()

        partial_placement2 = output2.placements[0]
        self.assertIsInstance(partial_placement2, _MaskPartial)
        output2.full_tensor()

        self.assertTrue(id(partial_placement1), id(partial_placement2))

        # case 2: two embeddings with the same logical_dim_size, but different logical_shape
        # thus they will have different _MaskPartial placements (with no cache hit)

        emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
        sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
        output3 = sharded_emb3(replicated_inp)
        partial_placement3 = output3.placements[0]
        self.assertIsInstance(partial_placement3, _MaskPartial)
        output2.full_tensor()

        # not equal because of different logical_shape, despite of same logical_dim_size
        self.assertNotEqual(partial_placement1, partial_placement3)


if __name__ == "__main__":
    run_tests()
