# Owner(s): ["module: nn"]
import itertools
import random
import unittest
from itertools import product

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import (
    dtypes,
    dtypesIfCUDA,
    instantiate_device_type_tests,
    largeTensorTest,
    onlyCUDA,
    onlyNativeDeviceTypes,
    skipCUDAIf,
    skipMeta,
    TEST_WITH_ROCM,
)

from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
    _assertGradAndGradgradChecks,
    dtype2prec_DONTUSE,
    instantiate_parametrized_tests,
    IS_JETSON,
    parametrize as parametrize_test,
    run_tests,
    set_default_dtype,
    skipIfTorchDynamo,
)


class TestEmbeddingNN(NNTestCase):
    _do_cuda_memory_leak_check = True
    _do_cuda_non_default_stream = True

    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    def test_embedding_max_norm_unsorted_repeating_indices(self):
        def create_embedding(device):
            # Seed RNG so we get the same Embedding each time
            torch.manual_seed(0)
            return torch.nn.Embedding(
                num_embeddings=20, embedding_dim=64, max_norm=1.0
            ).to(device)

        ix = torch.arange(2, device="cpu", dtype=torch.long).repeat(2000)
        out_cpu = create_embedding("cpu")(ix)

        ix = ix.to("cuda")
        out = create_embedding("cuda")(ix)
        self.assertEqual(out.cpu(), out_cpu)

    def test_embedding_sparse_basic(self):
        embedding = nn.Embedding(10, 20, sparse=True)
        input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long)
        embedding(input).sum().backward()
        self.assertTrue(embedding.weight.grad.is_sparse)
        self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)

    def test_embedding_sparse_empty_tensor(self):
        embedding = nn.Embedding(0, 0, sparse=True)
        input = torch.tensor([], dtype=torch.int64)
        embedding(input).sum().backward()
        self.assertTrue(embedding.weight.grad.is_sparse)
        self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)

        embedding = nn.Embedding(10, 0, sparse=True)
        input = torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]])
        embedding(input).sum().backward()
        self.assertTrue(embedding.weight.grad.is_sparse)
        self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)

    def test_move_sparse_half_embedding(self):
        embedding = nn.Embedding(10, 3, sparse=True)
        self.assertEqual(embedding.weight.device.type, "cpu")
        self.assertEqual(embedding.weight.dtype, torch.get_default_dtype())
        embedding.to(torch.float16)
        self.assertEqual(embedding.weight.dtype, torch.float16)
        self.assertEqual(embedding.embedding_dim, 3)
        self.assertEqual(embedding.num_embeddings, 10)

        if torch.cuda.is_available():
            embedding.to("cuda")
            self.assertEqual(embedding.weight.device.type, "cuda")
            embedding.to("cpu")
            self.assertEqual(embedding.weight.device.type, "cpu")

    def test_embedding_max_norm(self):
        embedding = nn.Embedding(22, 5, max_norm=1.0)
        input = torch.tensor([2, 8, 8, 6], dtype=torch.long)
        output = embedding(input)
        self.assertEqual(output[1], output[2])
        self.assertTrue(output.data.norm(p=2, dim=1).le(1).all())

    @parametrize_test(
        "dtype",
        (
            torch.uint8,
            torch.int8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.float,
            torch.double,
        ),
    )
    def test_embedding_from_pretrained(self, dtype):
        a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
        embedding = nn.Embedding.from_pretrained(a)
        self.assertEqual(a, embedding.weight.data)

        input = torch.LongTensor([0, 1])
        output = embedding(input)
        self.assertEqual(a, output)

    def test_embedding_bag_from_pretrained(self):
        a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
        embedding = nn.EmbeddingBag.from_pretrained(a)
        self.assertEqual(a, embedding.weight)

        input = torch.tensor([0, 1], dtype=torch.long)
        output = embedding(input, torch.arange(input.size(0)))
        self.assertEqual(a, output)

    def test_embedding_from_pretrained_padding_idx(self):
        padding_idx = 2
        padding_vec = torch.ones(3) * 7
        embeddings = torch.rand(4, 3, requires_grad=True)
        with torch.no_grad():
            embeddings[padding_idx] = padding_vec
        embedding_nn = nn.Embedding.from_pretrained(embeddings, padding_idx=padding_idx)
        self.assertEqual(embedding_nn.weight[padding_idx], padding_vec)

    def test_embedding_bag_from_pretrained_padding_idx(self):
        padding_idx = 2
        embeddings = torch.rand(4, 3, requires_grad=True)
        embedding_nn = nn.EmbeddingBag.from_pretrained(
            embeddings, padding_idx=padding_idx
        )
        self.assertEqual(embedding_nn.weight, embeddings)

    def test_embedding_from_pretrained_options(self):
        with set_default_dtype(torch.double):
            a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
            opts = {
                "max_norm": 2.0,
                "norm_type": 0.5,
                "scale_grad_by_freq": False,
                "sparse": True,
            }
            embedding = nn.Embedding.from_pretrained(a, **opts)
            input = torch.LongTensor([0, 1])
            output = embedding(input)
            # test output and that weight matrix was renormalized
            self.assertEqual(a, output)
            self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
            self.assertTrue(
                output.data.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all()
            )

    def test_embedding_functional(self):
        a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
        embeddings = torch.rand(4, 3, requires_grad=True)

        embed_old = torch.nn.Embedding(4, 3)
        embed_old.weight.data = embeddings.data
        # A silly test for eager, this test is useful for when we run under PYTORCH_TEST_WITH_DYNAMO=1
        # as it ensures that setattr correctly works.
        self.assertEqual(embed_old.weight.data, embeddings.data)
        res_old = embed_old(a)

        res_F = F.embedding(a, embeddings)
        self.assertEqual(res_old, res_F)

        embed_old = torch.nn.Embedding(4, 3)
        embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
        res_old = embed_old(a)
        res_F = F.embedding(a, embeddings, padding_idx=2)

        self.assertEqual(res_old, res_F)

    # https://github.com/pytorch/pytorch/issues/130806
    @largeTensorTest("40GB", device="cuda")
    def test_large_tensors(self):
        input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
        w = torch.randn([16032, 16384], device="cuda")
        out = torch.nn.functional.embedding(input, w)
        self.assertEqual(out.dim(), 2)
        self.assertEqual(out.numel(), 2147483648)

    def test_embedding_bag_functional(self):
        a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
        embeddings = torch.rand(4, 3, requires_grad=True)

        embed_old = torch.nn.EmbeddingBag(4, 3)
        embed_old.weight = torch.nn.Parameter(embeddings)
        res_old = embed_old(a)

        res_F = F.embedding_bag(a, embeddings)
        self.assertEqual(res_old, res_F)

        embed_old = torch.nn.EmbeddingBag(4, 3)
        embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
        res_old = embed_old(a)
        res_F = F.embedding_bag(a, embeddings, padding_idx=2)

        self.assertEqual(res_old, res_F)

    # Make sure that error is thrown if padding_idx is out of bounds
    def test_embedding_bag_padding_idx_error(self):
        a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
        num_embeddings = 4
        num_features = 3
        embeddings = torch.rand(num_embeddings, num_features, requires_grad=True)

        functional_err_msg = r"padding_idx must be within the number of embeddings"
        module_err_msg = r"padding_idx must be within num_embeddings"

        for padding_idx in range(-(num_embeddings + 2), (num_embeddings + 2)):
            if (padding_idx < -num_embeddings) or (padding_idx >= num_embeddings):
                with self.assertRaisesRegex(RuntimeError, functional_err_msg):
                    F.embedding_bag(a, embeddings, padding_idx=padding_idx)
                with self.assertRaisesRegex(AssertionError, module_err_msg):
                    torch.nn.EmbeddingBag(
                        num_embeddings, num_features, padding_idx=padding_idx
                    )
            else:
                F.embedding_bag(a, embeddings, padding_idx=padding_idx)
                torch.nn.EmbeddingBag(
                    num_embeddings, num_features, padding_idx=padding_idx
                )

    def test_embeddingbag_from_pretrained(self):
        a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
        embeddingbag = nn.EmbeddingBag.from_pretrained(a)
        self.assertEqual(a, embeddingbag.weight.data)

        input = torch.LongTensor([[0, 1]])
        output = embeddingbag(input)
        self.assertEqual(a.mean(0, keepdim=True), output)

    def test_embeddingbag_from_pretrained_options(self):
        with set_default_dtype(torch.double):
            a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
            opts = {
                "max_norm": 2.0,
                "norm_type": 0.5,
                "scale_grad_by_freq": False,
                "mode": "max",
                "sparse": False,
            }
            embeddingbag = nn.EmbeddingBag.from_pretrained(a, **opts)

            input = torch.LongTensor([[0, 1]])
            output = embeddingbag(input)
            self.assertEqual(a.max(0, keepdim=True)[0], output)
            self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
            self.assertTrue(
                a.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all()
            )

    def test_embeddingbag_include_last_offset(self):
        # Test case from https://github.com/pytorch/pytorch/issues/89677
        embeddingbag = nn.EmbeddingBag(100, 3, include_last_offset=True, padding_idx=61)
        input = torch.tensor([0, 1, 2, 3])
        out = embeddingbag(input, torch.tensor([0, 3, 3]))
        out2 = embeddingbag(input, torch.tensor([0, 3, 4]))

        weight = embeddingbag.weight
        row0 = weight[0:3].mean(0)
        row1 = weight[3]
        ref_out = torch.stack([row0, row1])

        self.assertEqual(ref_out, out)
        self.assertEqual(ref_out, out2)


class TestEmbeddingNNDeviceType(NNTestCase):
    def test_embedding_dense_grad(self, device):
        with set_default_dtype(torch.double):
            embd = nn.Embedding(20, 20).to(device)
            weight = embd.weight

            def fn_wrapper(device):
                def fn(weight):
                    inp = torch.tensor(
                        [[0, 1, 1, 2], [3, 5, 7, 11]], dtype=torch.long
                    ).to(device)
                    return torch.nn.functional.embedding(inp, weight)

                return fn

            fn = fn_wrapper(device)
            _assertGradAndGradgradChecks(self, fn, (weight,))

    def test_embedding_scalar_weight_error(self, device):
        indices = torch.rand(2, 2, device=device).long()
        weights = [
            torch.tensor(1.0, device=device),
            torch.tensor(1.0, device=device).reshape(1, 1, 1),
        ]

        for weight in weights:
            with self.assertRaisesRegex(RuntimeError, "'weight' must be 2-D"):
                torch.nn.functional.embedding(indices, weight)

    @dtypesIfCUDA(torch.float16, torch.float64)
    @dtypes(torch.float64)
    def test_embedding_backward(self, device, dtype):
        embedding = nn.Embedding(10, 3, sparse=True)
        tensor = torch.tensor([[7, 1, 3]])
        ones = torch.tensor(1.0, dtype=dtype).expand(3, 3)
        tensorTwice = tensor.repeat(1, 2)
        onesTwice = torch.cat((ones, ones))

        embedding = embedding.to(dtype=dtype).to(device)
        tensor = tensor.to(device)
        ones = ones.to(device)
        tensorTwice = tensorTwice.to(device)
        onesTwice = onesTwice.to(device)

        embedding.zero_grad()
        embedding(tensor[0]).sum().backward()
        self.assertEqual(embedding.weight.grad._indices(), tensor)
        self.assertEqual(embedding.weight.grad._values(), ones)

        embedding.zero_grad()
        embedding(tensor[0]).sum().backward()
        embedding(tensor[0]).sum().backward()
        self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
        self.assertEqual(embedding.weight.grad._values(), onesTwice)

        embedding.zero_grad()
        embedding(tensor[0]).sum().backward()
        tensor[0, 0] = 8
        embedding(tensor[0]).sum().backward()
        tensorTwice[0, 3] = 8
        self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
        self.assertEqual(embedding.weight.grad._values(), onesTwice)

    @dtypesIfCUDA(
        *(
            (torch.float, torch.double, torch.bfloat16, torch.half)
            if TEST_WITH_ROCM
            else (torch.float, torch.double, torch.half)
        )
    )
    @dtypes(torch.float32)
    def test_embedding_max_norm_backward(self, device, dtype):
        # can't use gradcheck since in place renorm makes analytical gradients different from produced ones
        weight = torch.randn((4, 4), device=device, dtype=dtype) * 2
        weight.requires_grad_()
        inp_list = [0, 1, 2, 2]
        inp = torch.tensor(inp_list, device=device)
        out = nn.functional.embedding(inp, weight, max_norm=1.0).sum()
        out.backward()

        expected_grad = (
            torch.tensor([[1.0, 1.0, 2.0, 0.0]], device=device, dtype=dtype)
            .transpose(0, 1)
            .expand(4, 4)
        )
        self.assertEqual(weight.grad, expected_grad)

    @dtypesIfCUDA(
        *(
            (torch.float, torch.double, torch.bfloat16, torch.half)
            if TEST_WITH_ROCM
            else (torch.float, torch.double, torch.half)
        )
    )
    @dtypes(torch.float32)
    def test_embedding_max_norm_fwd_AD(self, device, dtype):
        if torch.device(device).type == "xla":
            self.skipTest("forward AD doesn't work on xla")

        # can't use gradcheck since in place renorm makes analytical gradients different from produced ones
        weight = torch.randn((4, 4), device=device, dtype=dtype) * 2
        tangent = torch.ones((4, 4), device=device, dtype=dtype)
        inp = torch.tensor([[0, 1], [2, 2]], device=device)
        with torch.autograd.forward_ad.dual_level():
            dual_weight = torch.autograd.forward_ad.make_dual(weight, tangent)
            out = nn.functional.embedding(inp, dual_weight, max_norm=1.0)
            jvp = torch.autograd.forward_ad.unpack_dual(out).tangent

        expected_grad = torch.ones((2, 2, 4), device=device, dtype=dtype)
        self.assertEqual(jvp, expected_grad)

    @dtypesIfCUDA(
        *(
            (torch.float, torch.double, torch.bfloat16, torch.half)
            if TEST_WITH_ROCM
            else (torch.float, torch.double, torch.half)
        )
    )
    @dtypes(torch.float32)
    def test_embedding_padding_idx(self, device, dtype):
        embedding = nn.Embedding(10, 20, padding_idx=0).to(device, dtype)
        input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device)
        output = embedding(input)
        self.assertEqual(output[0][0].sum(), 0)
        self.assertEqual(output[1][2].sum(), 0)

        embedding = nn.Embedding(10, 20, padding_idx=0, sparse=True).to(device, dtype)
        input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device)
        output = embedding(input)
        self.assertEqual(output[0][0].sum(), 0)
        self.assertEqual(output[1][2].sum(), 0)

        # negative indexing check for padding_idx
        # padding_idx=-2, num_embeddings=10 ==> index 8 padded
        embedding = nn.Embedding(10, 20, padding_idx=-2).to(device, dtype)
        input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device)
        output = embedding(input)
        self.assertEqual(output[0][2].sum(), 0)
        self.assertEqual(output[1][1].sum(), 0)

        embedding = nn.Embedding(10, 20, padding_idx=-2, sparse=True).to(device, dtype)
        input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device)
        output = embedding(input)
        self.assertEqual(output[0][2].sum(), 0)
        self.assertEqual(output[1][1].sum(), 0)

        # change padding vector
        padding_vector = torch.ones(20, dtype=dtype, device=device)
        embedding = nn.Embedding(10, 20, padding_idx=2, sparse=True).to(device, dtype)
        with torch.no_grad():
            embedding.weight[2] = padding_vector
        input = torch.tensor([0, 2], dtype=torch.long).to(device)
        output = embedding(input)
        self.assertEqual(output[1], padding_vector)

        # out of bounds check for padding_idx
        self.assertRaises(
            AssertionError,
            nn.Embedding,
            num_embeddings=10,
            embedding_dim=20,
            padding_idx=25,
        )
        self.assertRaises(
            AssertionError,
            nn.Embedding,
            num_embeddings=10,
            embedding_dim=20,
            padding_idx=-25,
        )

        padding_idx = 0
        embedding = nn.Embedding(5, 2, padding_idx=padding_idx).to(device, dtype)
        for n in (
            1,
            2,
            1000,
        ):  # Need large N to trigger all the methods we have implemented
            for other_indices in ([], [1, 3], [2]):
                indices = torch.tensor(
                    other_indices + [padding_idx] * n, dtype=torch.long
                ).to(device)
                pre = embedding.weight[padding_idx].clone()
                embedding(indices).sum().backward()
                after = (embedding.weight + embedding.weight.grad)[padding_idx]
                embedding.zero_grad()
                self.assertEqual(after, pre)

                # test double backward
                emb_sum = embedding(indices).sum()
                emb_grad = torch.autograd.grad(
                    outputs=emb_sum,
                    inputs=list(embedding.parameters()),
                    retain_graph=True,
                )
                scalar = emb_grad[0].sum() + emb_sum
                scalar.backward()
                after = (embedding.weight + embedding.weight.grad)[padding_idx]
                embedding.zero_grad()
                self.assertEqual(after, pre)

    # Check correctness of torch.nn.functional.embedding_bag forward and
    # backward functions with padding_idx, given a 1D input separated into bags
    # with an offset array. Compare against an equivalent 2D input that uses
    # padding indices to fill in the gaps indicated by the offset array

    @skipIfTorchDynamo("see https://github.com/pytorch/pytorch/pull/95621")
    @onlyNativeDeviceTypes
    @dtypes(torch.float32, torch.float64)
    @dtypesIfCUDA(torch.half, torch.bfloat16)
    def test_embedding_bag_1D_padding_idx(self, device, dtype):
        num_features = 3
        max_indices_per_bag = 10
        num_bags = 10
        num_words = 100

        def gen_1D_indices_offsets(include_last_offset, allpad):
            indices = []
            offsets = []
            cur_offset = 0

            # Make one bag full and one bag empty, for extra coverage
            empty_bag = random.randint(0, num_bags - 1)
            full_bag = empty_bag
            while full_bag == empty_bag:
                full_bag = random.randint(0, num_bags - 1)

            for bag in range(num_bags):
                offsets.append(cur_offset)
                if bag == full_bag:
                    bag_size = max_indices_per_bag
                elif bag == empty_bag:
                    bag_size = 0
                else:
                    bag_size = random.randint(1, max_indices_per_bag - 1)
                indices += [
                    1 if allpad else random.randint(0, num_words - 1)
                    for _ in range(bag_size)
                ]
                cur_offset += bag_size

            # embedding_bag requires first entry of offsets to be 0
            assert offsets[0] == 0

            indices = torch.tensor(indices, device=device)

            if include_last_offset:
                offsets.append(indices.size(0))

            offsets = torch.tensor(offsets, device=device)

            return indices, offsets

        # Convert a 1-D indices-offsets representation into 2-D. Fill any empty
        # indices with padding_idx
        def gen_2D_indices_from_1D(
            indices_1D, offsets, include_last_offset, padding_idx
        ):
            assert offsets[0] == 0
            if include_last_offset:
                offsets = offsets[:-1]
            indices_2D = torch.empty(
                num_bags, max_indices_per_bag, device=device, dtype=torch.long
            )
            for bag in range(num_bags):
                # Determine the start and end position of the bag within indices_1D
                start = offsets[bag]
                end = len(indices_1D) if bag + 1 == num_bags else offsets[bag + 1]
                end = min(len(indices_1D), end)

                # Pull out the bag's indices from indices_1D, and fill any
                # remaining space with padding indices
                indices_in_bag = []
                for item_pos in range(0, max_indices_per_bag):
                    if (start + item_pos) < end:
                        indices_in_bag.append(indices_1D[start + item_pos])
                    else:
                        indices_in_bag.append(padding_idx)
                indices_2D[bag] = torch.tensor(indices_in_bag, device=device)

            return indices_2D

        test_cases = product(
            ["max", "mean", "sum"], [False, True], [False, True], [False, True]
        )

        for mode, sparse, include_last_offset, allpad in test_cases:
            # Max sparse and bfloat16 are not supported
            if mode == "max":
                if sparse or (dtype == torch.bfloat16):
                    continue
            indices_1D, offsets = gen_1D_indices_offsets(include_last_offset, allpad)
            for padding_idx_1D in list(set(indices_1D.tolist())) + [None]:
                msg = (
                    f"mode: '{mode}', sparse: {sparse}, include_last_offset: {include_last_offset}, "
                    f"padding_idx_1D: {padding_idx_1D}"
                )

                # If 1D input does not use a padding index, we still need one for the 2D input,
                # so we can add one dummy word to the weights to act as the padded word
                padding_idx_2D = (
                    padding_idx_1D if padding_idx_1D is not None else num_words
                )
                num_words_with_padding = (
                    num_words if padding_idx_1D is not None else num_words + 1
                )

                indices_2D = gen_2D_indices_from_1D(
                    indices_1D, offsets, include_last_offset, padding_idx_2D
                )

                weights = torch.randn(
                    num_words_with_padding,
                    num_features,
                    dtype=dtype,
                    device=device,
                    requires_grad=True,
                )
                weights_check = weights.clone().detach().requires_grad_(True)

                bag = torch.nn.functional.embedding_bag(
                    indices_1D,
                    weights,
                    offsets,
                    padding_idx=padding_idx_1D,
                    mode=mode,
                    sparse=sparse,
                    include_last_offset=include_last_offset,
                )

                bag_check = torch.nn.functional.embedding_bag(
                    indices_2D,
                    weights_check,
                    padding_idx=padding_idx_2D,
                    mode=mode,
                    sparse=sparse,
                )
                self.assertEqual(bag, bag_check, msg=msg)

                bag.sum().backward()
                bag_check.sum().backward()

                # Sometimes, half dtype gradients mismatch by a greater amount
                # than other dtypes
                if dtype in [torch.half, torch.bfloat16]:
                    atol = 0.01
                    rtol = 0.01
                else:
                    atol = None
                    rtol = None
                self.assertEqual(
                    weights.grad, weights_check.grad, msg=msg, atol=atol, rtol=rtol
                )

    # Check correctness of torch.nn.functional.embedding_bag forward and
    # backward functions with padding_idx, given a 2D indices input. Compare
    # against torch.nn.functional.embedding followed by a reduction.
    @onlyNativeDeviceTypes
    @dtypes(torch.float32, torch.float64)
    @dtypesIfCUDA(torch.half, torch.bfloat16)
    def test_embedding_bag_2D_padding_idx(self, device, dtype):
        # Use a Python implementation of embedding_bag with padding_idx support
        # to check torch.nn.functional.embedding_bag correctness
        def embedding_bag_check(indices, weights, mode, sparse, padding_idx):
            assert padding_idx is not None
            embedding = torch.nn.functional.embedding(
                indices, weights, padding_idx=padding_idx, sparse=sparse
            )

            reduction_dim = indices.dim() - 1

            if mode == "sum" or mode == "mean":
                # We must avoid including elements at padding_idx in the
                # sum/mean, so multiply those elements by 0, and multiply
                # all other elements by 1
                per_sample_weights = indices.ne(padding_idx).to(dtype).unsqueeze(-1)
                res = embedding.mul(per_sample_weights).sum(dim=reduction_dim)

                if mode == "mean":
                    weights_sum = per_sample_weights.sum(dim=reduction_dim)
                    res = res.div(weights_sum)

            elif mode == "max":
                # We must avoid allowing elements at padding_idx to be chosen
                # as the max, so set those elements to negative infinity
                res = embedding.masked_fill(
                    indices.unsqueeze(-1) == padding_idx, -float("inf")
                ).amax(dim=reduction_dim)

            else:
                raise RuntimeError(f"mode '{mode}' is not available")

            # If a row is all padding, set its corresponding result row to 0.
            # This is needed because the above mean and max mode
            # implementations set these elements to nan and -inf, respectively
            if mode in ["mean", "max"]:
                res = res.masked_fill(
                    indices.eq(padding_idx).all(dim=-1).unsqueeze(-1), 0
                )

            return res

        num_features = 3
        num_words = 10
        indices_dim1 = 10

        for mode, sparse, allpad, indices_dim0 in product(
            ["max", "mean", "sum"], [False, True], [False, True], [1, 10]
        ):
            # Max sparse and bfloat16 are not supported
            if mode == "max":
                if sparse or (dtype == torch.bfloat16):
                    continue

            if allpad:
                indices = torch.empty(
                    indices_dim0, indices_dim1, dtype=torch.long, device=device
                ).fill_(1)
            else:
                indices = torch.randint(
                    0, num_words, (indices_dim0, indices_dim1), device=device
                )

                if indices_dim0 > 1:
                    # Fill one row with duplicate index so we can test with a fully
                    # padded row
                    duplicate_row = random.randint(0, indices_dim0 - 1)
                    indices[duplicate_row] = indices[duplicate_row][0]

            for padding_idx in list(set(indices.flatten(0, -1).tolist())):
                weights = torch.randn(
                    num_words,
                    num_features,
                    dtype=dtype,
                    device=device,
                    requires_grad=True,
                )
                weights_check = weights.clone().detach().requires_grad_(True)

                msg = (
                    f"mode: '{mode}', sparse: {sparse}, padding_idx: {padding_idx}, "
                    f"allpad: {allpad}, indices.size(): {indices.size()}"
                )

                # Check forward with a Python implementation of padding_idx embedding_bag
                bag_check = embedding_bag_check(
                    indices, weights_check, mode, sparse, padding_idx
                )
                bag = torch.nn.functional.embedding_bag(
                    indices, weights, padding_idx=padding_idx, mode=mode, sparse=sparse
                )

                self.assertEqual(bag, bag_check, msg=msg)

                bag_check.sum().backward()
                grad_check = weights_check.grad

                bag.sum().backward()
                grad = weights.grad

                # Sometimes, half dtype gradients mismatch by a greater amount
                # than other dtypes
                if dtype in [torch.half, torch.bfloat16]:
                    atol = 0.01
                    rtol = 0.01
                else:
                    atol = None
                    rtol = None
                self.assertEqual(grad, grad_check, msg=msg, atol=atol, rtol=rtol)

    @onlyCUDA
    @dtypes(
        *(
            (torch.float, torch.double, torch.bfloat16, torch.half)
            if TEST_WITH_ROCM
            else (torch.float, torch.double, torch.half)
        )
    )
    def test_embedding_max_norm_device(self, device, dtype):
        embedding = nn.Embedding(22, 5, max_norm=1.0).to(device, dtype=dtype)
        # nn.Embedding only takes LongTensor as input
        input = torch.tensor([2, 8, 8, 6], device=device, dtype=torch.long)
        output = embedding(input)
        self.assertEqual(output[1], output[2])
        self.assertTrue(output.data.norm(p=2, dim=1).le(1).all())

    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
    def test_embedding_bag_empty_input(self, device, dtypes):
        m = 4
        n = 3
        x = torch.tensor([], device=device, dtype=dtypes[0])
        for sparse in [True, False]:
            Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse)
            Embed.to(device)

            output = Embed(
                input=x, offsets=torch.tensor([0], device=device, dtype=dtypes[1])
            )
            self.assertEqual(output, torch.zeros_like(output))

            output = Embed(
                input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtypes[1])
            )
            self.assertEqual(output, torch.zeros_like(output))

    @skipCUDAIf(True, "no out-of-bounds check on CUDA for perf.")
    @dtypes(*itertools.product((torch.float, torch.double), (torch.int, torch.long)))
    @parametrize_test("padding_idx", [None, 0])
    @parametrize_test("mode", ["sum", "mean", "max"])
    def test_embedding_bag_out_of_bounds_idx(self, device, dtypes, padding_idx, mode):
        padding_idx = 0
        w_dtype, idx_dtype = dtypes
        # negative out-of-bound
        idx1 = torch.tensor([[-1, 1]], device=device, dtype=idx_dtype)
        # positive out-of-bound
        idx2 = torch.tensor([[11, 8]], device=device, dtype=idx_dtype)
        weight = torch.randn(10, 2, device=device, dtype=w_dtype)
        if mode == "sum":
            # Only `sum` supports per_sample_weight
            per_sample_weights = (
                None,
                torch.randn_like(idx1, device=device, dtype=w_dtype),
            )
        else:
            per_sample_weights = (None,)

        for p_s_weights, idx in itertools.product(per_sample_weights, (idx1, idx2)):
            msg = "Expected idx >= 0 && idx < num_embeddings"
            with self.assertRaisesRegex(RuntimeError, msg):
                torch.nn.functional.embedding_bag(
                    idx,
                    weight,
                    per_sample_weights=p_s_weights,
                    padding_idx=padding_idx,
                    mode=mode,
                )

    def test_embedding_bag_dimension_errors(self, device):
        funcs = (
            lambda x, y, z: torch.nn.functional.embedding_bag(y, x, z),
            torch.embedding_bag,
            torch._embedding_bag,
            torch._embedding_bag_forward_only,
        )
        for i, f in enumerate(funcs):
            err_type = (ValueError, RuntimeError) if i == 0 else RuntimeError

            weight = torch.full(
                (
                    2,
                    6,
                ),
                0,
                dtype=torch.float64,
                device=device,
            )
            indices = torch.full(
                (
                    2,
                    0,
                    0,
                    6,
                    6,
                ),
                2,
                dtype=torch.int64,
                device=device,
            )
            offsets = torch.full((2, 0, 0, 6, 6), 0, dtype=torch.int64, device=device)

            if i == 0:
                error_msg = "input has to be 1D or 2D Tensor"
            else:
                error_msg = "input has to be a 1D or 2D Tensor"
            torch._dynamo.disable(self.assertRaisesRegex)(
                err_type, error_msg, lambda: f(weight, indices, offsets)
            )

            weight = torch.full((2, 2), 0, dtype=torch.float64, device=device)
            indices = torch.full((2,), 1, dtype=torch.int64, device=device)

            torch._dynamo.disable(self.assertRaisesRegex)(
                err_type,
                "offsets has to be a 1D Tensor",
                lambda: f(weight, indices, offsets),
            )

            weight = torch.full((2, 2, 2), 0, dtype=torch.float64, device=device)
            indices = torch.full((2,), 2, dtype=torch.int64, device=device)
            offsets = torch.full((2,), 0, dtype=torch.int64, device=device)

            torch._dynamo.disable(self.assertRaisesRegex)(
                err_type,
                "weight has to be a 2D Tensor",
                lambda: f(weight, indices, offsets),
            )

    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
    def test_EmbeddingBag_per_sample_weights_failures(self, device, dtypes):
        # Failure 1: mismatched embeddings / per_sample_weights dtype
        es = nn.EmbeddingBag(5, 2, mode="sum").to(dtype=torch.float, device=device)
        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device)
        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device)
        per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
        if device == "cpu":
            with self.assertRaisesRegex(RuntimeError, "have the same type as"):
                es(input, offsets, per_sample_weights)
        else:
            with self.assertRaisesRegex(RuntimeError, "expected scalar type"):
                es(input, offsets, per_sample_weights)

        # Failure 2.1: input/per_sample_weights have different sizes (1d input)
        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device)
        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device)
        per_sample_weights = torch.randn(5, dtype=torch.float, device=device)
        with self.assertRaisesRegex(ValueError, "same shape as the input"):
            es(input, offsets, per_sample_weights)

        # Failure 2.2: input/per_sample_weights have different sizes (2d input)
        input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device)
        offsets = None
        per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device)
        with self.assertRaisesRegex(ValueError, "same shape as the input"):
            es(input, offsets, per_sample_weights)

        # Failure 3: Unsupported per_sample_weights and mode=('max', 'mean')
        for unsupported_mode in ("max", "mean"):
            es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to(
                dtype=torch.float, device=device
            )
            input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device)
            offsets = None
            per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device)
            with self.assertRaisesRegex(
                NotImplementedError, "only supported for mode='sum'"
            ):
                es(input, offsets, per_sample_weights)

    def _embedding_bag_reference_impl(
        self,
        input,
        weight,
        offsets=None,
        mode="sum",
        per_sample_weights=None,
        include_last_offset=False,
    ):
        assert mode == "sum" or per_sample_weights is None
        assert offsets is not None
        if per_sample_weights is None:
            per_sample_weights = torch.ones(input.size()).to(
                dtype=weight.dtype, device=weight.device
            )
        assert input.numel() == per_sample_weights.numel()

        bags = []
        long_input = input.to(torch.long)
        embeddings = weight.index_select(0, long_input) * per_sample_weights.unsqueeze(
            1
        )
        if include_last_offset:
            for index in range(len(offsets) - 1):
                offset = offsets[index]
                next_offset = offsets[index + 1]
                length = next_offset - offset
                if length == 0:
                    bags.append(
                        torch.tensor([0] * weight.size(1)).to(
                            dtype=embeddings.dtype, device=embeddings.device
                        )
                    )
                else:
                    if mode == "sum":
                        bags.append(embeddings.narrow(0, offset, length).sum(0))
                    elif mode == "mean":
                        bags.append(
                            embeddings.narrow(0, offset, length).sum(0).div(length)
                        )
                    else:
                        assert mode == "max"
                        bags.append(embeddings.narrow(0, offset, length).max(0)[0])
        else:
            for index, offset in enumerate(offsets):
                if index + 1 < len(offsets):
                    next_offset = offsets[index + 1]
                else:
                    next_offset = len(long_input)
                length = next_offset - offset
                if length == 0:
                    bags.append(
                        torch.tensor([0] * weight.size(1)).to(
                            dtype=embeddings.dtype, device=embeddings.device
                        )
                    )
                else:
                    if mode == "sum":
                        bags.append(embeddings.narrow(0, offset, length).sum(0))
                    elif mode == "mean":
                        bags.append(
                            embeddings.narrow(0, offset, length).sum(0).div(length)
                        )
                    else:
                        assert mode == "max"
                        bags.append(embeddings.narrow(0, offset, length).max(0)[0])
        return torch.stack(bags)

    @skipMeta
    @dtypes(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.half, torch.bfloat16, torch.float, torch.double),
        )
    )
    @dtypesIfCUDA(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.float, torch.double, torch.half),
        )
    )
    def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
        # Test empty input and per sample weight, and backward pass. There was a CUDA
        # invalid configuration bug (more context in #46572)
        def test_per_sample_weights(mode, trainable_scale):
            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
            es.weight.data.copy_(
                torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2])
            )
            input = torch.tensor([], device=device, dtype=dtypes[0])
            offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[1])
            per_sample_weights = torch.randn_like(
                input, dtype=dtypes[2]
            ).requires_grad_(trainable_scale)
            ref_per_sample_weights = per_sample_weights.detach().requires_grad_(
                trainable_scale
            )
            reference_weights = es.weight.detach().requires_grad_()

            expected = self._embedding_bag_reference_impl(
                input, reference_weights, offsets, mode, ref_per_sample_weights
            )
            result = es(input, offsets, per_sample_weights)
            self.assertEqual(
                result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0
            )

            grad = torch.randn_like(expected)
            result.backward(grad)
            # the reference impl doesn't have grad fn for empty input; but the grad should
            # simply be a zero tensor
            ref_weights_grad = torch.zeros_like(es.weight)
            self.assertEqual(
                es.weight.grad,
                ref_weights_grad,
                atol=dtype2prec_DONTUSE[dtypes[2]],
                rtol=0,
            )
            if trainable_scale:
                ref_per_sample_weights_grad = torch.empty_like(per_sample_weights)
                self.assertEqual(
                    per_sample_weights.grad,
                    ref_per_sample_weights_grad,
                    atol=dtype2prec_DONTUSE[dtypes[2]],
                    rtol=0,
                )

        modes = ("sum",)
        trainable_scale = (True, False)
        for mode, trainable in itertools.product(modes, trainable_scale):
            test_per_sample_weights(mode, trainable)

    @skipMeta
    @dtypes(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.float, torch.double, torch.half, torch.bfloat16),
        )
    )
    @dtypesIfCUDA(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.float, torch.double, torch.half),
        )
    )
    def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
        def test_per_sample_weights(mode, trainable_scale):
            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
            es.weight.data.copy_(
                torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2])
            )
            input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1])
            per_sample_weights = torch.randn_like(
                input, dtype=dtypes[2]
            ).requires_grad_(trainable_scale)
            ref_per_sample_weights = per_sample_weights.detach().requires_grad_(
                trainable_scale
            )
            reference_weights = es.weight.detach().requires_grad_()

            expected = self._embedding_bag_reference_impl(
                input, reference_weights, offsets, mode, ref_per_sample_weights
            )
            result = es(input, offsets, per_sample_weights)
            self.assertEqual(
                result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0
            )

            grad = torch.randn_like(expected).to(dtype=dtypes[2], device=device)
            result.backward(grad)
            expected.backward(grad)
            self.assertEqual(
                es.weight.grad,
                reference_weights.grad,
                atol=dtype2prec_DONTUSE[dtypes[2]],
                rtol=0,
            )
            if trainable_scale:
                self.assertEqual(
                    per_sample_weights.grad,
                    ref_per_sample_weights.grad,
                    atol=dtype2prec_DONTUSE[dtypes[2]],
                    rtol=0,
                )

        modes = ("sum",)
        trainable_scale = (True, False)
        for mode, trainable in itertools.product(modes, trainable_scale):
            test_per_sample_weights(mode, trainable)

    @skipMeta
    @dtypes(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.float, torch.double, torch.half, torch.bfloat16),
        )
    )
    @dtypesIfCUDA(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.float, torch.double, torch.half),
        )
    )
    def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
        def test_per_sample_weights_new_offsets(
            mode, trainable_scale, include_last_offset, has_weight=True
        ):
            es = nn.EmbeddingBag(
                5, 2, mode=mode, include_last_offset=include_last_offset
            ).to(dtype=dtypes[2], device=device)
            es.weight.data.copy_(
                torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2])
            )
            input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1])

            if include_last_offset:
                offsets = torch.cat(
                    (
                        offsets,
                        torch.tensor([input.size(0)], device=device, dtype=dtypes[1]),
                    ),
                    0,
                )

            if has_weight:
                per_sample_weights = torch.randn_like(
                    input, device=device, dtype=dtypes[2]
                ).requires_grad_(trainable_scale)
                ref_per_sample_weights = per_sample_weights.detach().requires_grad_(
                    trainable_scale
                )
            else:
                per_sample_weights = None
                ref_per_sample_weights = None

            reference_weights = es.weight.detach().requires_grad_()

            expected = self._embedding_bag_reference_impl(
                input,
                reference_weights,
                offsets,
                mode,
                ref_per_sample_weights,
                include_last_offset,
            )
            result = es(input, offsets, per_sample_weights)
            self.assertEqual(
                result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0
            )

            grad = torch.randn_like(expected)
            result.backward(grad)
            expected.backward(grad)
            self.assertEqual(
                es.weight.grad,
                reference_weights.grad,
                atol=dtype2prec_DONTUSE[dtypes[2]],
                rtol=0,
            )
            if has_weight and trainable_scale:
                self.assertEqual(
                    per_sample_weights.grad,
                    ref_per_sample_weights.grad,
                    atol=dtype2prec_DONTUSE[dtypes[2]],
                    rtol=0,
                )

        trainable_scale = (True, False)
        include_last_offset_list = (True, False)
        modes = (("sum", False), ("sum", True), ("max", False), ("mean", False))
        for (mode, has_weight), trainable, include_last_offset in itertools.product(
            modes, trainable_scale, include_last_offset_list
        ):
            test_per_sample_weights_new_offsets(
                mode, trainable, include_last_offset, has_weight
            )

    def _test_EmbeddingBag_vs_Embedding(
        self,
        N,
        D,
        B,
        L,
        max_norm=None,
        mode="mean",
        device="cpu",
        wdtype=torch.float,
        dtype=torch.long,
        test_per_sample_weights=False,
        trainable_per_sample_weights=False,
        sparse=False,
        test_backward=True,
        backward_prec=None,
    ):
        es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(
            device, wdtype
        )
        e = nn.Embedding(N, D, max_norm=max_norm).to(device, wdtype)
        e.weight.data.copy_(es.weight)
        input = torch.randint(N, (B, L), device=device, dtype=dtype)
        offsets = torch.arange(0, B, device=device, dtype=dtype).mul_(L)
        grad_output = torch.rand(B, D, device=device, dtype=wdtype)

        if test_per_sample_weights:
            # To prevent large gradients, weights should sum to 1 for each bag
            per_sample_weights = torch.randn(B, L, device=device, dtype=wdtype).softmax(
                dim=-1
            )
            per_sample_weights_reference = per_sample_weights.clone().requires_grad_(
                trainable_per_sample_weights
            )
            per_sample_weights.requires_grad_(trainable_per_sample_weights)
            output = es(input.view(-1), offsets, per_sample_weights.view(-1))
        else:
            output = es(input.view(-1), offsets)
            per_sample_weights = None
            per_sample_weights_reference = None

        if mode == "sum":
            if test_per_sample_weights:
                ref_output = (
                    e(input) * per_sample_weights_reference.unsqueeze(-1)
                ).sum(1)
            else:
                ref_output = e(input).sum(1)
        elif mode == "mean":
            assert not test_per_sample_weights
            ref_output = e(input).mean(1)
        elif mode == "max":
            assert not test_per_sample_weights
            ref_output = e(input).max(1)[0]

        self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[wdtype], rtol=0)

        if not test_backward:
            return

        output.backward(grad_output)
        ref_output.backward(grad_output)
        es_weight_grad = es.weight.grad
        if sparse:
            es_weight_grad = es.weight.grad.to_dense()

        # We have more floating point error here because we are dealing with larger numbers
        if backward_prec is None:
            needed_prec = dtype2prec_DONTUSE[wdtype] * 5
            rtol = 0.02 if wdtype == torch.half else 0
        else:
            needed_prec = backward_prec
            rtol = 0

        self.assertEqual(es_weight_grad, e.weight.grad, atol=needed_prec, rtol=rtol)

        if test_per_sample_weights and trainable_per_sample_weights:
            self.assertEqual(
                per_sample_weights.grad,
                per_sample_weights_reference.grad,
                atol=dtype2prec_DONTUSE[wdtype],
                rtol=0,
            )

    @dtypesIfCUDA(
        *itertools.product(
            (torch.int, torch.long), (torch.half, torch.float, torch.double)
        )
    )
    @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
    def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes):
        def run_tests(mode, sparse, trainable_per_sample_weights):
            kwargs = dict(
                test_per_sample_weights=True,
                device=device,
                mode=mode,
                wdtype=dtypes[1],
                dtype=dtypes[0],
                sparse=sparse,
                trainable_per_sample_weights=trainable_per_sample_weights,
            )

            # Simple case
            self._test_EmbeddingBag_vs_Embedding(2, 3, 5, 7, **kwargs)

            # B * L > 1000
            self._test_EmbeddingBag_vs_Embedding(2, 5, 53, 23, **kwargs)

            # Large num_embedding
            self._test_EmbeddingBag_vs_Embedding(101, 5, 3, 7, **kwargs)

            # Large embedding_dim
            self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)

        modes = ("sum",)
        sparsity = (True, False)
        trainable_scale = (True, False)
        for mode, sparse, trainable_per_sample_weights in itertools.product(
            modes, sparsity, trainable_scale
        ):
            run_tests(mode, sparse, trainable_per_sample_weights)

        # Test CUDA Dense on half precision
        if device == "cuda":
            modes = ("sum",)
            sparsity = (False,)
            trainable_scale = (True, False)
            for mode, sparse, trainable_per_sample_weights in itertools.product(
                modes, sparsity, trainable_scale
            ):
                run_tests(mode, sparse, trainable_per_sample_weights)

    def _test_EmbeddingBag(
        self,
        device,
        mode,
        sparse,
        wdtype=torch.double,
        dtype=torch.long,
        odtype=torch.long,
        test_backward=True,
    ):
        # check a known test example
        es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype)
        es.weight.data.copy_(
            torch.arange(1, 11, device=device).view_as(es.weight).to(wdtype)
        )
        input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype)
        offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=odtype)

        grad_output = torch.tensor([1, 2, 3, 4], device=device, dtype=wdtype).view(2, 2)
        grad_output_with_empty = torch.tensor(
            [99, 99, 1, 2, 99, 99, 3, 4, 99, 99], device=device, dtype=wdtype
        ).view(5, 2)

        if mode == "sum" or mode == "mean":
            denominator = 1 if mode == "sum" else 3
            expected_output = (
                torch.tensor([[13, 16], [13, 16]], device=device, dtype=wdtype)
                / denominator
            )

            expected_output_with_empty = (
                torch.tensor(
                    [[0, 0], [13, 16], [0, 0], [13, 16], [0, 0]],
                    device=device,
                    dtype=wdtype,
                )
                / denominator
            )

            expected_grad_weight = (
                torch.tensor(
                    [[3, 4], [5, 8], [0, 0], [1, 2], [3, 4]],
                    device=device,
                    dtype=wdtype,
                )
                / denominator
            )
        elif mode == "max":
            expected_output = torch.tensor(
                [[7, 8], [9, 10]], device=device, dtype=wdtype
            )

            expected_output_with_empty = torch.tensor(
                [[0, 0], [7, 8], [0, 0], [9, 10], [0, 0]], device=device, dtype=wdtype
            )

            expected_grad_weight = torch.tensor(
                [[0, 0], [0, 0], [0, 0], [1, 2], [3, 4]], device=device, dtype=wdtype
            )
        output = es(input, offsets)
        output.backward(grad_output_with_empty)

        es_weight_grad = es.weight.grad
        if sparse:
            es_weight_grad = es.weight.grad.to_dense()
        self.assertEqual(output, expected_output_with_empty)
        self.assertEqual(
            es_weight_grad,
            expected_grad_weight,
            atol=dtype2prec_DONTUSE[wdtype],
            rtol=0,
        )

        # check same example except as 2D (2 x 3)
        input = input.view(2, -1)
        es.zero_grad()
        output = es(input)
        output.backward(grad_output)

        es_weight_grad = es.weight.grad
        if sparse:
            es_weight_grad = es.weight.grad.to_dense()
        self.assertEqual(output, expected_output)
        self.assertEqual(
            es_weight_grad,
            expected_grad_weight,
            atol=dtype2prec_DONTUSE[wdtype],
            rtol=0,
        )

        # test all empty bags
        es.zero_grad()
        inputs = torch.tensor([], dtype=dtype, device=device)
        offsets = torch.tensor([0, 0, 0, 0], dtype=odtype, device=device)
        es(inputs, offsets).sum().backward()
        dense_grad = es.weight.grad
        if dense_grad.is_sparse:
            dense_grad = dense_grad.to_dense()
        self.assertEqual(dense_grad, torch.zeros_like(es.weight))

        # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
        N, D, B, L = (
            random.randint(1, 100),
            random.randint(1, 100),
            random.randint(1, 50),
            random.randint(1, 50),
        )
        kwargs = dict(
            mode=mode,
            sparse=sparse,
            device=device,
            wdtype=wdtype,
            dtype=dtype,
            test_backward=test_backward,
        )
        self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs)
        for max_norm in (None, 3):
            for p in itertools.product([1, 2], repeat=4):
                self._test_EmbeddingBag_vs_Embedding(*p, max_norm=max_norm, **kwargs)

        # check that giving illegal input combos raises error
        es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
        input = torch.ones(3, 4, dtype=dtype)
        offset = torch.arange(0, 3, dtype=odtype)
        torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input, offset))
        torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input.view(-1)))
        offset[0] = 1
        if self.device_type == "cpu":
            torch._dynamo.disable(self.assertRaises)(
                RuntimeError, lambda: es(input.view(-1), offset)
            )
            offset[0] = 0
            offset[-1] = 100
            torch._dynamo.disable(self.assertRaises)(
                RuntimeError, lambda: es(input.view(-1), offset)
            )

    @skipMeta
    @dtypes(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.float, torch.double, torch.half, torch.bfloat16),
        )
    )
    @dtypesIfCUDA(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.float, torch.double, torch.half),
        )
    )
    def test_embedding_bag_device(self, device, dtypes):
        if IS_JETSON and torch.bfloat16 in dtypes and device == "cpu":
            self.skipTest("bfloat16 not supported with Jetson cpu")
        with set_default_dtype(torch.double):
            self._test_EmbeddingBag(
                device,
                "sum",
                False,
                wdtype=dtypes[2],
                dtype=dtypes[0],
                odtype=dtypes[1],
            )
            self._test_EmbeddingBag(
                device,
                "mean",
                False,
                wdtype=dtypes[2],
                dtype=dtypes[0],
                odtype=dtypes[1],
            )
            self._test_EmbeddingBag(
                device,
                "max",
                False,
                wdtype=dtypes[2],
                dtype=dtypes[0],
                odtype=dtypes[1],
            )

            test_backward = False
            if self.device_type == "cuda":
                # see 'todo' in test_embedding_bag.
                test_backward = dtypes[2] is not torch.float16
            elif self.device_type == "cpu":
                # TODO: figure out why precision on sparse embeddings isn't the
                # same as for dense.
                test_backward = (
                    dtypes[2] is not torch.float and dtypes[2] is not torch.float16
                )

            self._test_EmbeddingBag(
                device,
                "sum",
                True,
                wdtype=dtypes[2],
                dtype=dtypes[0],
                odtype=dtypes[1],
                test_backward=test_backward,
            )
            self._test_EmbeddingBag(
                device,
                "mean",
                True,
                wdtype=dtypes[2],
                dtype=dtypes[0],
                odtype=dtypes[1],
                test_backward=test_backward,
            )

    @skipMeta
    @dtypes(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.float, torch.double, torch.half, torch.bfloat16),
        )
    )
    @dtypesIfCUDA(
        *itertools.product(
            (torch.int, torch.long),
            (torch.int, torch.long),
            (torch.float, torch.double, torch.half),
        )
    )
    def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
        weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device)

        weight_tensor_non_contig = weight_tensor[
            :, :3
        ]  # This is non-contiguous strided.
        weight_tensor_contig = (
            weight_tensor_non_contig.clone().contiguous()
        )  # Contig-strided.

        index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device)
        offsets = torch.tensor([0, 2], dtype=dtypes[1], device=device)
        for mode in ["sum", "mean", "max"]:
            output_non_contig = F.embedding_bag(
                input=index,
                weight=weight_tensor_non_contig,
                offsets=offsets,
                mode=mode,
            )
            output_contig = F.embedding_bag(
                input=index,
                weight=weight_tensor_contig,
                offsets=offsets,
                mode=mode,
            )
        self.assertEqual(output_non_contig, output_contig)

    @onlyNativeDeviceTypes  # currently fails on XLA
    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
    def test_embedding_bag_bfloat16(self, device, dtypes):
        with set_default_dtype(torch.double):
            self._test_EmbeddingBag(
                device,
                "sum",
                True,
                wdtype=torch.bfloat16,
                dtype=dtypes[0],
                odtype=dtypes[1],
                test_backward=True,
            )
            self._test_EmbeddingBag(
                device,
                "mean",
                True,
                wdtype=torch.bfloat16,
                dtype=dtypes[0],
                odtype=dtypes[1],
                test_backward=True,
            )

    @onlyNativeDeviceTypes  # currently fails on XLA
    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
    def test_embedding_bag_half(self, device, dtypes):
        self._test_EmbeddingBag(
            device,
            "sum",
            True,
            wdtype=torch.float16,
            dtype=dtypes[0],
            odtype=dtypes[1],
            test_backward=True,
        )


instantiate_device_type_tests(TestEmbeddingNNDeviceType, globals())
instantiate_parametrized_tests(TestEmbeddingNN)

if __name__ == "__main__":
    run_tests()
