# Owner(s): ["module: nestedtensor"]

import io
import itertools
import math
import sys
import unittest
from functools import partial
from typing import Optional, Tuple

import numpy as np

import torch
import torch._dynamo
import torch._dynamo.testing
import torch.nn
import torch.nn.functional as F
from torch.nested._internal.nested_tensor import (
    buffer_from_jagged,
    jagged_from_list,
    nested_view_from_values_offsets,
    NestedTensor,
    ViewNestedFromBuffer,
)
from torch.testing._internal.common_cuda import (
    PLATFORM_SUPPORTS_FUSED_ATTENTION,
    SM70OrLater,
    SM80OrLater,
)
from torch.testing._internal.common_device_type import (
    dtypes,
    dtypesIfCUDA,
    instantiate_device_type_tests,
    onlyCPU,
    onlyCUDA,
    ops,
    PYTORCH_CUDA_MEMCHECK,
    skipCPUIf,
    skipCUDAIf,
    skipCUDAIfRocm,
    skipMeta,
)
from torch.testing._internal.common_dtype import floating_types_and_half
from torch.testing._internal.common_utils import (
    decorateIf,
    freeze_rng_state,
    gradcheck,
    instantiate_parametrized_tests,
    IS_FBCODE,
    IS_WINDOWS,
    markDynamoStrictTest,
    NestedTensorTestCase,
    parametrize,
    run_tests,
    skipIfSlowGradcheckEnv,
    skipIfTorchDynamo,
    subtest,
    TEST_WITH_ROCM,
    xfailIfTorchDynamo,
)
from torch.testing._internal.opinfo.definitions.nested import njt_op_db
from torch.utils._pytree import tree_flatten
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts


# Tests are ported from pytorch/nestedtensor.
# This makes porting as_nested_tensor easier in the future.


def _iter_constructors():
    # yield as_nested_tensor
    yield torch.nested.nested_tensor


# Returns True if the function recompiles between inputs1 and inputs2 with the
# specified dynamic setting.
def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
    compile_count = [0]

    def counter(gm, example_inputs):
        compile_count[0] += 1
        return gm

    compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
    compiled_f(*inputs1)
    compiled_f(*inputs2)
    return compile_count[0] > 1


# Helper function to generate a pair of random nested tensors
# one is contiguous, the other is not, but they appear to have same entries
# an output nested tensor consists of
# * `len(ragged_sizes)` matrices
# * matrices[i].shape == (20, ragged_sizes[i])


def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16):
    xs = []
    for size in ragged_sizes:
        xs.append(torch.randn((size, 20), device=device, dtype=dtype))
    # contiguous nested tensor
    ys = []
    for x in xs:
        ys.append(x.transpose(-1, -2))
    nt_contiguous = torch.nested.nested_tensor(ys)
    # noncontiguous nested tensor
    n = len(ragged_sizes)
    nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2)
    return nt_contiguous, nt_noncontiguous


# Helper functions to pad a noncontiguous nested tensor
# can be replaced once to_padded_tensor supports noncontiguous memory


def noncontiguous_to_padded_tensor(input, shape=None):
    tensors = input.unbind()
    ntensors = len(tensors)
    assert ntensors > 0
    if shape is None:
        shape = []
        for size in tensors[0].shape:
            shape.append(size)
        for i in range(1, ntensors):
            new_shape = tensors[i].shape
            for j in range(len(shape)):
                shape[j] = max(shape[j], new_shape[j])
        shape = [ntensors] + shape
    result = tensors[0].new_zeros(shape)
    for itensor in range(ntensors):
        tensor = tensors[itensor]
        view = result[itensor]
        for idim in range(tensor.dim()):
            view = view.narrow(idim, 0, tensor.size(idim))
        view.copy_(tensor)
    return result


# Helper function to generate a random nested tensor


def random_nt(
    device,
    dtype,
    num_tensors,
    max_dims,
    min_dims=None,
    layout=torch.strided,
    require_non_empty=True,
):
    if min_dims is None:
        min_dims = tuple([0] * len(max_dims))

    assert len(max_dims) == len(min_dims)
    for min_dim, max_dim in zip(min_dims, max_dims):
        assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim"
        assert min_dim >= 0, "random_nt: min_dim must be non-negative"
        if require_non_empty:
            assert not (
                min_dim == 0 and max_dim == 1
            ), "random_nt: zero cannot be the only possible value if require_non_empty is True"

    if require_non_empty:
        # Select a random idx that will be required to be non-empty
        non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item()

    ts1 = []
    for i, _ in enumerate(range(num_tensors)):
        tensor_dims = []
        for min_dim, max_dim in zip(min_dims, max_dims):
            new_min_dim = min_dim
            if require_non_empty and i == non_zero_idx and min_dim == 0:
                new_min_dim = 1
            tensor_dims.append(
                torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item()
            )
        t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
        ts1.append(t1)

    return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout)


# Alternate approach to generating a random NT.
# dims should be something like [5, None, 10], with None indicating that a
# random ragged structure should be used
def random_nt_from_dims(
    dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
):
    sizes = [
        [
            d if d is not None else torch.randint(2, 10, size=(1,)).item()
            for d in dims[1:]
        ]
        for d in range(dims[0])
    ]
    return torch.nested.nested_tensor(
        [torch.randn(*size) for size in sizes],
        device=device,
        dtype=dtype,
        layout=layout,
        requires_grad=requires_grad,
    )


# Creates an NT matching another NT's number of components and
# shape / ragged structure for all dims specified to be -1.
def random_nt_from_similar(other, dims=None):
    if dims is None:
        return torch.randn_like(other)
    assert len(dims) == other.dim()
    assert dims[0] == -1 or dims[0] == other.size(0)

    ret_sizes = []
    for t in other.unbind():
        other_size = t.shape
        ret_size = []
        for i, d in enumerate(dims[1:]):
            if d == -1:
                ret_size.append(other_size[i])
            else:
                ret_size.append(d)
        ret_sizes.append(ret_size)

    return torch.nested.nested_tensor(
        [torch.randn(*size) for size in ret_sizes], device=other.device
    )


# makes naming nice for tests that parametrize over layout.
def layout_name(layout):
    # e.g. "torch.jagged" -> "jagged"
    return layout.__repr__().split(".")[-1]


def get_op_name(layout):
    # e.g. "<OpOverload(op='aten.sum', overload='dim_IntList')>" -> "sum"
    return layout.__name__.split(".")[0].split("_")[-1]


# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_dense_to_nested_tensor_legacy(values):
    offsets = torch.arange(
        0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device
    )
    metadata_cache = {"max_seqlen": values.shape[1], "min_seqlen": 1}
    nt = ViewNestedFromBuffer.apply(
        values.view(-1, values.shape[-1]), offsets, metadata_cache
    )
    return nt


# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_jagged_to_nested_tensor_legacy(
    values: torch.Tensor, offsets: torch.Tensor, max_length: int
) -> torch.Tensor:
    metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1}
    nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache)
    return nt


# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_nt_to_jagged_legacy(nt):
    return buffer_from_jagged(nt)


# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_dense_to_nested_tensor(values):
    nt = torch.nested.as_nested_tensor(values, layout=torch.jagged)
    return nt


# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_jagged_to_nested_tensor(
    values: torch.Tensor, offsets: torch.Tensor, max_length: int
) -> torch.Tensor:
    nt = torch.nested.nested_tensor_from_jagged(
        values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length
    )
    return nt


# Helper function for test_dummy_mha_with_nt
def convert_nt_to_jagged(nt):
    return nt.values()


@markDynamoStrictTest
class TestNestedTensor(NestedTensorTestCase):
    @parametrize("batch_size", [2, 4])
    @parametrize("max_seq_len", [3, 5])
    @parametrize("vocab_size", [10, 20])
    def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
        data = []
        nested_tensor_ref_list = []
        for _ in range(batch_size):
            if max_seq_len == 0:
                length = 0
            else:
                length = np.random.randint(low=1, high=max_seq_len)
            row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
            data.append(row)
            nested_tensor_ref_list.append(torch.Tensor(row))
        nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
        nested_tensor_list = nested_tensor.unbind()
        for id in range(batch_size):
            self.assertEqual(
                nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
            )

    @parametrize("batch_size", [2, 4])
    @parametrize("max_seq_len", [3, 5])
    @parametrize("vocab_size", [10, 20])
    def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
        data = []
        nested_tensor_ref_list = []
        for _ in range(batch_size):
            if max_seq_len == 0:
                length = 0
            else:
                length = np.random.randint(low=1, high=max_seq_len)
            row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
            row = [list(item * np.arange(max_seq_len)) for item in row]
            data.append(row)
            nested_tensor_ref_list.append(torch.Tensor(row))
        nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
        nested_tensor_list = nested_tensor.unbind()
        for id in range(batch_size):
            self.assertEqual(
                nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
            )

    @parametrize("batch_size", [2, 4])
    @parametrize("max_seq_len", [3, 5])
    @parametrize("vocab_size", [10, 20])
    def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size):
        data = []
        nested_tensor_ref_list = []
        for _ in range(batch_size):
            if max_seq_len == 0:
                length = 0
            else:
                length = np.random.randint(low=1, high=max_seq_len)
            row = list(
                np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float)
            )
            row = [list(item * np.arange(max_seq_len)) for item in row]
            data.append(row)
            nested_tensor_ref_list.append(torch.Tensor(row))
        nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float)
        nested_tensor_list = nested_tensor.unbind()
        for id in range(batch_size):
            self.assertEqual(
                nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float)
            )

    @torch.inference_mode()
    def _test_unbind_case(self, a, b):
        nt = torch.nested.nested_tensor([a, b])
        a1, b1 = nt.unbind()
        self.assertTrue(a is not a1)
        self.assertTrue(b is not b1)

        nt = torch.nested.nested_tensor([a, b], dtype=a.dtype)
        a1, b1 = nt.unbind(0)
        self.assertEqual(a, a1)
        self.assertEqual(b, b1)

        a = torch.randn((2, 3)).add_(1)
        nt = torch.nested.nested_tensor([a])
        self.assertEqual(a, nt.unbind(0)[0])

    @torch.inference_mode()
    def test_unbind_0(self):
        self._test_unbind_case(torch.tensor([1, 2]), torch.tensor([7, 8]))

    @torch.inference_mode()
    def test_unbind_1(self):
        self._test_unbind_case(torch.tensor([1]), torch.tensor([7]))

    @torch.inference_mode()
    def test_unbind_3(self):
        self._test_unbind_case(torch.tensor([1.0]), torch.tensor([]))

    @torch.inference_mode()
    def test_unbind_4(self):
        self._test_unbind_case(torch.tensor([]), torch.tensor([]))

    @torch.inference_mode()
    def test_unbind_dim(self):
        def _test_fn(unbind_fn):
            a = torch.rand(3, 2)
            b = torch.rand(2, 3)
            nt = torch.nested.nested_tensor([a, b])
            self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))

        # Both of these tests are necessary, because we're using
        # torch_function.
        _test_fn(lambda x, dim: x.unbind(dim))
        # TODO: Re-enable this once using torch_dispatch
        # _test_fn(lambda x, dim: torch.unbind(x, dim))

    @torch.inference_mode()
    def test_nested_tensor(self):
        self.assertRaises(
            TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0]))
        )
        self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0))

    @torch.inference_mode()
    def test_nested_tensor_matching_dim(self):
        self.assertRaisesRegex(
            RuntimeError,
            "Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
            lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]),
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
            lambda: torch.nested.nested_tensor(
                [torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
            ),
        )

    @torch.inference_mode()
    def test_default_nested_tensor(self):
        self.assertRaises(TypeError, lambda: torch.nested.nested_tensor())
        default_nested_tensor = torch.nested.nested_tensor([])
        default_tensor = torch.tensor([])
        # self.assertEqual(default_nested_tensor.nested_dim(), 1)
        # self.assertEqual(default_nested_tensor.nested_size(), ())
        self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
        self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
        self.assertEqual(default_nested_tensor.device, default_tensor.device)
        self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
        self.assertEqual(
            default_nested_tensor.requires_grad, default_tensor.requires_grad
        )
        self.assertIsNone(default_tensor.grad)
        # TODO: Re-enable once we have a performance driven
        # use case and implementation.
        # self.assertEqual(default_nested_tensor.is_pinned(),
        #                  default_tensor.is_pinned())

    @torch.inference_mode()
    def test_dim(self):
        for constructor in _iter_constructors():
            a1 = constructor([])
            self.assertEqual(a1.dim(), 1)
            a1 = constructor([torch.tensor(3.0)])
            self.assertEqual(a1.dim(), 1)
            a1 = constructor([torch.tensor([1, 2, 3, 4])])
            self.assertEqual(a1.dim(), 2)

    @unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
    @torch.inference_mode()
    def test_numel(self):
        for constructor in _iter_constructors():
            a1 = constructor([])
            self.assertEqual(a1.numel(), 0)
            a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)])
            self.assertEqual(a1.numel(), 2)
            a1 = constructor([torch.randn(2, 2, 2)])
            self.assertEqual(a1.numel(), 8)
            a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)])
            self.assertEqual(a1.numel(), 12)
            a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)])
            self.assertEqual(a1.numel(), 27)
            a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)])
            self.assertEqual(a1.numel(), 341)

            # Interesting edge case
            a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)])
            self.assertEqual(a1.numel(), 6)

    @torch.inference_mode()
    def test_size(self):
        for constructor in _iter_constructors():
            a1 = constructor([])
            self.assertRaisesRegex(
                RuntimeError,
                "NestedTensorImpl doesn't support sizes",
                lambda: a1.size(),
            )

    def test_size_dim(self):
        a = torch.nested.nested_tensor([])
        self.assertEqual(a.size(0), 0)

        a = torch.nested.nested_tensor([torch.tensor(1)])
        self.assertEqual(a.size(0), 1)

        a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)])
        self.assertEqual(a.size(0), 2)

        a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)])
        self.assertEqual(a.size(0), 2)
        self.assertEqual(a.size(1), 1)
        self.assertRaisesRegex(
            RuntimeError,
            "Given dimension 2 is irregular and does not have a size",
            lambda: a.size(2),
        )

        a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)])
        self.assertEqual(a.size(0), 2)
        self.assertRaisesRegex(
            RuntimeError,
            "Given dimension 1 is irregular and does not have a size",
            lambda: a.size(1),
        )
        self.assertEqual(a.size(2), 4)

    @unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
    @torch.inference_mode()
    def test_stride(self):
        for constructor in _iter_constructors():
            a1 = constructor([])
            self.assertRaisesRegex(
                RuntimeError,
                "NestedTensorImpl doesn't support strides",
                lambda: a1.stride(),
            )

    @unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
    @torch.inference_mode()
    def test_is_contiguous(self):
        # Test empty case
        nt_empty = torch.nested.nested_tensor([])
        assert nt_empty.is_contiguous()
        self.assertEqual(nt_empty, nt_empty.contiguous())

        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))

        # Test contiguous case
        assert nt_contiguous.is_contiguous()
        self.assertEqual(nt_contiguous, nt_contiguous.contiguous())

        # Test non_contiguous case
        assert not nt_noncontiguous.is_contiguous()
        self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous())

        # Test querying by memory_format
        self.assertTrue(
            nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
        )
        self.assertTrue(
            not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
        )

    @torch.inference_mode()
    def test_repr_string(self):
        a = torch.nested.nested_tensor([])
        expected = "nested_tensor([\n\n])"
        self.assertEqual(str(a), expected)
        self.assertEqual(repr(a), expected)

        a = torch.nested.nested_tensor([torch.tensor(1.0)])
        expected = "nested_tensor([\n  tensor(1.)\n])"
        self.assertEqual(str(a), expected)
        self.assertEqual(repr(a), expected)

        a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
        expected = "nested_tensor([\n  tensor([[1, 2]]),\n  tensor([[4, 5]])\n])"
        self.assertEqual(str(a), expected)
        self.assertEqual(repr(a), expected)

    def test_to_padded_tensor_on_empty_tensor(self):
        nt = torch.nested.nested_tensor([])
        empty = torch.nested.to_padded_tensor(nt, 4)
        self.assertEqual(empty, torch.tensor([]))

    def test_nested_namespace(self):
        nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)])
        result = nt.to_padded_tensor(4)
        nested_namespace_result = torch.nested.to_padded_tensor(nt, 4)
        self.assertEqual(result, nested_namespace_result)

    def test_to(self):
        ntensors = 4
        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))

        def test_copy_behavior(t, non_blocking=False):
            self.assertIs(t, t.to(t, non_blocking=non_blocking))
            self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
            self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
            self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
            self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
            self.assertIsNot(
                t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)
            )

            devices = [t.device]
            if t.device.type == "cuda":
                if t.device.index == -1:
                    devices.append(f"cuda:{torch.cuda.current_device()}")
                elif t.device.index == torch.cuda.current_device():
                    devices.append("cuda")
            for device in devices:
                self.assertIs(t, t.to(device, non_blocking=non_blocking))
                self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
                self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
                self.assertIsNot(
                    t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)
                )

        test_copy_behavior(nt)
        self.assertEqual(nt.device, nt.to("cpu").device)
        self.assertEqual(nt.device, nt.to("cpu", dtype=torch.float32).device)
        self.assertIs(torch.float32, nt.to("cpu", dtype=torch.float32).dtype)
        self.assertEqual(nt.device, nt.to(torch.float32).device)
        self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype)

        def test_data_ptr(getter):
            self.assertEqual(getter(nt), getter(nt.to("cpu")))
            self.assertEqual(
                getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False))
            )
            self.assertEqual(getter(nt), getter(nt.to("cpu", copy=False)))
            self.assertNotEqual(getter(nt), getter(nt.to("cpu", copy=True)))

        test_data_ptr(lambda nt: nt.data_ptr())

        if torch.cuda.is_available():
            for non_blocking in [True, False]:
                for cuda in [
                    "cuda",
                    "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1",
                ]:
                    nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4))
                    test_copy_behavior(nt2, non_blocking)
                    self.assertEqual(
                        nt2.device, nt2.to(cuda, non_blocking=non_blocking).device
                    )
                    self.assertEqual(
                        nt.device, nt2.to("cpu", non_blocking=non_blocking).device
                    )
                    self.assertEqual(
                        nt2.device, nt.to(cuda, non_blocking=non_blocking).device
                    )
                    self.assertIs(
                        torch.int32,
                        nt2.to(
                            "cpu", dtype=torch.int32, non_blocking=non_blocking
                        ).dtype,
                    )
                    self.assertEqual(
                        nt.device,
                        nt2.to(
                            "cpu", dtype=torch.int32, non_blocking=non_blocking
                        ).device,
                    )
                    self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype)
                    self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device)

    def test_copy_(self):
        ntensors = 4
        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
        nt_copy = torch.empty_like(nt)
        nt_copy.copy_(nt)

        for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
            self.assertEqual(nt_ub, nt_copy_ub)

        nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])])
        self.assertRaisesRegex(
            RuntimeError,
            "copy_ only supports tensors that are the same size for Nested implementations",
            lambda: nt_error.copy_(nt),
        )

        if torch.cuda.is_available():
            nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4))
            nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
            nt_copy.copy_(nt, non_blocking=True)
            torch.cuda.current_stream(torch.cuda.current_device()).synchronize()
            for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
                self.assertEqual(nt_ub, nt_copy_ub)

            nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
            nt_copy.copy_(nt, non_blocking=False)
            for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
                self.assertEqual(nt_ub, nt_copy_ub)

    def test_fill_(self):
        ntensors = 4
        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
        nt.fill_(10.0)
        for nt_ub in nt.unbind():
            t = torch.empty_like(nt_ub)
            t.fill_(10.0)
            self.assertEqual(nt_ub, t)

        fill_tensor = torch.tensor([11.0])
        self.assertRaisesRegex(
            RuntimeError,
            "fill_ only supports 0-dimension value tensor",
            lambda: nt.fill_(fill_tensor),
        )

        nt.fill_(fill_tensor[0])
        for nt_ub in nt.unbind():
            t = torch.empty_like(nt_ub)
            t.fill_(11.0)
            self.assertEqual(nt_ub, t)

    def test_zero_(self):
        ntensors = 4
        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
        nt.zero_()
        for nt_ub in nt.unbind():
            t = torch.empty_like(nt_ub)
            t.fill_(0.0)
            self.assertEqual(nt_ub, t)

    @parametrize(
        "func",
        [torch.ones_like, torch.zeros_like, torch.randn_like],
        name_fn=lambda f: f.__name__,
    )
    def test_like_functions(self, func):
        ntensors = 4
        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
        torch.manual_seed(1)
        nt_like = func(nt)

        torch.manual_seed(1)
        for nt_ub in nt_like.unbind():
            t_like = func(nt_ub)
            self.assertEqual(nt_ub, t_like)

    def test_cat(self):
        # dim=0 success case
        # No constraints on ragged structures matching.
        x = random_nt_from_dims([5, None, 10])
        y = random_nt_from_dims([3, 4, None])
        output = torch.cat([x, y], dim=0)
        for out_component, xy_component in zip(
            output.unbind(), itertools.chain(x.unbind(), y.unbind())
        ):
            self.assertEqual(out_component, xy_component)

        # dim=-1 success case
        # shape (B, *, D)
        x = random_nt_from_dims([5, None, 10])
        # shape (B, *, D'); same structure as x but dim=-1 differs
        y = random_nt_from_similar(x, dims=[-1, -1, 8])
        # should be shape (B, *, D + D') when supported
        output = torch.cat([x, y], dim=-1)
        for out_component, x_component, y_component in zip(
            output.unbind(), x.unbind(), y.unbind()
        ):
            self.assertEqual(
                out_component, torch.cat([x_component, y_component], dim=-1)
            )

        # dim between 0 and -1 success case
        x = random_nt_from_dims([5, None, 2, 3])
        # same structure as x but dim=2 differs
        y = random_nt_from_similar(x, dims=[-1, -1, 4, -1])
        output = torch.cat([x, y], dim=2)
        for out_component, x_component, y_component in zip(
            output.unbind(), x.unbind(), y.unbind()
        ):
            self.assertEqual(
                out_component, torch.cat([x_component, y_component], dim=1)
            )

        # error case: mixed NT / dense inputs
        x = random_nt_from_dims([5, None, 2])
        y = torch.randn(5, 3, 2)
        with self.assertRaisesRegex(
            RuntimeError, "expected each tensor in given list to be nested"
        ):
            torch.cat([x, y], dim=-1)

        # error case: NTs with different dims
        x = random_nt_from_dims([5, None, 2])
        y = random_nt_from_dims([5, None, 2, 3])
        with self.assertRaisesRegex(
            RuntimeError,
            "expected all nested tensors to have matching ragged structures outside of the concatenated dim",
        ):
            torch.cat([x, y], dim=-1)

        # error case: non-contiguous NT
        x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32)
        # transpose to put ragged dim next to batch dim
        x, y = x.transpose(-2, -1), y.transpose(-2, -1)
        with self.assertRaisesRegex(
            RuntimeError, "only contiguous nested tensors are supported"
        ):
            torch.cat([x, y], dim=-1)

        # error case: multiple ragged dims in inputs
        x = random_nt_from_dims([5, None, None, 2])
        y = random_nt_from_similar(x)
        with self.assertRaisesRegex(
            RuntimeError,
            "only nested tensors with a single ragged dim next to the batch dim are supported",
        ):
            torch.cat([x, y], dim=-1)

        # error case: ragged dim not next to batch dim
        x = random_nt_from_dims([5, 2, None])
        y = random_nt_from_similar(x)
        with self.assertRaisesRegex(
            RuntimeError,
            "only nested tensors with a single ragged dim next to the batch dim are supported",
        ):
            torch.cat([x, y], dim=1)

        # error case: NTs with different batch sizes
        x = random_nt_from_dims([5, None, 2])
        y = random_nt_from_dims([3, None, 2])
        with self.assertRaisesRegex(
            RuntimeError,
            "expected all nested tensors to have matching ragged structures outside of the concatenated dim",
        ):
            torch.cat([x, y], dim=-1)

        # error case: NTs with different ragged structures
        x = torch.nested.nested_tensor(
            [
                torch.randn(2, 6),
                torch.randn(4, 6),
                torch.randn(5, 6),
            ]
        )
        y = torch.nested.nested_tensor(
            [
                torch.randn(5, 6),
                torch.randn(4, 6),
                torch.randn(2, 6),
            ]
        )
        with self.assertRaisesRegex(
            RuntimeError,
            "expected all nested tensors to have matching ragged structures outside of the concatenated dim",
        ):
            torch.cat([x, y], dim=-1)


@markDynamoStrictTest
class TestNestedTensorDeviceType(NestedTensorTestCase):
    # Helper function to generate a pair of random nested tensors
    # the 2 nested tensors have same shapes
    def random_nt_pair(self, device, dtype, num_tensors, max_dims):
        ts1 = []
        ts2 = []
        for _ in range(num_tensors):
            tensor_dims = tuple(
                [
                    torch.randint(low=0, high=max_dim, size=(1,)).item()
                    for max_dim in max_dims
                ]
            )
            t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
            t2 = torch.randn(tensor_dims, device=device, dtype=dtype)
            ts1.append(t1)
            ts2.append(t2)
        return (
            torch.nested.nested_tensor(ts1, device=device, dtype=dtype),
            torch.nested.nested_tensor(ts2, device=device, dtype=dtype),
        )

    @dtypes(*floating_types_and_half())
    def test_detach(self, device, dtype):
        a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False)
        b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False)
        x = torch.nested.nested_tensor([a, b], requires_grad=True)

        x_detach = x.detach()

        z = x_detach * 4
        self.assertFalse(x_detach.requires_grad)
        self.assertFalse(z.requires_grad)

        a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True)
        b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True)
        x = torch.nested.as_nested_tensor([a, b])

        y = x * 2
        y = y.detach()
        self.assertFalse(y.requires_grad)
        self.assertIsNone(y.grad_fn)

        z = x + y
        torch.nested.to_padded_tensor(z, 0).sum().backward()
        # This is an incorrect gradient, but we assume that's what the user
        # wanted. detach() is an advanced option.
        self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
        self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))

    @dtypes(torch.float, torch.float16, torch.double)
    def test_unbind_noncontiguous(self, device, dtype):
        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
            (2, 3, 6, 7), device, dtype
        )
        ub_contiguous = nt_contiguous.unbind()
        ub_noncontiguous = nt_noncontiguous.unbind()
        self.assertEqual(len(ub_contiguous), len(ub_noncontiguous))
        n = len(ub_contiguous)
        for i in range(n):
            self.assertEqual(ub_contiguous[i], ub_noncontiguous[i])

    @dtypes(torch.float)
    @skipMeta
    def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
        ts = list(torch.unbind(t))
        ts[0] = ts[0][:-1]
        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
        padded = torch.nested.to_padded_tensor(nt, 0)

        nt_to = torch._nested_from_padded_and_nested_example(padded, nt)

        for t1, t2 in zip(nt.unbind(), nt_to.unbind()):
            self.assertEqual(t1, t2)
        self.assertEqual(nt.device, nt_to.device)

    @dtypes(torch.float)
    @dtypesIfCUDA(torch.float, torch.half)
    @skipMeta
    @torch.inference_mode()
    def test_layer_norm(self, device, dtype):
        def _test(size):
            # Simple shapes test
            t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
            t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
            ts = [t0, t1, t0, t1]
            nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
            layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
            nt_result = layer_norm(nt)
            for nt_subresult, t in zip(nt_result.unbind(), ts):
                t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
                self.assertEqual(nt_subresult, t_result)

            # More complex nt test with different lengths for each tensor
            t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False)
            t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False)
            t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False)
            ts = [t0, t1, t2, t0, t2]
            nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
            layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
            nt_result = layer_norm(nt)
            for nt_subresult, t in zip(nt_result.unbind(), ts):
                t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
                self.assertEqual(nt_subresult, t_result)

            if size <= 128:
                # Test with multidimensional tensors after irregular dim
                # (run only with smaller dimensions to ensure fast execution)
                t0 = torch.randn(
                    4, size, size, 4, device=device, dtype=dtype, requires_grad=False
                )
                t1 = torch.randn(
                    10, size, size, 4, device=device, dtype=dtype, requires_grad=False
                )
                t2 = torch.randn(
                    7, size, size, 4, device=device, dtype=dtype, requires_grad=False
                )
                ts = [t0, t1, t2, t0, t2]
                nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
                layer_norm = torch.nn.LayerNorm(
                    (size, size, 4), device=device, dtype=dtype
                )
                nt_result = layer_norm(nt)
                for nt_subresult, t in zip(nt_result.unbind(), ts):
                    t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
                    self.assertEqual(nt_subresult, t_result)

                # Test where the normalizing dimensions are not all
                layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype)
                nt_result = layer_norm(nt)
                for nt_subresult, t in zip(nt_result.unbind(), ts):
                    t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
                    self.assertEqual(nt_subresult, t_result)

        for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32):
            _test(size)

    @dtypes(torch.float)
    @dtypesIfCUDA(torch.float, torch.half)
    @skipMeta
    @torch.inference_mode()
    def test_layer_norm_breaking(self, device, dtype):
        size = 128
        t0 = torch.randn(
            4, size, size, 4, device=device, dtype=dtype, requires_grad=False
        )
        t1 = torch.randn(
            10, size, size, 4, device=device, dtype=dtype, requires_grad=False
        )
        t2 = torch.randn(
            7, size, size, 4, device=device, dtype=dtype, requires_grad=False
        )
        ts = [t0, t1, t2, t0, t2]
        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
        layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype)
        self.assertRaisesRegex(
            RuntimeError,
            "normalized_shape extends into irregular dimensions for the nested tensor",
            lambda: layer_norm(nt),
        )
        layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype)
        self.assertRaisesRegex(
            RuntimeError,
            "The shape at dimension 0",
            lambda: layer_norm(nt),
        )

    @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
    def test_embedding(self, device, layout):
        inputs = [
            torch.randint(100, (L,), device=device, dtype=torch.int64)
            for L in torch.randint(5, 50, (8,))
        ]
        x = torch.nested.nested_tensor(
            inputs, device=device, dtype=torch.int64, layout=layout
        )
        emb = torch.nn.Embedding(100, 8, device=device)
        y = emb(x)

        @torch._dynamo.disable
        def check(inputs, y):
            ys = y.unbind()
            for i, inp in enumerate(inputs):
                self.assertEqual(emb(inp), ys[i])

        check(inputs, y)

    @skipMeta
    @torch.inference_mode()
    @dtypes(*floating_types_and_half())
    def test_masked_fill(self, device, dtype):
        # nested tensor * nested tensor
        (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4))
        mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()])
        ref = torch.nested.nested_tensor(
            [t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())]
        )
        out = nt.masked_fill(mask, 0)
        self.assertEqual(ref, out)

    @dtypes(torch.float, torch.float16)
    def test_to_padded_tensor_simple(self, device, dtype):
        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
        ts = list(torch.unbind(t))
        ts[0] = ts[0][:-1]
        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
        for padding_value in (0, 1):
            padded = torch.nested.to_padded_tensor(nt, padding_value)

            correct_output = t.clone()
            if padding_value == 0:
                correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
            else:
                correct_output[0][-1] = torch.ones_like(correct_output[0][-1])

            self.assertEqual(padded, correct_output)
            self.assertEqual(padded.device, torch.device(device))
            self.assertEqual(padded.dtype, dtype)

    @dtypes(torch.float, torch.float16)
    def test_to_padded_tensor_output_size(self, device, dtype):
        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
        output_size = (4, 6, 5)
        ts = list(torch.unbind(t))
        ts[0] = ts[0][:-1]
        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
        for padding_value in (0, 1):
            padded = torch.nested.to_padded_tensor(
                nt, padding_value, output_size=output_size
            )
            correct_output = (
                torch.ones(output_size, device=device, dtype=dtype) * padding_value
            )
            correct_output[:4:, :4, :4] = t.clone()
            if padding_value == 0:
                correct_output[0][3] = torch.zeros_like(correct_output[0][3])
            else:
                correct_output[0][3] = torch.ones_like(correct_output[0][3])

            self.assertEqual(padded, correct_output)
            self.assertEqual(padded.device, torch.device(device))
            self.assertEqual(padded.dtype, dtype)

    @dtypes(torch.float, torch.float16, torch.double)
    def test_to_padded_tensor_dim2(self, device, dtype):
        ts = [
            torch.randn(160, device=device, dtype=dtype),
            torch.randn(1240, device=device, dtype=dtype),
            torch.randn(2400, device=device, dtype=dtype),
        ]
        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
        pad = 42
        correct_output = []
        for t in ts:
            next_output = torch.ones_like(ts[2]) * pad
            correct_output.append(next_output)
            next_output[: t.size(0)].copy_(t)
        correct_output = torch.stack(correct_output)
        padded = torch.nested.to_padded_tensor(nt, pad)
        self.assertEqual(padded, correct_output)

    @dtypes(torch.float, torch.float16, torch.double)
    def test_to_padded_tensor_dim3(self, device, dtype):
        ts = [
            torch.randn(16, 21, device=device, dtype=dtype),
            torch.randn(24, 32, device=device, dtype=dtype),
            torch.randn(40, 53, device=device, dtype=dtype),
        ]
        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
        pad = 42
        correct_output = []
        for t in ts:
            next_output = torch.ones_like(ts[2]) * pad
            correct_output.append(next_output)
            next_output[: t.size(0), : t.size(1)].copy_(t)
        correct_output = torch.stack(correct_output)
        padded = torch.nested.to_padded_tensor(nt, pad)
        self.assertEqual(padded, correct_output)

    @dtypes(torch.float, torch.float16, torch.double)
    def test_to_padded_tensor_dim4(self, device, dtype):
        ts = [
            torch.randn(16, 21, 13, device=device, dtype=dtype),
            torch.randn(24, 32, 14, device=device, dtype=dtype),
            torch.randn(40, 53, 16, device=device, dtype=dtype),
        ]
        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
        pad = 42
        correct_output = []
        for t in ts:
            next_output = torch.ones_like(ts[2]) * pad
            correct_output.append(next_output)
            next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t)
        correct_output = torch.stack(correct_output)
        padded = torch.nested.to_padded_tensor(nt, pad)
        self.assertEqual(padded, correct_output)

    # TODO: test noncontiguous to_padded_tensor
    # For now this tests the functionality of noncontiguous_to_padded_tensor
    # and the error message of to_padded_tensor
    # since to_padded_tensor does not support noncontiguous buffer yet
    @dtypes(torch.float, torch.float16, torch.double)
    @torch.inference_mode()
    def test_to_padded_tensor_noncontiguous(self, device, dtype):
        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
            (2, 3, 6, 7), device, dtype
        )
        # test noncontiguous_to_padded_tensor functionality
        self.assertEqual(
            torch.nested.to_padded_tensor(nt_contiguous, 0.0),
            noncontiguous_to_padded_tensor(nt_noncontiguous),
        )
        # test to_padded_tensor error message
        self.assertRaisesRegex(
            RuntimeError,
            r"for now to_padded_tensor only supports contiguous nested tensor",
            lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0),
        )

    @skipMeta
    def test_device_checks(self, device):
        nt = torch.nested.nested_tensor([], device=device)
        is_cuda = "cuda" in str(device)
        self.assertEqual(nt.is_cuda, is_cuda)

    @dtypes(torch.float, torch.float16, torch.double)
    def test_nested_tensor_indexing(self, device, dtype):
        # edge case: empty nested tensor
        nt0 = torch.nested.nested_tensor([])
        self.assertRaises(IndexError, lambda: nt0[0])
        # normal case
        x0 = torch.randn((2, 5), device=device, dtype=dtype)
        x1 = torch.randn((3, 4), device=device, dtype=dtype)
        nt = torch.nested.nested_tensor([x0, x1])
        # single index: only support integer in the batch dimension
        self.assertEqual(nt[0], x0)
        self.assertEqual(nt[-1], x1)
        self.assertRaises(IndexError, lambda: nt[2])
        self.assertRaises(IndexError, lambda: nt[-3])
        self.assertRaises(NotImplementedError, lambda: nt[:])
        self.assertEqual(nt[...], nt)
        # tuple of indices: only support integer in the batch dimension
        #                 + all possible indexing in the original tensor dimensions
        self.assertEqual(nt[0, 0, 0], x0[0, 0])
        self.assertEqual(nt[0, 1, :], x0[1, :])
        self.assertEqual(nt[1, ...], x1)
        self.assertRaises(IndexError, lambda: nt[1, 4, 2])
        self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1])
        # test select on non-batch dimensions
        self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0))
        self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0))
        self.assertRaises(IndexError, lambda: nt.select(1, 3))
        self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0))
        self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0))
        self.assertRaises(IndexError, lambda: nt.select(2, 5))
        # make sure indexing returns a view
        nt[0].fill_(100.0)
        answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5))
        self.assertEqual(nt[0], answer)
        nt[1, 1, :].fill_(200.0)
        answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
        self.assertEqual(nt[1, 1, :], answer)

        # Test that indexing works when requires_grad_(True)
        # previously this was failing because the backward kernel for select.int uses .sizes()
        nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True)
        self.assertEqual(nt[0], x0)
        self.assertEqual(nt[-1], x1)
        grad_x0 = torch.randn((2, 5), device=device, dtype=dtype)
        nt[0].backward(grad_x0)
        expected_grad = torch.nested.nested_tensor(
            [grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]
        )
        self.assertEqual(nt.grad, expected_grad)

    @parametrize(
        "func",
        [
            subtest(torch.nn.functional.relu, name="relu"),
            subtest(torch.nn.functional.relu_, name="relu_"),
            subtest(torch.nn.functional.gelu, name="gelu"),
            subtest(torch._C._nn.gelu_, name="gelu_"),
            subtest(torch.tanh, name="tanh"),
            subtest(torch.tanh_, name="tanh_"),
            subtest(torch.neg, name="neg"),
            subtest(torch.nn.functional.silu, name="silu"),
            subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"),
            subtest(torch.abs, name="abs"),
            subtest(torch.abs_, name="abs_"),
            subtest(torch.sgn, name="sgn"),
            subtest(torch.logical_not, name="logical_not"),
            subtest(torch.sin, name="sin"),
            subtest(torch.cos, name="cos"),
        ],
    )
    def test_activations(self, device, func):
        nt, nt_noncontiguous = random_nt_noncontiguous_pair(
            (2, 3, 6, 7), device=device, dtype=torch.float32
        )
        nested_result = func(nt)
        self.assertTrue(nested_result.is_nested)
        for t, t_res in zip(nt.unbind(), nested_result.unbind()):
            self.assertEqual(func(t), t_res)
        self.assertRaisesRegex(
            RuntimeError,
            "NestedTensor must be contiguous to get buffer.",
            lambda: func(nt_noncontiguous),
        )

    @parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")])
    def test_binary_ops_with_scalar(self, device, func):
        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
            (2, 3, 6, 7), device=device, dtype=torch.float32
        )
        scalar = 0.0

        # should work regardless of contiguity
        for nt in (nt_contiguous, nt_noncontiguous):
            nested_result = func(nt, scalar)
            self.assertTrue(nested_result.is_nested)
            for t, t_res in zip(nt.unbind(), nested_result.unbind()):
                self.assertEqual(func(t, scalar), t_res)

    @dtypes(*floating_types_and_half())
    def test_nested_tensor_chunk(self, device, dtype):
        # Transformer use case
        a = torch.randn(3, 3 * 4, device=device, dtype=dtype)
        b = torch.randn(2, 3 * 4, device=device, dtype=dtype)
        c = torch.randn(1, 3 * 4, device=device, dtype=dtype)
        a_chunks = a.chunk(3, dim=-1)
        b_chunks = b.chunk(3, dim=-1)
        c_chunks = c.chunk(3, dim=-1)

        a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]]
        b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]]
        c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]]

        nt = torch.nested.nested_tensor([a, b, c])
        chunked = nt.chunk(3, dim=-1)

        self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt))
        self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt))
        self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt))

        for chunk in chunked:
            self.assertFalse(chunk.is_contiguous())

        # Failure chunking on ragged dimensions
        self.assertRaisesRegex(
            RuntimeError,
            "Chunk for nested tensors is currently only supported for the last dimension.",
            lambda: torch.chunk(nt, 5, dim=1),
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Chunk for nested tensors is currently only supported for the last dimension.",
            lambda: torch.chunk(nt, 5, dim=0),
        )

        # Failure on non-contiguous nt
        _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
        self.assertRaisesRegex(
            RuntimeError,
            "chunk expects `self` to be contiguous.",
            lambda: torch.chunk(nt_noncontiguous, 5, dim=-1),
        )

        # Failure when calling non divisible n_chunks
        self.assertRaisesRegex(
            RuntimeError,
            "Chunk for nested tensors is only supported for "
            "nested tensors with trailing dimension divisible by chunks.",
            lambda: torch.chunk(nt, 5, dim=-1),
        )

        # Failure when calling backward on a chunk
        a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True)
        b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True)
        nt_grad = torch.nested.as_nested_tensor([a, b])
        chunked = torch.chunk(nt_grad, 2, dim=-1)
        self.assertRaisesRegex(
            RuntimeError,
            "Nested Strided Tensor doesn't support chunk backward.",
            lambda: chunked[0].backward(chunked[0].clone()),
        )

    @dtypes(*floating_types_and_half())
    def test_nested_tensor_split_with_sizes(self, device, dtype):
        a = torch.randn(3, 20, device=device, dtype=dtype)
        b = torch.randn(2, 20, device=device, dtype=dtype)
        c = torch.randn(1, 20, device=device, dtype=dtype)

        split_sizes = [4, 6, 10]
        a_splits = a.split_with_sizes(split_sizes, dim=-1)
        b_splits = b.split_with_sizes(split_sizes, dim=-1)
        c_splits = c.split_with_sizes(split_sizes, dim=-1)

        nt = torch.nested.nested_tensor([a, b, c])
        nt_splits = nt.split_with_sizes(split_sizes, dim=-1)

        for i, nt_split in enumerate(nt_splits):
            self.assertEqual(
                nt_split,
                torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]),
            )
            dense_strides = torch.stack(
                [
                    torch.tensor(a_splits[i].stride()),
                    torch.tensor(b_splits[i].stride()),
                    torch.tensor(c_splits[i].stride()),
                ]
            )
            self.assertEqual(nt_split._nested_tensor_strides(), dense_strides)
            self.assertFalse(nt_split.is_contiguous())

        # Failure calling on ragged dimensions
        self.assertRaisesRegex(
            RuntimeError,
            "split_with_sizes for nested tensors is currently only supported for the last dimension.",
            lambda: torch.split_with_sizes(nt, split_sizes, dim=1),
        )

        # Failure calling on non-last dimension
        self.assertRaisesRegex(
            RuntimeError,
            "split_with_sizes for nested tensors is currently only supported for the last dimension.",
            lambda: torch.split_with_sizes(nt, split_sizes, dim=0),
        )

        # Failure on non-contiguous nt
        _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
        self.assertRaisesRegex(
            RuntimeError,
            "split_with_sizes expects `self` to be contiguous.",
            lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1),
        )

        # Failure when calling with split_sizes that don't cover the full dim size
        bad_split_sizes = [4, 6, 9]  # don't add up to 20
        self.assertRaisesRegex(
            RuntimeError,
            "split_with_sizes expects split_sizes to sum exactly to 20",
            lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1),
        )

    @dtypes(torch.float, torch.float16, torch.double)
    @torch.inference_mode()
    def test_nested_tensor_indexing_noncontiguous(self, device, dtype):
        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
            (2, 3, 6, 7), device, dtype
        )
        self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0))
        n = nt_contiguous.size(0)
        for i in range(n):
            self.assertEqual(nt_contiguous[i], nt_noncontiguous[i])

    @dtypes(torch.float, torch.float16)
    @skipMeta
    @torch.inference_mode()
    @parametrize("transpose", [True, False])
    def test_nested_tensor_add(self, device, dtype, transpose):
        if transpose:
            a = torch.randn(2, 2, 2, device=device, dtype=dtype)
            b = torch.rand(2, 2, 2, device=device, dtype=dtype)
            c = a.transpose(-1, -2).contiguous()
            d = b.transpose(-1, -2).contiguous()
            nt1 = torch.nested.nested_tensor([a, b, a, b])
            nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
        else:
            (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
        ref = torch.nested.nested_tensor(
            [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
        )
        out = nt1 + nt2
        self.assertEqual(ref, out)

    @dtypes(torch.float, torch.float16)
    @skipMeta
    @torch.inference_mode()
    @parametrize("transpose", [True, False])
    def test_nested_tensor_sub(self, device, dtype, transpose):
        if transpose:
            a = torch.randn(2, 2, 2, device=device, dtype=dtype)
            b = torch.rand(2, 2, 2, device=device, dtype=dtype)
            c = a.transpose(-1, -2).contiguous()
            d = b.transpose(-1, -2).contiguous()
            nt1 = torch.nested.nested_tensor([a, b, a, b])
            nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
        else:
            (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
        ref = torch.nested.nested_tensor(
            [t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
        )
        out = nt1 - nt2
        self.assertEqual(ref, out)

    @onlyCUDA
    @dtypes(torch.float, torch.float16)
    @torch.inference_mode()
    @parametrize("embedding_dim", [8, 128, 256, 384])
    def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim):
        def _test_add_mul(nt, t):
            ref_add = torch.nested.nested_tensor(
                [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
            )
            ref_mul = torch.nested.nested_tensor(
                [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
            )
            self.assertEqual(nt.add(t), ref_add)
            self.assertEqual(nt.mul(t), ref_mul)

        batch_size = 32
        seq_lens = torch.randint(low=0, high=10, size=(batch_size,))

        # [B, *, D], [B, 1, D] case
        ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens]
        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
        t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype)
        _test_add_mul(nt, t)

        # [B, *], [B, 1] case
        ts = [torch.randn(seq_len) for seq_len in seq_lens]
        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
        t = torch.randn((batch_size, 1), device=device, dtype=dtype)
        _test_add_mul(nt, t)

    @dtypes(torch.float, torch.float16)
    @skipMeta
    @torch.inference_mode()
    def test_nested_tensor_mul(self, device, dtype):
        # nested tensor * nested tensor
        (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
        ref = torch.nested.nested_tensor(
            [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
        )
        out = nt1 * nt2
        self.assertEqual(ref, out)
        # nested tensor * scalar
        number = 10.0
        scalar = torch.tensor(number).to(dtype).to(device)
        ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
        out_number0 = nt1 * number
        out_number1 = number * nt1
        out_scalar0 = nt1 * scalar
        out_scalar1 = scalar * nt1
        self.assertEqual(out_number0, ref)
        self.assertEqual(out_number1, ref)
        self.assertEqual(out_scalar0, ref)
        self.assertEqual(out_scalar1, ref)
        # error case: numel == 1 but dim > 0
        vector = torch.tensor([number]).to(dtype).to(device)
        self.assertRaisesRegex(
            RuntimeError,
            "Expected both self and other to be nested, but got a nested self and non-nested other",
            lambda: nt1.mul(vector),
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Expected both self and other to be nested, but got a non-nested self and nested other",
            lambda: vector.mul(nt1),
        )

    @dtypes(torch.float, torch.float16)
    @skipMeta
    @torch.inference_mode()
    def test_nested_tensor_div(self, device, dtype):
        nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4))
        scale = 4.0
        ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()])
        out = nt / 4.0
        self.assertEqual(ref, out)
        ref_transposed = ref.transpose(1, 2)
        out = nt.transpose(1, 2) / 4.0
        self.assertEqual(ref_transposed, out)

        ref = torch.nested.nested_tensor(
            [t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]
        )
        out = nt / nt2
        self.assertEqual(ref, out)

        out = nt.transpose(1, 2) / nt2.transpose(1, 2)
        self.assertEqual(ref.transpose(1, 2), out)

        nt_transpose_copy = torch.nested.nested_tensor(
            [t.transpose(0, 1) for t in nt.unbind()]
        )

        self.assertRaisesRegex(
            RuntimeError,
            "div requires strides to match when given NestedTensors",
            lambda: nt_transpose_copy.transpose(1, 2) / nt2,
        )

        nt = torch.nested.nested_tensor(
            [torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype
        )
        nt_chunks = nt.chunk(2, -1)
        self.assertRaisesRegex(
            RuntimeError,
            "div requires offsets to match when given NestedTensors",
            lambda: nt_chunks[0] / nt_chunks[1],
        )

    @dtypes(torch.float, torch.float16)
    @skipMeta
    @torch.inference_mode()
    def test_nested_tensor_add_in_place(self, device, dtype):
        (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
        ref = torch.nested.nested_tensor(
            [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
        )
        nt1 += nt2
        self.assertEqual(ref, nt1)

    @dtypes(torch.float, torch.float16)
    @skipMeta
    @torch.inference_mode()
    def test_nested_tensor_mul_in_place(self, device, dtype):
        # nested tensor * nested tensor
        (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
        ref = torch.nested.nested_tensor(
            [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
        )
        nt1 *= nt2
        self.assertEqual(ref, nt1)
        # nested tensor * scalar
        number = 10.0
        scalar = torch.tensor(number).to(dtype).to(device)
        ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
        out_number = nt1.clone()
        out_number *= number
        out_scalar = nt1.clone()
        out_scalar *= scalar
        self.assertEqual(out_number, ref)
        self.assertEqual(out_scalar, ref)
        self.assertRaisesRegex(
            RuntimeError,
            r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]",
            lambda: scalar.mul_(nt1),
        )
        # error case: numel == 1 but dim > 0
        vector = torch.tensor([number]).to(dtype).to(device)
        self.assertRaisesRegex(
            RuntimeError,
            "Expected both self and other to be nested, but got a nested self and non-nested other",
            lambda: nt1.mul_(vector),
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Expected both self and other to be nested, but got a non-nested self and nested other",
            lambda: vector.mul_(nt1),
        )

    @onlyCPU
    @skipMeta
    @dtypes(torch.float)
    def test_nested_tensor_sum_dim(self, device, dtype):
        params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7)))

        def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True):
            nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False)
            nt2 = nt.clone()
            ub2 = nt2.unbind()
            nt.requires_grad_(True)
            [t.requires_grad_(True) for t in ub2]
            nt_sum = nt.sum(dim=dim, keepdim=keepdim)
            ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2]
            self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum))

            # test backward
            # generate gradient tensor that has the same size as the output
            size = nt_sum._nested_tensor_size()
            gt2 = []
            for i in range(ntensors):
                gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype))
            gt = torch.nested.nested_tensor(gt2).clone()
            nt_sum.backward(gt)
            for t2, g2 in zip(ub2_sum, gt2):
                t2.backward(g2)
            self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2]))
            return

        for ntensors, max_sizes in params:
            test_sum(device, dtype, ntensors, max_sizes, len(max_sizes))

        # Test error inputs
        with self.assertRaisesRegex(
            RuntimeError, "NestedTensor can only be reduced across the last"
        ):
            torch.nested.nested_tensor(
                [torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
            ).sum(0, keepdim=True)

        with self.assertRaisesRegex(
            RuntimeError, "NestedTensor only allows reduction of a single"
        ):
            torch.nested.nested_tensor(
                [torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]
            ).sum([0, 1], keepdim=True)

        with self.assertRaisesRegex(
            RuntimeError, "NestedTensor always requires keepdim=True for now."
        ):
            torch.nested.nested_tensor(
                [torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
            ).sum(-1)

    @dtypes(torch.float, torch.float16)
    def test_contiguous(self, device, dtype):
        # Since we don't have access to the buffer in python this is harder to show what
        # we are testing for. When we call chunk on a consistent dim of a NT
        # for chunk_size > 1 the resulting tensors are views of the original NT
        # whose numels is now less than the size of the buffer. Clone was
        # previously creating a new NT with a buffer that was the same size as the
        # original.
        nt_contiguous = torch.nested.nested_tensor(
            [
                torch.randn(2, 20, device=device, dtype=dtype),
                torch.randn(4, 20, device=device, dtype=dtype),
            ]
        )
        # Split up the last dimension which has a consistent size of 20 into 5 chunks
        chunks = nt_contiguous.chunk(5, dim=-1)

        # # Check chunks are contiguous after calling contiguous
        for chunk in chunks:
            self.assertFalse(chunk.is_contiguous())
            self.assertTrue(chunk.contiguous().is_contiguous())

    @dtypes(torch.float, torch.float16)
    @skipMeta
    def test_clone(self, device, dtype):
        nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1))
        nt2 = nt1.clone()
        # Verify the values match
        self.assertEqual(nt1, nt2)
        # Verify modifying nt2 doesn't affect nt1
        nt2.mul_(nt1)
        ub1 = nt1.unbind()
        ub2 = nt2.unbind()
        for i in range(len(ub1)):
            self.assertNotEqual(ub1[i], ub2[i])

        nt1.clone(memory_format=torch.preserve_format)
        msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast"
        with self.assertRaisesRegex(RuntimeError, msg):
            nt1.clone(memory_format=torch.channels_last)

    # cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
    @decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged)
    @dtypes(torch.float, torch.double)
    @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
    def test_dropout(self, device, dtype, layout):
        # edge case: empty nested tensor
        # TODO: support empty NT in jagged layout
        if layout == torch.strided:
            nt0 = torch.nested.nested_tensor([], layout=layout)
            y = torch.nn.functional.dropout(nt0, 0.5)
            self.assertEqual(nt0, y)
        # normal nested tensor
        ntensors = 4
        if layout == torch.jagged:
            nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout)
        else:
            nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout)
        # edge case: invalid dropout
        self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
        self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
        self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
        self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
        # edge case: no dropout
        dropouter = torch.nn.Dropout(0.0)
        y0 = dropouter(nt)
        y1 = torch.nn.functional.dropout(nt, 0.0)
        self.assertEqual(nt, y0)
        self.assertEqual(nt, y1)
        # edge case: all dropout
        dropouter = torch.nn.Dropout(1.0)
        y0 = dropouter(nt)
        y1 = torch.nn.functional.dropout(nt, 1.0)
        nt0 = torch.zeros_like(nt)
        self.assertEqual(nt0, y0)
        self.assertEqual(nt0, y1)
        # normal case: normal dropout
        p = 0.2
        y = torch.nn.functional.dropout(nt, p)
        expect = nt.clone()
        if layout == torch.jagged:
            expect = torch.where(y == 0.0, y, nt)
            expect /= 1.0 - p
            self.assertEqual(y, expect)
        else:
            expect = nt.clone()
            for i in range(ntensors):
                actual_tensor = y[i].view(-1)
                expect_tensor = expect[i].view(-1)
                for j in range(actual_tensor.shape[0]):
                    if actual_tensor[j].item() == 0.0:
                        expect_tensor[j] = 0.0
                    else:
                        expect_tensor[j] /= 1.0 - p
            self.assertEqual(y, expect)
        with freeze_rng_state():
            dropouter = torch.nn.Dropout(p)
            y0 = dropouter(nt)
        with freeze_rng_state():
            y1 = torch.nn.functional.dropout(nt, p)
        self.assertEqual(y0, y1)

    @dtypes(torch.float, torch.double)
    def test_dropout_noncontiguous(self, device, dtype):
        ntensors = 4
        nt0 = random_nt(device, dtype, ntensors, (4, 4))
        nt1 = nt0.transpose(-1, -2)
        p = 0.3
        with freeze_rng_state():
            dropouter = torch.nn.Dropout(p)
            y0 = dropouter(nt0)
        with freeze_rng_state():
            y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2)
        self.assertEqual(y0, y1)

    # cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half'
    @dtypes(torch.float, torch.double)
    def test_softmax(self, device, dtype):
        # normal nested tensor
        ntensors = 4
        nt = random_nt(device, dtype, ntensors, (4, 4))
        # error case: softmax across nested dimension
        self.assertRaisesRegex(
            RuntimeError,
            "Cannot apply softmax across nested dimension 0",
            lambda: torch.nn.functional.softmax(nt, 0),
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Cannot apply softmax across nested dimension 0",
            lambda: torch.nn.functional.softmax(nt, -3),
        )
        # error case: dimension out of range
        self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3))
        self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4))
        # normal case: should equal to padding -inf
        softmaxer = torch.nn.Softmax(1)
        y0 = softmaxer(nt)
        y1 = torch.nn.functional.softmax(nt, 1)
        self.assertEqual(y0, y1)
        pt = torch.nested.to_padded_tensor(nt, float("-inf"))
        # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
        # however, physically speaking that should be 0.0
        expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
        self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect)
        # edge case: empty nested tensor
        nt0 = torch.nested.nested_tensor([])
        y = torch.nn.functional.softmax(nt0, 1)
        self.assertEqual(nt0, y)
        # edge case: nesting scalars
        nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
        self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0))
        self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))

    @dtypes(torch.float, torch.double)
    @torch.inference_mode()
    def test_softmax_noncontiguous(self, device, dtype):
        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
            (2, 3, 6, 7), device, dtype
        )
        self.assertEqual(
            torch.nn.functional.softmax(nt_contiguous, -1),
            torch.nn.functional.softmax(nt_noncontiguous, -1),
        )

    def _test_bmm(self, device, dtype):
        # error case: not 3D tensors
        nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
        nt1 = torch.nested.nested_tensor(
            [torch.randn(2), torch.randn(3)], device=device, dtype=dtype
        )
        nt2 = torch.nested.nested_tensor(
            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
        )
        self.assertRaisesRegex(
            RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0)
        )
        self.assertRaisesRegex(
            RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1)
        )
        self.assertRaisesRegex(
            RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2)
        )
        self.assertRaisesRegex(
            RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0)
        )
        self.assertRaisesRegex(
            RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1)
        )
        self.assertRaisesRegex(
            RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2)
        )
        self.assertRaisesRegex(
            RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0)
        )
        self.assertRaisesRegex(
            RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1)
        )
        # error case: incompatible batch size
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
            device=device,
            dtype=dtype,
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.",
            lambda: nt0.bmm(nt1),
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.",
            lambda: nt1.bmm(nt0),
        )
        # error case: underlying matrices cannot be multiplied
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
            lambda: nt0.bmm(nt0),
        )
        # normal nested tensor
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
        )
        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
            torch.nested.to_padded_tensor(nt1, 0.0)
        )
        if dtype == torch.float16:
            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
        else:
            self.assertEqual(actual, expect)

        # nested tensor bmm normal tensor
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype
        )
        nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device)
        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
        if dtype == torch.float16:
            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
        else:
            self.assertEqual(actual, expect)

        # nested tensor bmm normal tensor with non-contiguous view
        nt1 = torch.rand(2, 5, 7, dtype=dtype, device=device)
        nt1 = nt1.transpose(1, 2)
        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
        if dtype == torch.float16:
            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
        else:
            self.assertEqual(actual, expect)

        # normal tensor bmm nested tensor
        nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device)
        nt1 = torch.nested.nested_tensor(
            [torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype
        )
        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
        expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0))
        if dtype == torch.float16:
            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
        else:
            self.assertEqual(actual, expect)

        # test tensorcore path
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype
        )
        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
            torch.nested.to_padded_tensor(nt1, 0.0)
        )
        if dtype == torch.float16:
            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
        else:
            self.assertEqual(actual, expect)

    @onlyCUDA
    @dtypes(torch.float, torch.double, torch.float16)
    def test_bmm_cuda(self, device, dtype):
        self._test_bmm(device, dtype)

    @onlyCPU
    # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
    @dtypes(torch.float, torch.double)
    def test_bmm_cpu(self, device, dtype):
        self._test_bmm(device, dtype)

    # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
    @dtypes(torch.float, torch.double)
    def test_bmm_noncontiguous(self, device, dtype):
        nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
            (2, 3), device, dtype
        )
        nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
            (6, 7), device, dtype
        )
        self.assertEqual(
            nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
            nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous),
        )

    @dtypes(torch.float, torch.double)
    def test_matmul_with_bmm_path(self, device, dtype):
        def unbind_rebind_matmul(nt1, nt2):
            t1s = nt1.unbind()
            t2s = nt2.unbind()
            out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)]
            return torch.nested.nested_tensor(out_ts)

        # [N, n_head, *, head_dim], [N, n_head, head_dim, *]
        Ns = [1, 2, 5]
        n_heads = np.random.randint(2, 5)
        head_dim = 3
        t1s = []
        t2s = []
        for N in Ns:
            for _ in range(N):
                seq_len1 = np.random.randint(2, 5)
                seq_len2 = np.random.randint(2, 5)
                t1s.append(torch.randn(n_heads, seq_len1, head_dim))
                t2s.append(torch.randn(n_heads, head_dim, seq_len2))
            nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype)
            nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype)
            self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2))

        # test with noncontiguous
        t3s = []
        t4s = []
        for _ in range(N):
            seq_len = np.random.randint(2, 5)
            t3s.append(torch.randn(seq_len, n_heads, head_dim))
            t4s.append(torch.randn(seq_len, n_heads, head_dim))
        nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(
            1, 2
        )
        nt4 = (
            torch.nested.nested_tensor(t4s, device=device, dtype=dtype)
            .transpose(1, 2)
            .transpose(2, 3)
        )
        self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4))

    # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
    @dtypes(torch.float, torch.double)
    def test_matmul(self, device, dtype):
        # error case: one is nested but the other is not
        nt = torch.nested.nested_tensor(
            [torch.randn(2), torch.randn(3)], device=device, dtype=dtype
        )
        t = torch.randn(4, device=device, dtype=dtype)
        self.assertRaisesRegex(
            RuntimeError,
            "Expected both to be nested, but got a nested self and non-nested other",
            lambda: torch.matmul(nt, t),
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Expected both to be nested, but got a non-nested self and nested other",
            lambda: torch.matmul(t, nt),
        )
        # error case: not 3+D tensors
        nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
        nt1 = torch.nested.nested_tensor(
            [torch.randn(2), torch.randn(3)], device=device, dtype=dtype
        )
        nt2 = torch.nested.nested_tensor(
            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
            lambda: torch.matmul(nt0, nt0),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
            lambda: torch.matmul(nt0, nt1),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
            lambda: torch.matmul(nt0, nt2),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
            lambda: torch.matmul(nt1, nt0),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
            lambda: torch.matmul(nt1, nt1),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
            lambda: torch.matmul(nt1, nt2),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
            lambda: torch.matmul(nt2, nt0),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
            lambda: torch.matmul(nt2, nt1),
        )
        # error case: incompatible batch size
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
            device=device,
            dtype=dtype,
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
            lambda: torch.matmul(nt0, nt1),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
            lambda: torch.matmul(nt1, nt0),
        )
        # error case: incompatible (wrong) batch sizes that shouldn't even broadcast?
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
        )
        self.assertRaisesRegex(
            RuntimeError,
            "matmul(): For nested tensors, batch dimensions must have the same sizes,",
            lambda: torch.matmul(nt0, nt1),
        )
        # error case: incompatible batch sizes that should technically broadcast
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
        )
        self.assertRaisesRegex(
            RuntimeError,
            "matmul(): For nested tensors, batch dimensions must have the same sizes,",
            lambda: torch.matmul(nt0, nt1),
        )
        # error case: underlying matrices cannot be multiplied
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
        )
        self.assertRaisesRegex(
            RuntimeError,
            "matmul(): Nested tensors cannot be matrix multiplied",
            lambda: torch.matmul(nt0, nt0),
        )
        # normal nested tensor: 3D
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
        )
        actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
        expect = torch.matmul(
            torch.nested.to_padded_tensor(nt0, 0.0),
            torch.nested.to_padded_tensor(nt1, 0.0),
        )
        self.assertEqual(actual, expect)
        # normal nested tensor: 4D (with testing for batch_size=1)
        nt0 = torch.nested.nested_tensor(
            [torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype
        )
        actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
        expect = torch.matmul(
            torch.nested.to_padded_tensor(nt0, 0.0),
            torch.nested.to_padded_tensor(nt1, 0.0),
        )
        self.assertEqual(actual, expect)
        # normal nested tensor: 5D
        nt0 = torch.nested.nested_tensor(
            [torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))],
            device=device,
            dtype=dtype,
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))],
            device=device,
            dtype=dtype,
        )
        actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
        expect = torch.matmul(
            torch.nested.to_padded_tensor(nt0, 0.0),
            torch.nested.to_padded_tensor(nt1, 0.0),
        )
        self.assertEqual(actual, expect)

    # only supported on CUDA for now
    @dtypes(torch.float, torch.double)
    def test_matmul_nt_with_broadcasted_t(self, device, dtype):
        # NT (B, *, C, D) with T (D, E) broadcasting case
        nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype)
        t = torch.randn(5, 6, device=device, dtype=dtype)
        output = torch.matmul(nt, t)

        # should be equivalent to matmul-ing each component with the dense tensor
        self.assertEqual(nt.size(0), output.size(0))
        for component, out_component in zip(nt, output):
            self.assertEqual(out_component, torch.matmul(component, t))

    # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
    @dtypes(torch.float, torch.double)
    def test_matmul_noncontiguous(self, device, dtype):
        nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
            (2, 3), device, dtype
        )
        nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
            (6, 7), device, dtype
        )
        self.assertEqual(
            torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous),
            torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous),
        )

    @dtypes(torch.float, torch.double)
    def test_linear(self, device, dtype):
        a = torch.randn(1, 2, device=device, dtype=dtype)
        b = torch.randn(2, 2, device=device, dtype=dtype)
        c = torch.randn(3, 2, device=device, dtype=dtype)
        nt = torch.nested.nested_tensor([a, b, c])

        weight = torch.randn(2, 2, device=device, dtype=dtype)
        bias = torch.randn(2, device=device, dtype=dtype)
        # success case
        torch.functional.F.linear(nt, weight, bias)

        # invalid nested tensor dimension
        msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2"
        nt1 = torch.nested.nested_tensor(
            [
                torch.randn(1, device=device, dtype=dtype),
                torch.randn(2, device=device, dtype=dtype),
            ]
        )
        with self.assertRaisesRegex(RuntimeError, msg):
            torch.functional.F.linear(nt1, weight, bias)

        # invalid weight shape
        msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3"
        weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype)
        with self.assertRaisesRegex(RuntimeError, msg):
            torch.functional.F.linear(nt, weight1, bias)

        # inconsistent last dim of nested tensor
        msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:"
        nt2 = torch.nested.nested_tensor(
            [
                torch.randn(1, 2, device=device, dtype=dtype),
                torch.randn(2, 3, device=device, dtype=dtype),
            ]
        )
        with self.assertRaisesRegex(RuntimeError, msg):
            torch.functional.F.linear(nt2, weight, bias)

        # Mismatch of nested tensor last dim and weight dimension
        weight2 = torch.randn(2, 4, device=device, dtype=dtype)
        msg = (
            r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'"
            r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4"
        )
        with self.assertRaisesRegex(RuntimeError, msg):
            torch.functional.F.linear(nt, weight2, bias)

        # Nested tensor input and nested weight
        nt_weight = nt.clone()
        msg = r"Linear does not support nested weight when input is a nested tensor."
        with self.assertRaisesRegex(RuntimeError, msg):
            torch.functional.F.linear(nt, nt_weight, bias)

    # TODO: test noncontiguous linear
    # For now this tests the error message of linear
    # since linear does not support noncontiguous buffer yet
    @dtypes(torch.float, torch.double)
    def test_linear_noncontiguous(self, device, dtype):
        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
            (2, 3, 6, 7), device, dtype
        )
        weight = torch.randn((8, 5), device=device, dtype=dtype)
        self.assertRaisesRegex(
            RuntimeError,
            r"for now linear only supports contiguous nested tensor",
            lambda: torch.nn.functional.linear(nt_noncontiguous, weight),
        )

    @dtypes(torch.float, torch.float16, torch.double)
    def test_to_padded_tensor_zero_numel_errors(self, device, dtype):
        ts = [torch.ones(1, 0), torch.ones(0, 0)]
        nt = torch.nested.nested_tensor(
            ts, device=device, dtype=dtype, layout=torch.strided
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"at least one constituent tensor should have non-zero numel",
            lambda: torch.nested.to_padded_tensor(nt, 0.0),
        )

    @dtypes(torch.float, torch.float16, torch.double)
    def test_transpose(self, device, dtype):
        nt = random_nt(device, dtype, 4, (4, 4))
        # error case: transpose nested dimension
        self.assertRaisesRegex(
            RuntimeError,
            "Nested tensor dimension 0 cannot be transposed",
            lambda: nt.transpose(0, 1),
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Nested tensor dimension 0 cannot be transposed",
            lambda: nt.transpose(1, -3),
        )
        # error case: dimension out of range
        self.assertRaises(IndexError, lambda: nt.transpose(1, 3))
        self.assertRaises(IndexError, lambda: nt.transpose(-4, -1))
        # normal case
        ntT = nt.transpose(-1, -2)
        ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
        pt = torch.nested.to_padded_tensor(nt, 0.0)
        ptT = pt.transpose(-1, -2)
        self.assertEqual(ptT, ptT_from_ntT)

    @dtypes(torch.float, torch.float16, torch.double)
    def test_squeeze_unsqueeze(self, device, dtype):
        a = torch.arange(6).reshape(2, 3)
        b = torch.arange(15).reshape(5, 3)
        nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype)
        # error case: squeeze no dimension
        self.assertRaisesRegex(
            RuntimeError,
            "For nested tensors, squeeze without the dim argument",
            lambda: nt.squeeze(),
        )
        # error case: squeeze nested dimension
        self.assertRaisesRegex(
            RuntimeError,
            "For nested tensors, squeezing dimension 0",
            lambda: nt.squeeze(0),
        )
        # error case: dimension out of range
        self.assertRaises(IndexError, lambda: nt.squeeze(3))
        # error case: squeeze nested tensor of singleton tensors
        c = torch.ones(1)
        nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype)
        self.assertRaisesRegex(
            RuntimeError,
            "For nested tensors, squeezing a nested tensor of singleton",
            lambda: nt_singleton.squeeze(1),
        )

        # squeezing a dim which does not have size 1 should be a no-op
        nt2 = nt.squeeze(-1)
        self.assertEqual(nt, nt2)

        # test cases that should work
        nt_sizes = nt._nested_tensor_size()
        nt_strides = nt._nested_tensor_strides()
        for i in range(-2, 4):
            if i == 0:
                # cannot unsqueeze batch dim
                continue
            nt_unsqueezed = nt.unsqueeze(i)
            # negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1
            wrapped_i = i + nt.dim() + 1 if i < 0 else i
            # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim
            size_idx = wrapped_i - 1
            self.assertEqual(
                nt_unsqueezed._nested_tensor_size()[:, size_idx],
                torch.ones(2, dtype=torch.long),
            )
            unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx]
            if i == nt.ndim or i == -1:
                self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long))
            else:
                stride_col_after = nt_strides[:, size_idx]
                size_col_after = nt_sizes[:, size_idx]
                self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after)
            nt_squeezed = nt_unsqueezed.squeeze(i)
            self.assertEqual(nt_squeezed, nt)
            self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes)
            self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides)

    @dtypes(torch.float, torch.float16, torch.double)
    def test_transpose_inference_mode_interaction(self, device, dtype):
        nt = random_nt(device, dtype, 4, (4, 4))
        # Construct in default mode and transpose while in inference mode
        with torch.inference_mode():
            ntT = nt.transpose(-1, -2)
            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
            pt = torch.nested.to_padded_tensor(nt, 0.0)
            ptT = pt.transpose(-1, -2)
            self.assertEqual(ptT, ptT_from_ntT)

        # Construct and transpose while in inference mode
        with torch.inference_mode():
            nt = random_nt(device, dtype, 4, (4, 4))
            ntT = nt.transpose(-1, -2)
            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
            pt = torch.nested.to_padded_tensor(nt, 0.0)
            ptT = pt.transpose(-1, -2)
            self.assertEqual(ptT, ptT_from_ntT)

    @dtypes(torch.float, torch.float16, torch.double)
    def test_view(self, device, dtype):
        nt = random_nt(device, dtype, 4, (4, 4))
        # error case: empty shape
        self.assertRaisesRegex(
            RuntimeError,
            r"shape '\[\]' is invalid for a nested tensor",
            lambda: nt.view(()),
        )
        # error case: empty nested tensor
        nt_empty = torch.nested.nested_tensor([])
        self.assertRaisesRegex(
            RuntimeError,
            "empty nested tensor cannot be reshaped",
            lambda: nt_empty.view(-1),
        )
        # error case: -1 for batch size
        self.assertRaisesRegex(
            RuntimeError,
            r"view: For now nested view cannot change or infer the implicit batch dimension",
            lambda: nt.view(-1, 2, 3),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"shape '\[.*\]' is invalid for input of size [0-9]+",
            lambda: nt.view(4, 2, 3),
        )
        # normal case
        x0 = torch.randn((2, 20), device=device, dtype=dtype)
        x1 = torch.randn((3, 20), device=device, dtype=dtype)
        nt = torch.nested.nested_tensor([x0, x1])
        pt = torch.nested.to_padded_tensor(nt, 0.0)
        # error case, trying to reshape batch dim to a legit shape
        self.assertRaisesRegex(
            RuntimeError,
            r"For now nested view cannot change or infer the implicit batch dimension",
            lambda: nt.transpose(-1, -2).view(40, -1),
        )
        # inherit only the ragged dimension
        # (2, 20) -> (2, 5, 4)
        # (3, 20) -> (3, 5, 4)
        nt1 = nt.view(2, -1, 5, 4)
        # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
        pt1 = pt.view(2, -1, 5, 4)
        self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)

        # more than one -1 (even for "old" dims), should fail
        # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
        # but we ban "inherit old behavior" for >1 dimension
        self.assertRaisesRegex(
            RuntimeError,
            r"only one dimension can be inferred",
            lambda: nt1.view(2, -1, -1, 2, 2),
        )

    @dtypes(torch.float, torch.float16, torch.double)
    def test_view_inference_mode_interaction(self, device, dtype):
        # Construct in default mode and view while in inference mode
        nt = torch.nested.nested_tensor(
            [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
        )
        with torch.inference_mode():
            ntT = nt.view(2, -1, 4, 5)
            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
            pt = torch.nested.to_padded_tensor(nt, 0.0)
            ptT = pt.view(2, -1, 4, 5)
            self.assertEqual(ptT, ptT_from_ntT)
        # Construct and view while in inference mode
        with torch.inference_mode():
            nt = torch.nested.nested_tensor(
                [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
            )
            ntT = nt.view(2, -1, 4, 5)
            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
            pt = torch.nested.to_padded_tensor(nt, 0.0)
            ptT = pt.view(2, -1, 4, 5)
            self.assertEqual(ptT, ptT_from_ntT)

    @dtypes(torch.float, torch.float16, torch.double)
    def test_reshape(self, device, dtype):
        nt = random_nt(device, dtype, 4, (4, 4))
        # error case: empty shape
        self.assertRaisesRegex(
            RuntimeError,
            r"shape '\[\]' is invalid for a nested tensor",
            lambda: nt.reshape(()),
        )
        # error case: empty nested tensor
        nt_empty = torch.nested.nested_tensor([])
        self.assertRaisesRegex(
            RuntimeError,
            "empty nested tensor cannot be reshaped",
            lambda: nt_empty.reshape(-1),
        )
        # error case: -1 for batch size
        self.assertRaisesRegex(
            RuntimeError,
            r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
            lambda: nt.reshape(-1, 2, 3),
        )
        self.assertRaisesRegex(
            RuntimeError,
            r"shape '\[.*\]' is invalid for input of size [0-9]+",
            lambda: nt.reshape(4, 2, 3),
        )
        # normal case
        x0 = torch.randn((2, 20), device=device, dtype=dtype)
        x1 = torch.randn((3, 20), device=device, dtype=dtype)
        nt = torch.nested.nested_tensor([x0, x1])  # (2, (2, 3), 20)
        pt = torch.nested.to_padded_tensor(nt, 0.0)
        # error case, trying to reshape batch dim to a legit shape
        self.assertRaisesRegex(
            RuntimeError,
            r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
            lambda: nt.transpose(-1, -2).reshape(40, -1),
        )
        # inherit only the ragged dimension
        # (2, 20) -> (2, 5, 4)
        # (3, 20) -> (3, 5, 4)
        nt1 = nt.reshape(2, -1, 5, 4)
        # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
        pt1 = pt.reshape(2, -1, 5, 4)
        self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)

        # more than one -1 (even for "old" dims), should fail
        # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
        # but we ban "inherit old behavior" for >1 dimension
        self.assertRaisesRegex(
            RuntimeError,
            r"only one dimension can be inferred",
            lambda: nt1.reshape(2, -1, -1, 2, 2),
        )

    def test_nested_masked_select(self, device):
        t = torch.randn([3, 3], device=device)
        mask = torch.tensor([False], device=device)

        njt = torch.nested.masked_select(t, mask)
        self.assertEqual(njt.values(), torch.tensor([], device=device))
        self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 0], device=device))

        mask = torch.tensor([[False], [False], [True]], device=device)
        njt = torch.nested.masked_select(t, mask)
        self.assertEqual(njt.values(), t[-1], atol=0.1, rtol=0.1)
        self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 3], device=device))

        mask = torch.tensor(
            [[False, False, True], [True, False, True], [False, False, True]],
            device=device,
        )
        njt = torch.nested.masked_select(t, mask)
        self.assertEqual(njt.values(), t.masked_select(mask))
        self.assertEqual(njt.offsets(), torch.tensor([0, 1, 3, 4], device=device))

        t = torch.randn([2, 3, 3, 1], device=device)
        mask = torch.tensor(
            [
                [
                    [[True], [False], [True]],
                    [[True], [False], [True]],
                    [[True], [False], [True]],
                ],
                [
                    [[False], [True], [True]],
                    [[False], [True], [True]],
                    [[True], [True], [True]],
                ],
            ],
            device=device,
        )
        njt = torch.nested.masked_select(t, mask)
        self.assertEqual(njt.values(), t.masked_select(mask))
        self.assertEqual(
            njt.offsets(),
            torch.tensor(
                [0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 11, 12, 13],
                device=device,
            ),
        )

    @dtypes(torch.float, torch.float16, torch.double)
    def test_narrow(self, device, dtype):
        nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype)

        # narrow on dim=0 from start to end
        bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)]
        for start, end in bounds:
            length = end - start
            narrowed = nt.narrow(dim=0, start=start, length=length)
            # ensure output is a view
            self.assertTrue(narrowed._base is nt)
            for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]):
                self.assertEqual(nc, c)

        # dim != 0 is not supported
        for dim in range(1, nt.dim()):
            with self.assertRaisesRegex(
                RuntimeError, "only dim=0 supported for nested tensors"
            ):
                nt.narrow(dim=dim, start=0, length=1)

        # error case: non-contiguous NT
        _, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4))
        with self.assertRaisesRegex(
            RuntimeError, "only contiguous nested tensors supported"
        ):
            nt_noncont.narrow(dim=0, start=0, length=1)

    @parametrize("input_dim", [3, 4])
    def test_scaled_dot_product_attention(self, device, input_dim):
        def rand_tensor(*shape):
            return torch.randn(shape, device=device)

        E = 8
        if input_dim == 3:
            # Shape: (N, L, E); ragged L
            query = torch.nested.nested_tensor(
                [rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)]
            )

            # Shape: (N, S, E); ragged S
            key = torch.nested.nested_tensor(
                [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
            )
            value = torch.nested.nested_tensor(
                [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
            )
        elif input_dim == 4:
            # In the 4D case the L and S is ragged
            # Shape: (N, N', L, E); ragged N' and L
            query = torch.nested.nested_tensor(
                [rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)]
            )
            # Shape: (N, N', S, E); ragged N' and S
            key = torch.nested.nested_tensor(
                [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
            )
            value = torch.nested.nested_tensor(
                [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
            )
        else:
            self.fail(f"Invalid input_dim {input_dim} encountered in SDP test")

        def rand_mask(size):
            return torch.randint(0, 2, size=size, dtype=torch.bool, device=device)

        # Shape: (N, L, S); ragged L and S matching above
        attn_mask = torch.nested.nested_tensor(
            [rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))]
        )

        dropout_p = 0.0  # no dropout for reproducibility

        # Success case: no attn_mask set and is_causal=False.
        actual = torch.nn.functional.scaled_dot_product_attention(
            query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p
        )

        expected_outputs = []
        for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()):
            output = torch.nn.functional.scaled_dot_product_attention(
                q.unsqueeze(0),
                k.unsqueeze(0),
                v.unsqueeze(0),
                attn_mask=None,
                dropout_p=dropout_p,
            )
            expected_outputs.append(output.squeeze(0))
        expected_output_nested = torch.nested.nested_tensor(expected_outputs)
        self.assertEqual(actual, expected_output_nested)

        # Error case: explicit attn_mask set.
        with self.assertRaisesRegex(
            RuntimeError, "not supported when an explicit attn_mask is set"
        ):
            torch.nn.functional.scaled_dot_product_attention(
                query, key, value, attn_mask=attn_mask, dropout_p=dropout_p
            )

        # Error case: is_causal=True.
        with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"):
            torch.nn.functional.scaled_dot_product_attention(
                query, key, value, dropout_p=dropout_p, is_causal=True
            )

    @dtypes(torch.float, torch.float16, torch.double)
    def test_empty_like(self, device, dtype):
        ntensors = 4
        nt = random_nt(device, dtype, ntensors, (4, 4))

        # Create empty on same device as original nested tensor
        nt_empty = torch.empty_like(nt)
        assert nt.is_same_size(nt_empty)
        self.assertEqual(nt.dtype, nt_empty.dtype)
        self.assertEqual(nt.device, nt_empty.device)
        self.assertEqual(nt.layout, nt_empty.layout)

        if torch.cuda.is_available():
            if device == "cpu":
                nt_cuda = torch.empty_like(nt, device="cuda")
                self.assertEqual(torch.device("cuda").type, nt_cuda.device.type)
            else:
                nt_cpu = torch.empty_like(nt, device="cpu")
                self.assertEqual(torch.device("cpu").type, nt_cpu.device.type)

        # Check changing dtype of empty_like nested tensor output
        dtype_set = {torch.float, torch.float16, torch.double}
        for other_dtype in dtype_set - {dtype}:
            nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype)
            self.assertEqual(nt.dtype, dtype)
            self.assertEqual(nt_empty_other_dtype.dtype, other_dtype)
            self.assertEqual(nt.device, nt_empty.device)
            self.assertEqual(nt.layout, nt_empty.layout)

        # Create tensor for autograd
        nt_empty_req_grad = torch.empty_like(nt, requires_grad=True)
        self.assertEqual(nt_empty_req_grad.requires_grad, True)

        # Test noncontiguous tensor does not fail to copy
        nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7))
        nt_empty = torch.empty_like(nt_cont)
        assert nt_cont.is_same_size(nt_empty)
        nt_empty_non_contig = torch.empty_like(nt_noncont)
        assert nt_noncont.is_same_size(nt_empty_non_contig)

        # Test the contiguous memory format option
        nt_empty_contig = torch.empty_like(
            nt_cont, memory_format=torch.contiguous_format
        )
        assert nt_cont.is_same_size(nt_empty_contig)
        assert nt_empty_contig.is_contiguous()

        nt_empty_non_contig = torch.empty_like(
            nt_noncont, memory_format=torch.contiguous_format
        )
        assert nt_noncont.is_same_size(nt_empty_non_contig)
        assert nt_empty_non_contig.is_contiguous()

        # Test other memory formats fail
        self.assertRaises(
            RuntimeError,
            lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last),
        )
        self.assertRaises(
            RuntimeError,
            lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last),
        )
        self.assertRaises(
            RuntimeError,
            lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d),
        )
        self.assertRaises(
            RuntimeError,
            lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d),
        )


@markDynamoStrictTest
class TestNestedTensorAutograd(NestedTensorTestCase):
    # Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck
    # includes the default parameters used for testing ops with gradcheck. However nested tensor
    # does not support the stack op therefore we turn it off for these tests
    def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False):
        return torch.nested.nested_tensor(
            [torch.randn(1, 2), torch.randn(7, 8)],
            requires_grad=requires_grad,
            device=tensor_device,
        )

    def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False):
        return torch.nested.as_nested_tensor(
            [
                torch.randn(1, 2, requires_grad=requires_grad),
                torch.randn(7, 8, requires_grad=requires_grad),
            ],
            device=tensor_device,
        )

    def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False):
        data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device)
        mask = torch.ones_like(data[:, :, 0]).bool()
        return torch._nested_tensor_from_mask(data, mask)

    def test_as_nested_tensor_propagates_gradients(self, device):
        a = torch.arange(3, dtype=torch.float, device=device)
        b = torch.arange(5, dtype=torch.float, device=device)
        nt = torch.nested.as_nested_tensor([a, b])
        # tensors with requires_grad=False are leaves
        self.assertTrue(nt.is_leaf)
        self.assertTrue(not nt.requires_grad)

        a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
        b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
        nt2 = torch.nested.as_nested_tensor([a, b])
        fake_grad = torch.nested.nested_tensor(
            [torch.ones_like(a), torch.zeros_like(b)], device=device
        )
        nt2.backward(fake_grad)
        self.assertEqual(a.grad, fake_grad[0])
        self.assertEqual(b.grad, fake_grad[1])

    def test_nested_tensor_generates_leaf(self, device):
        a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
        b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)

        nt = torch.nested.nested_tensor([a, b], requires_grad=False)
        self.assertTrue(nt.is_leaf)
        self.assertTrue(not nt.requires_grad)

        nt2 = torch.nested.nested_tensor([a, b], requires_grad=True)
        self.assertTrue(nt2.is_leaf)
        self.assertTrue(nt2.requires_grad)

        fake_grad = torch.nested.nested_tensor(
            [torch.ones_like(a), torch.zeros_like(b)], device=device
        )
        nt2.backward(fake_grad)
        self.assertEqual(nt2.grad, fake_grad)
        self.assertEqual(a.grad, None)
        self.assertEqual(b.grad, None)

    def test_set_requires_grad_from_list(self, device):
        nt = self._create_nested_tensor_from_list(device)
        nt.requires_grad_()
        assert nt.requires_grad

    def test_set_requires_grad_from_mask(self, device):
        nt = self._create_nested_tensor_from_mask(device)
        nt.requires_grad_()
        assert nt.requires_grad

    def test_backward_for_add_op(self, device):
        nt_1 = self._create_nested_tensor_from_mask(device)
        nt_2 = self._create_nested_tensor_from_mask(device)

        nt_1.requires_grad_()
        c = nt_1 + nt_2

        assert nt_1.requires_grad
        assert c.requires_grad
        grad_output = self._create_nested_tensor_from_mask(device)
        c.backward(grad_output)

        #  Grad check doesn't work with nested yet.
        # d/dnt_1 (nt + nt_1) = 1*grad_output
        self.assertEqual(nt_1.grad, grad_output)

    def test_backward_for_sub_op(self, device):
        nt_1 = self._create_nested_tensor_from_mask(device)
        nt_2 = self._create_nested_tensor_from_mask(device)

        nt_1.requires_grad_()
        nt_2.requires_grad_()
        c = nt_1 - nt_2

        assert nt_1.requires_grad
        assert nt_2.requires_grad
        assert c.requires_grad
        grad_output = self._create_nested_tensor_from_mask(device)
        c.backward(grad_output)

        self.assertEqual(nt_1.grad, grad_output)
        self.assertEqual(nt_2.grad, -1 * grad_output)

    def test_backward_sub_strided(self, device):
        a = torch.nested.nested_tensor(
            [torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
            requires_grad=True,
            device=device,
        )
        b = torch.nested.nested_tensor(
            [torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
            requires_grad=True,
            device=device,
        )
        c = a - b.transpose(-1, -2)
        grad_output = c.clone()
        c.backward(grad_output)
        self.assertEqual(a.grad, grad_output)
        self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2))

    def test_backward_add_strided(self, device):
        a = torch.nested.nested_tensor(
            [torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
            requires_grad=True,
            device=device,
        )
        b = torch.nested.nested_tensor(
            [torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
            requires_grad=True,
            device=device,
        )
        c = a + b.transpose(-1, -2)
        grad_output = c.clone()
        c.backward(grad_output)
        self.assertEqual(a.grad, grad_output)
        self.assertEqual(b.grad, grad_output.transpose(-1, -2))

    # Test Factory Functions
    def test_nested_tensor_to_padded_tensor(self, device):
        for padding_val in [0, 1]:
            nt = self._create_leaf_nested_tensor_from_list(
                tensor_device=device, requires_grad=True
            )

            out = torch.nested.to_padded_tensor(nt, padding_val)
            grad_output = torch.ones(out.shape, device=device)
            out.backward(grad_output)

            self.assertEqual(
                nt.grad,
                torch.nested.nested_tensor(
                    [torch.ones(1, 2), torch.ones(7, 8)], device=device
                ),
            )

    def test_nested_tensor_from_mask_and_to_padded(self, device):
        N, L, D = 2, 4, 4
        mask = torch.ones(N, L, device=device)
        for i in range(1, N):
            end = torch.randint(1, L - 1, (1,), device=device)
            mask[i, end:] = 0

        mask[0, :] = 1
        mask = mask.bool()

        data = torch.randn(
            N, L, D, requires_grad=True, dtype=torch.float64, device=device
        )

        def grad_test_func(inpt):
            nt = torch._nested_tensor_from_mask(inpt, mask)
            # This implicitly tests to_padded_tensor grads
            return torch.nested.to_padded_tensor(nt, 0)

        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_nested_tensor_from_padded(self, device):
        nested_size = torch.tensor([[1, 2], [2, 2]])
        padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device)
        padded_tensor[0, 1, :] = 0
        padded_tensor.requires_grad_()

        def grad_test_func(tensor, nested_size):
            nt = torch._nested_from_padded(
                tensor, nested_size, fuse_transform_0213=False
            )
            # This implicitly tests to_padded_tensor grads
            return torch.nested.to_padded_tensor(nt, 0)

        data = (padded_tensor, nested_size)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_nested_tensor_from_padded_fused(self, device):
        nested_size = torch.tensor([[1, 8], [2, 8]])
        padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device)
        padded_tensor[0, 1, :] = 0
        padded_tensor.requires_grad_()

        def grad_test_func(tensor, nested_size):
            nt = torch._nested_from_padded(
                tensor, nested_size, fuse_transform_0213=True
            )
            # This implicitly tests to_padded_tensor grads
            return torch.nested.to_padded_tensor(nt, 0)

        data = (padded_tensor, nested_size)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_nested_tensor_from_list(self, device):
        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            c = torch.nested.as_nested_tensor([a, b, c])
            # This implictily tests to_padded_tensor grads
            return torch.nested.to_padded_tensor(c, 0)

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
    def test_dropout_backward(self, layout):
        if layout == torch.jagged:
            nt = torch.nested.nested_tensor(
                [torch.randn((2, 5)), torch.randn((3, 5))],
                requires_grad=True,
                layout=layout,
            )
        else:
            nt = torch.nested.nested_tensor(
                [torch.randn((2, 5)), torch.randn((3, 4))],
                requires_grad=True,
                layout=layout,
            )
        p = 0.2
        y = torch.nn.functional.dropout(nt, p)
        y.backward(nt.clone().detach())
        self.assertEqual(nt.grad, y)

    def test_nested_tensor_bmm_gradcheck(self, device):
        a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
        d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c, d):
            nt0 = torch.nested.as_nested_tensor([a, b])
            nt1 = torch.nested.as_nested_tensor([c, d])
            result = nt0.bmm(nt1)
            return torch.nested.to_padded_tensor(result, 0.0)

        data = (a, b, c, d)
        assert torch.autograd.gradcheck(grad_test_func, inputs=data)

    def test_nested_tensor_bmm_backward(self, device):
        nt0 = torch.nested.nested_tensor(
            [torch.randn((2, 6)), torch.randn((3, 6))],
            requires_grad=True,
            device=device,
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((6, 4)), torch.randn((6, 5))],
            requires_grad=True,
            device=device,
        )
        with torch.no_grad():
            pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
            pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)

        ynt = nt0.bmm(nt1)
        ypt = pt0.bmm(pt1)
        ynt.backward(ynt.clone())
        ypt.backward(ypt.clone())

        self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
        self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)

    def test_nested_tensor_matmul_gradcheck(self, device):
        a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
        d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c, d):
            nt0 = torch.nested.as_nested_tensor([a, b])
            nt1 = torch.nested.as_nested_tensor([c, d])
            result = torch.matmul(nt0, nt1)
            return torch.nested.to_padded_tensor(result, 0.0)

        data = (a, b, c, d)
        assert torch.autograd.gradcheck(grad_test_func, inputs=data)

    def test_nested_tensor_matmul_backward(self, device):
        nt0 = torch.nested.nested_tensor(
            [torch.randn((7, 2, 6)), torch.randn((7, 3, 6))],
            requires_grad=True,
            device=device,
        )
        nt1 = torch.nested.nested_tensor(
            [torch.randn((7, 6, 4)), torch.randn((7, 6, 5))],
            requires_grad=True,
            device=device,
        )
        with torch.no_grad():
            pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
            pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)

        ynt = torch.matmul(nt0, nt1)
        ypt = torch.matmul(pt0, pt1)
        ynt.backward(ynt.clone())
        ypt.backward(ypt.clone())

        self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
        self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)

    def test_nested_tensor_transpose_gradcheck(self, device):
        a = torch.randn(2, 5, requires_grad=True, device=device)
        b = torch.randn(3, 4, requires_grad=True, device=device)

        def grad_test_func(a, b):
            nt = torch.nested.as_nested_tensor([a, b])
            result = nt.transpose(-2, -1).transpose(-2, -1)
            return torch.nested.to_padded_tensor(result, 0.0)

        data = (a, b)
        assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)

    def test_nested_tensor_transpose_backward(self, device):
        nt = torch.nested.nested_tensor(
            [torch.randn((2, 5)), torch.randn((3, 4))],
            requires_grad=True,
            device=device,
        )
        with torch.no_grad():
            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)

        ynt = nt.transpose(-2, -1)
        ypt = pt.transpose(-2, -1)
        ynt.backward(ynt.clone())
        ypt.backward(ypt.clone())

        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)

    def test_nested_tensor_reshape_gradcheck(self, device):
        a = torch.randn(2, 6, requires_grad=True, device=device)
        b = torch.randn(3, 6, requires_grad=True, device=device)

        def grad_test_func(a, b):
            nt = torch.nested.as_nested_tensor([a, b])
            result = nt.reshape(2, -1, 2, 3)
            return torch.nested.to_padded_tensor(result, 0.0)

        data = (a, b)
        assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)

    def test_nested_tensor_reshape_backward(self):
        nt = torch.nested.nested_tensor(
            [torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True
        )
        with torch.no_grad():
            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)

        ynt = nt.reshape(2, -1, 2, 3)
        ypt = pt.reshape(2, -1, 2, 3)
        ynt.backward(ynt.clone())
        ypt.backward(ypt.clone())

        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)

    def test_nested_tensor_squeeze_backward(self, device):
        nt = torch.nested.nested_tensor(
            [torch.randn((2, 6, 1)), torch.randn((3, 6, 1))],
            requires_grad=True,
            device=device,
        )
        with torch.no_grad():
            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)

        ynt = nt.squeeze(-1)
        ypt = pt.squeeze(-1)
        ynt.backward(ynt.clone())
        ypt.backward(ypt.clone())

        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)

    def test_nested_tensor_squeeze_gradcheck(self, device):
        a = torch.randn(
            (2, 6, 1), dtype=torch.float64, requires_grad=True, device=device
        )
        b = torch.randn(
            (3, 6, 1), dtype=torch.float64, requires_grad=True, device=device
        )

        def grad_test_func(a, b):
            nt = torch.nested.as_nested_tensor([a, b])
            result = nt.squeeze(-1)
            return torch.nested.to_padded_tensor(result, 0.0)

        assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)

    def test_nested_tensor_unsqueeze_backward(self, device):
        nt = torch.nested.nested_tensor(
            [torch.randn((2, 6)), torch.randn((3, 6))],
            requires_grad=True,
            device=device,
        )
        with torch.no_grad():
            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)

        ynt = nt.unsqueeze(2)
        ypt = pt.unsqueeze(2)
        ynt.backward(ynt.clone())
        ypt.backward(ypt.clone())

        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)

    def test_nested_tensor_unsqueeze_gradcheck(self, device):
        a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device)
        b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device)

        def grad_test_func(a, b):
            nt = torch.nested.as_nested_tensor([a, b])
            result = nt.unsqueeze(-1)
            return torch.nested.to_padded_tensor(result, 0.0)

        assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)

    def test_nested_tensor_linear(self, device):
        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)

        weight = torch.randn(
            2, 2, requires_grad=True, dtype=torch.float64, device=device
        )
        bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c, weight, bias=None):
            nt = torch.nested.as_nested_tensor([a, b, c])
            # This implicitly tests to_padded_tensor grads
            d = torch.functional.F.linear(nt, weight, bias)
            return torch.nested.to_padded_tensor(d, 0)

        data = (a, b, c, weight, bias)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

        # Test linear with no bias added
        data = (a, b, c, weight)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_nested_tensor_linear_plus_transpose(self, device):
        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)

        weight = torch.randn(
            2, 2, requires_grad=True, dtype=torch.float64, device=device
        )
        bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c, weight, bias=None):
            nt = torch.nested.as_nested_tensor([a, b, c])
            # This implicitly tests to_padded_tensor grads
            d = torch.functional.F.linear(nt, weight, bias)
            d = d.transpose(-1, -2).contiguous()
            return torch.nested.to_padded_tensor(d, 0)

        data = (a, b, c, weight, bias)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

        # Test linear with no bias added
        data = (a, b, c, weight)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_nested_tensor_softmax(self, device):
        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c, dim):
            nt = torch.nested.as_nested_tensor([a, b, c])
            # This implicitly tests to_padded_tensor grads
            d = torch.functional.F.softmax(nt, dim=dim)
            return torch.nested.to_padded_tensor(d, 0)

        # softmax over last dim
        data = (a, b, c, -1)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_nested_tensor_linear_backward(self, device):
        a = torch.randn(1, 2, requires_grad=False, device=device)
        b = torch.randn(2, 2, requires_grad=False, device=device)
        c = torch.randn(3, 2, requires_grad=False, device=device)

        weight = torch.randn(2, 2, requires_grad=True, device=device)
        bias = torch.randn(2, requires_grad=True, device=device)
        nt = torch.nested.as_nested_tensor([a, b, c], device=device)

        out = torch.functional.F.linear(nt, weight, bias)

        out.backward(out.clone())

        assert weight.grad is not None
        assert bias.grad is not None

        assert a.grad is None
        assert b.grad is None
        assert c.grad is None

    def test_values_grad_with_broadcast(self, device):
        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            buffer = nt.values()
            return buffer.sum()

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_to_buffer_series_ops_grad_with_broadcast(self, device):
        a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            buffer = nt.values()
            buffer = buffer * 2
            return buffer.exp()

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_unbind_flow_through(self, device):
        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            ntT = nt.transpose(-1, -2)
            unbound = ntT.unbind()
            d = unbound[0]
            d = torch.pow(d, 2)
            return d

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_split_with_sizes_flow_through(self, device):
        a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            splits = nt.split_with_sizes([2, 3], dim=-1)
            unbound = splits[1].unbind()
            d = unbound[0]
            d = torch.pow(d, 2)
            return d

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_indexing_backward(self, device):
        x0 = torch.randn((2, 5))
        x1 = torch.randn((3, 4))
        nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True)
        self.assertEqual(nt[0], x0)
        self.assertEqual(nt[-1], x1)
        grad_x0 = torch.randn((2, 5), device=device)
        nt[0].backward(grad_x0)
        expected_grad = torch.nested.nested_tensor(
            [grad_x0, torch.zeros((3, 4), device=device)]
        )
        self.assertEqual(nt.grad, expected_grad)

    def test_masked_fill_backward(self, device):
        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            mask = nt.detach().clone().to(bool)
            out = nt.masked_fill(mask, 0)
            out = torch.nested.to_padded_tensor(out, 0)
            return out

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_gelu_backward(self, device):
        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            nt_gelu = torch.nn.functional.gelu(nt)
            return torch.nested.to_padded_tensor(nt_gelu, 0)

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_relu_backward(self, device):
        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            nt_relu = torch.nn.functional.relu(nt)
            return torch.nested.to_padded_tensor(nt_relu, 0)

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_selu_backward(self, device):
        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            nt_relu = torch.nn.functional.silu(nt)
            return torch.nested.to_padded_tensor(nt_relu, 0)

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    def test_abs_backward(self, device):
        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            nt_abs = torch.abs(nt)
            return torch.nested.to_padded_tensor(nt_abs, 0)

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    # Previously would error when input NT doesn't require grad
    # NotImplementedError: Cannot access storage of UndefinedTensorImpl
    def test_layer_norm_backward_edge_case(self, device):
        size = 4
        a = torch.randn(
            1, 2, size, requires_grad=False, dtype=torch.float64, device=device
        )
        nt = torch.nested.nested_tensor([a])
        nt_layer_norm = torch.nn.LayerNorm(
            nt.size(-1), device=device, dtype=torch.float64
        )
        out = nt_layer_norm(nt)
        out.backward(out.clone())

    def test_accumulate_grad_different_strides(self, device):
        a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b):
            nt_1 = torch.nested.as_nested_tensor([a, b])
            nt_2 = nt_1.clone()
            out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2)
            return torch.nested.to_padded_tensor(out, 0)

        data = (a, b)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    # https://github.com/pytorch/pytorch/issues/95562
    @skipIfSlowGradcheckEnv
    @parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2])
    def test_layer_norm_backward(self, device, size):
        a = torch.randn(
            1, 2, size, requires_grad=True, dtype=torch.float64, device=device
        )
        b = torch.randn(
            2, 2, size, requires_grad=True, dtype=torch.float64, device=device
        )
        c = torch.randn(
            3, 2, size, requires_grad=True, dtype=torch.float64, device=device
        )

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            layer_norm = torch.nn.LayerNorm(
                nt.size(-1), device=device, dtype=torch.float64
            )
            nt_layer_norm = layer_norm(nt)
            return torch.nested.to_padded_tensor(nt_layer_norm, 0)

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)

    # https://github.com/pytorch/pytorch/issues/95562
    @skipIfSlowGradcheckEnv
    # Could either mark slow or reduce size
    @parametrize("size", [128, 32, 4, 2])
    def test_layer_norm_backward_5d(self, device, size):
        a = torch.randn(
            4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
        )
        b = torch.randn(
            7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
        )
        c = torch.randn(
            10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
        )

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c])
            layer_norm = torch.nn.LayerNorm(
                (size, size, nt.size(-1)), device=device, dtype=torch.float64
            )
            nt_layer_norm = layer_norm(nt)
            return torch.nested.to_padded_tensor(nt_layer_norm, 0)

        data = (a, b, c)
        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)


# Found in torch/testing/_comparison.py
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}


def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
    deviation = true_value - computed_value
    deviation = torch.abs(deviation / true_value)
    # Fill in the nans with the default rtol
    torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
    return deviation.max().item()


def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
    deviation = true_value - computed_value
    atol = torch.abs(deviation).max().item()
    return atol


def get_tolerances(
    true_value: torch.Tensor,
    computed_value: torch.Tensor,
    fudge_factor: Optional[float] = None,
) -> Tuple[float, float]:
    """Returns the absolute and relative tolerances for comparing two tensors."""
    fudge_factor = fudge_factor if fudge_factor is not None else 1.0
    atol = get_atol(true_value, computed_value)
    rtol = get_rtol(true_value, computed_value)

    atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
    rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
    # torch.isclose() has weird behavior around see:
    # https://github.com/pytorch/pytorch/issues/102400
    if rtol > 1e30:
        rtol = default_rtol[computed_value.dtype]
    return atol, rtol


# We can probably parametrizing existing tests instead of having a separate
# test class as we begin to support more ops. Also maybe rewrite with OpInfos.
@markDynamoStrictTest
class TestNestedTensorSubclass(NestedTensorTestCase):
    # TODO: consolidate with the below
    def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
        Ds = nested_size[1:]
        out = []
        for s in nested_size[0]:
            out.append(
                torch.randn(
                    s,
                    *Ds,
                    requires_grad=requires_grad,
                    device=device,
                    dtype=torch.float64,
                )
            )
        return out

    def _get_example_tensor_lists(
        self,
        include_list_of_lists=True,
        include_requires_grad=True,
        include_inner_dim_size_1=False,
        include_2d_tensor=False,
    ):
        def _make_tensor(
            *shape, include_requires_grad=include_requires_grad, requires_grad=True
        ):
            return torch.randn(
                *shape,
                requires_grad=(requires_grad if include_requires_grad else False),
            )

        # Purposefully introduce mixed requires_grad settings for the components
        # when include_requires_grad=True.
        example_lists = [
            # (B, *, D) with B=4
            [
                _make_tensor(2, 5),
                _make_tensor(3, 5, requires_grad=False),
                _make_tensor(4, 5, requires_grad=False),
                _make_tensor(6, 5),
            ],
            # (B, *, D_0, D_1) with B=5
            [
                _make_tensor(2, 5, 6),
                _make_tensor(3, 5, 6),
                _make_tensor(4, 5, 6, requires_grad=False),
                _make_tensor(5, 5, 6),
                _make_tensor(6, 5, 6),
            ],
            # (B, *, D_0, D_1, D_2) with B=6
            [
                _make_tensor(2, 5, 6, 7),
                _make_tensor(3, 5, 6, 7),
                _make_tensor(4, 5, 6, 7, requires_grad=False),
                _make_tensor(5, 5, 6, 7),
                _make_tensor(6, 5, 6, 7),
                _make_tensor(7, 5, 6, 7),
            ],
        ]

        if include_list_of_lists:
            example_lists.append(
                # (B, *, D) with B=3 in list form
                [
                    _make_tensor(2, 5, requires_grad=False).tolist(),
                    _make_tensor(3, 5).tolist(),
                    _make_tensor(4, 5).tolist(),
                ]
            )

        if include_inner_dim_size_1:
            example_lists.append(
                [
                    _make_tensor(2, 1),
                    _make_tensor(3, 1, requires_grad=False),
                    _make_tensor(4, 1, requires_grad=False),
                    _make_tensor(6, 1),
                ]  # (B, *, 1)
            )
            example_lists.append(
                [
                    _make_tensor(2, 5, 1),
                    _make_tensor(3, 5, 1, requires_grad=False),
                    _make_tensor(4, 5, 1, requires_grad=False),
                    _make_tensor(6, 5, 1),
                ]  # (B, *, 5, 1)
            )

        if include_2d_tensor:
            example_lists.append(
                [
                    _make_tensor(2),
                    _make_tensor(3, requires_grad=False),
                    _make_tensor(4, requires_grad=False),
                    _make_tensor(6),
                ]  # (B, *)
            )

        return example_lists

    def test_tensor_attributes(self, device):
        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
        _offsets = nt.offsets()

        for op in (
            torch.ops.aten.is_non_overlapping_and_dense.default,
            torch.ops.aten.sym_size.default,
            torch.ops.aten.dim.default,
            torch.ops.aten.numel.default,
            torch.ops.aten.sym_numel.default,
            torch.ops.aten.sym_stride.default,
            torch.ops.aten.sym_storage_offset.default,
        ):
            op(nt)

        with self.assertRaisesRegex(
            RuntimeError, "directly calling torch.ops.aten.size"
        ):
            torch.ops.aten.size.default(nt)

        nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(
            _offsets, coeff=1
        )
        self.assertEqual(nt.size(), (3, nested_int, 3))
        self.assertEqual(nt.shape, (3, nested_int, 3))
        self.assertEqual(nt.dim(), 3)
        self.assertEqual(nt.numel(), 27)

    @parametrize("nt_dim", [3, 4, 5])
    def test_linear(self, device, nt_dim):
        if nt_dim == 3:
            fixed_shape = (3,)
        elif nt_dim == 4:
            fixed_shape = (4, 3)
        elif nt_dim == 5:
            fixed_shape = (5, 4, 3)

        a = torch.randn(
            2, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
        )
        b = torch.randn(
            3, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
        )
        c = torch.randn(
            4, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
        )
        weight = torch.randn(
            4, 3, requires_grad=True, dtype=torch.float64, device=device
        )

        def grad_test_func(a, b, c, weight):
            nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
            out = torch.nn.functional.linear(nt, weight)
            return out.values()

        gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False)

    def test_unary_pointwise(self, device):
        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
            out = torch.nn.functional.silu(nt.sin().cos())
            return out.values()

        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)

    def test_unary_pointwise_transposed_inputs(self, device):
        a, b, c = (
            torch.randn(
                i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
            )
            for i in range(3)
        )

        nt = torch.nested.nested_tensor(
            [a.detach(), b.detach(), c.detach()], layout=torch.jagged
        )
        nt_t = nt.transpose(1, 2)
        self.assertFalse(nt_t.is_contiguous())
        out = torch.nn.functional.silu(nt_t.sin().cos())
        self.assertEqual(
            out.is_contiguous(),
            torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(),
        )

        self.assertEqual(nt_t.shape, out.shape)

        a, b, c = (
            torch.randn(
                i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
            )
            for i in range(3)
        )

        def grad_test_func(a, b, c):
            nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
            nt_t = nt.transpose(1, 2)
            out = torch.nn.functional.silu(nt_t.sin().cos())
            return out.values()

        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)

    def test_binary_pointwise(self, device):
        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)

        # Incorrect usage: shape check will fail if the offsets tensor are not
        #                  the same exact tensor object
        nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
        nt2 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)

        self.assertRaisesRegex(
            RuntimeError,
            "cannot call binary pointwise function .* with inputs of shapes",
            lambda: nt1 * nt2,
        )

        # Correct usage: chain the calls using the same offsets tensor object
        def grad_test_func(a, b, c):
            nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
            # TODO: Switch to public API that takes in (values, offsets) once it exists
            nt2, offsets = jagged_from_list([a, b, c], nt1.offsets())
            out = nt1 * nt2
            return out.values()

        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)

    def test_binary_pointwise_transposed(self, device):
        a, b, c = (
            torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3)
        )

        nt1, offsets = jagged_from_list([a, b, c], None)
        nt2, offsets = jagged_from_list([a, b, c], offsets)

        nt1_t = nt1.transpose(1, 2)
        nt2_t = nt2.transpose(1, 2)

        # out = nt1_t * nt2_t
        # self.assertFalse(nt1_t.is_contiguous())
        # self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous())
        # self.assertEqual(out.shape, nt1_t.shape)

        self.assertRaisesRegex(
            RuntimeError,
            "cannot call binary pointwise function mul.Tensor with inputs of shapes",
            lambda: nt1 * nt2_t,
        )

        a, b, c = (
            torch.randn(
                i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
            )
            for i in range(3)
        )

        # Correct usage: chain the calls using the same offsets tensor object
        def grad_test_func(a, b, c):
            nt1, offsets = jagged_from_list([a, b, c], None)
            nt2, offsets = jagged_from_list([a, b, c], offsets)
            nt1_t = nt1.transpose(1, 2)
            nt2_t = nt2.transpose(1, 2)
            out = nt1_t * nt2_t
            return out.values()

        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)

    def test_split(self, device):
        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)

        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
        out = torch.split(nt, 2, -1)
        self.assertEqual(len(out), 2)
        self.assertEqualIgnoringNestedInts(
            out[0],
            torch.nested.as_nested_tensor(
                [a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged
            ),
        )
        self.assertEqualIgnoringNestedInts(
            out[1],
            torch.nested.as_nested_tensor(
                [a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged
            ),
        )

        with self.assertRaisesRegex(
            RuntimeError,
            r"split\(\): not supported for NestedTensor on dim=1",
        ):
            torch.split(nt, 2, 1)

    def test_split_with_sizes(self, device):
        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)

        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
        out = torch.split(nt, [1, 2], -1)
        self.assertEqual(len(out), 2)
        self.assertEqualIgnoringNestedInts(
            out[0],
            torch.nested.as_nested_tensor(
                [a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged
            ),
        )
        self.assertEqualIgnoringNestedInts(
            out[1],
            torch.nested.as_nested_tensor(
                [a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged
            ),
        )
        with self.assertRaisesRegex(
            RuntimeError,
            r"split_with_sizes\(\): not supported for NestedTensor on dim=1",
        ):
            torch.split(nt, [1, 2], 1)

    def test_softmax(self, device):
        nt = random_nt_from_dims(
            [3, None, 5],
            device=device,
            dtype=torch.float32,
            layout=torch.jagged,
            requires_grad=True,
        )

        # operate on dim=2
        output = nt.softmax(dim=2)

        @torch._dynamo.disable
        def _compare_to_ref(nt, output, dim):
            for in_component, out_component in zip(nt.unbind(), output.unbind()):
                self.assertEqual(in_component.softmax(dim=dim), out_component)

        # dim=2 -> dim=1 after unbind
        _compare_to_ref(nt, output, dim=1)

        # operate on dim=-1
        output2 = nt.softmax(dim=-1)
        torch._dynamo.disable(self.assertEqual)(output, output2)
        _compare_to_ref(nt, output2, dim=-1)

        def grad_test_func(a, b):
            nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged)
            out = nt.softmax(dim=-1)
            return out.values()

        a = torch.rand(4, 5, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.rand(8, 5, requires_grad=True, dtype=torch.float64, device=device)
        gradcheck(grad_test_func, inputs=(a, b), check_batched_grad=False)

    def test_views_inherit_ragged_dim(self, device):
        # view
        nt = random_nt_from_dims(
            [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
        )
        # inherit ragged dim via -1
        view = nt.view(4, -1, 80)
        self.assertEqual(nt.shape[1], view.shape[1])
        # inherit batch and ragged dims via -1
        view2 = nt.view(-1, -1, 80)
        self.assertEqual(nt.shape[:2], view2.shape[:2])

        # expand
        nt = random_nt_from_dims(
            [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
        )
        # inherit batch and ragged dims via -1
        view = nt.expand(-1, -1, 5)
        self.assertEqual(nt.shape[:2], view.shape[:2])

    def test_view_ragged_idx_not_one(self, device):
        nt = random_nt_from_dims(
            [2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged
        )

        view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1))
        self.assertEqual((2, 20, nt.size(1)), (view_transposed.size()))
        self.assertEqual(view_transposed._base, nt._base)

    def test_unsafe_view(self, device):
        nt = random_nt_from_dims(
            [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
        )
        # basic view
        view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80))
        self.assertEqual((4, nt.size(1), 80), tuple(view1.size()))
        # _unsafe_view differs from view in that the view information is not tracked
        self.assertTrue(view1._base is None)

        # test an unsafe_view when ragged_idx != 1, currently only supports identity view
        nt_t = nt.transpose(1, 2)
        view2 = torch.ops.aten._unsafe_view(nt_t, (4, 8, nt.size(1), 10))
        self.assertEqual((4, 8, nt.size(1), 10), tuple(view2.size()))
        self.assertTrue(view2._base is None)

    @xfailIfTorchDynamo
    @parametrize("requires_grad", [False, True])
    def test_reshape_decomp(self, device, requires_grad):
        # contiguous NT should result in view.
        nt = (
            random_nt_from_dims(
                [3, None, 10],
                device=device,
                dtype=torch.float32,
                layout=torch.jagged,
            )
            .detach()
            .requires_grad_(requires_grad)
        )
        view = nt.reshape(-1, -1, 5, 2)
        self.assertEqual(view.shape[:2], nt.shape[:2])
        self.assertTrue(view._is_view() and view._base is nt)
        # make sure gradients flow back
        if requires_grad:
            view.backward(torch.ones_like(view))
            self.assertEqual(nt.grad, torch.ones_like(nt))

        # non-contiguous NT should result in contiguous copy
        nt = random_nt_from_dims(
            [3, None, 5, 2],
            device=device,
            dtype=torch.float32,
            layout=torch.jagged,
            requires_grad=requires_grad,
        )
        nt_noncontig = nt.transpose(-1, -2)
        self.assertFalse(nt_noncontig.is_contiguous())
        copy = nt_noncontig.reshape(-1, -1, 10)
        self.assertTrue(copy.is_contiguous())
        self.assertEqual(copy.shape[:2], nt.shape[:2])
        # make sure gradients flow back
        if requires_grad:
            copy.backward(torch.ones_like(copy))
            self.assertEqual(nt.grad, torch.ones_like(nt))

    def test_flatten_decomp(self, device):
        nt = random_nt_from_dims(
            [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged
        )
        flattened = nt.flatten(-2, -1)
        self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape)

        nt = random_nt_from_dims(
            [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged
        )
        flattened = nt.flatten(-3, -2)
        self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape)

    def test_chunk(self, device):
        # none NJT case
        t = torch.randn(10, 4, 5, requires_grad=True)
        t_list = t.chunk(3, dim=0)
        loss = t_list[0].sum() + t_list[2].sum()
        loss.backward()

        # normal case
        D = 30
        B = 8
        nt = random_nt_from_dims(
            [B, None, D],
            device=device,
            dtype=torch.float32,
            layout=torch.jagged,
            requires_grad=True,
        )
        NUM_CHUNKS = 3
        chunks = nt.chunk(NUM_CHUNKS, dim=-1)
        self.assertEqual(len(chunks), NUM_CHUNKS)
        for i in range(NUM_CHUNKS):
            self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS)

        # test chunk_backward
        values = torch.randn(
            5, 11, dtype=torch.float64, device=device, requires_grad=True
        )
        offsets = torch.tensor([0, 2, 3, 5], device=device)

        def grad_test_func(values, offsets):
            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
            chunks = nt.chunk(3, dim=-1)
            return chunks[0].values().sum()

        assert gradcheck(
            grad_test_func,
            inputs=(values, offsets),
            check_batched_grad=False,
        )

        # chunk on batch dim
        chunks = nt.chunk(NUM_CHUNKS, dim=0)
        self.assertEqual(len(chunks), NUM_CHUNKS)
        chunk_size = math.ceil(B / NUM_CHUNKS)
        for i in range(NUM_CHUNKS):
            if i < NUM_CHUNKS - 1:
                self.assertEqual(chunks[i].shape[0], chunk_size)
            else:
                self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1))
            offsets_expected = (
                nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1]
                - nt._offsets[i * chunk_size]
            )
            self.assertEqual(chunks[i]._offsets[1:], offsets_expected)
        self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0))

        with self.assertRaisesRegex(
            RuntimeError,
            "dim != 0 INTERNAL ASSERT FAILED .* Nested Tensor doesn't support chunk backward on dim=0 yet.",
        ):
            # doesn't support backward for chunk (dim=0) yet
            loss = (
                chunks[0].values().sum()
                + chunks[1].values().sum()
                + chunks[2].values().sum()
            )
            loss.backward()

        # chunk on ragged dim not supported
        with self.assertRaisesRegex(
            RuntimeError, "chunk.* not supported for NestedTensor on dim=1"
        ):
            nt.chunk(2, dim=1)

    def test_squeeze(self, device):
        B = 4
        D = 6
        # squeeze middle dim
        nt = random_nt_from_dims(
            [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged
        )
        j0 = nt.shape[1]

        for dim_arg in [-2, 2]:
            out = nt.squeeze(dim_arg)
            self.assertEqual(out.shape, (B, j0, D))
            self.assertEqual(out.unsqueeze(-2), nt)

        # squeeze last dim
        nt = random_nt_from_dims(
            [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
        )
        j1 = nt.shape[1]

        for dim_arg in [-1, 2]:
            out = nt.squeeze(dim_arg)
            self.assertEqual(out.shape, (B, j1))
            self.assertEqual(out.unsqueeze(-1), nt)

        # squeeze on batch dim not supported
        with self.assertRaisesRegex(
            RuntimeError, "squeeze.* not supported for NestedTensor on dim=0"
        ):
            nt.squeeze(0)

        # squeeze on ragged dim not supported
        with self.assertRaisesRegex(
            RuntimeError, "squeeze.* not supported for NestedTensor on dim=1"
        ):
            nt.squeeze(1)

    def test_binary_pointwise_broadcasting(self, device):
        # (B, j0, 3, 4)
        ts = self._get_list_for_jagged_tensor(
            ((2, 3, 4), 3, 4), device, requires_grad=True
        )
        # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
        # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
        # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
        # Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
        t_sizes = (
            (4,),
            (1, 4),
            (3, 1),
            (1, 3, 1),
            (1, 1, 1, 4),
            # (1, 1, 1, 1, 4), (unsupported today)
        )

        def grad_test_func(t, *ts):
            nt = torch.nested.as_nested_tensor(list(ts), layout=torch.jagged)
            out = nt + t
            return out.values()

        for t_size in t_sizes:
            t = torch.rand(
                t_size, requires_grad=True, device=device, dtype=torch.float64
            )
            gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)

    def test_threshold_backward(self, device):
        ts1 = self._get_list_for_jagged_tensor(
            ((2, 3, 4), 16), device=device, requires_grad=False
        )
        ts2 = self._get_list_for_jagged_tensor(
            ((2, 3, 4), 16), device=device, requires_grad=False
        )

        nt1, offsets = jagged_from_list(ts1, None)
        nt2, offsets = jagged_from_list(ts2, offsets)
        buf1 = nt1.values().detach().clone()
        buf2 = nt2.values().detach().clone()

        res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0)
        res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0)

        self.assertEqual(res_dense, res_nt.values())

    @dtypes(torch.float32)
    @parametrize(
        "func",
        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
        name_fn=get_op_name,
    )
    @parametrize("keepdim", [False, True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_jagged_op_different_output_shape_dim(
        self, device, dtype, keepdim, requires_grad, components_require_grad, func
    ):
        """
        Operator passes when reducing on valid reduction dimensions.
        This test is for operators which return an output tensor with a shape different from the input tensor.
        """
        if get_op_name(func) == "mean" and not keepdim:
            return

        op_name = get_op_name(func)

        ts = self._get_list_for_jagged_tensor(
            ((2, 3, 4), 3, 4), device=device, requires_grad=True
        )  # (B, j0, 3, 4)

        # verify correctness of shapes (assuming that ragged_idx == 1)
        if op_name == "sum":
            reduce_dims = (
                ((0, 1), (3, 4), (1, 1, 3, 4), (0,)),  # batch, ragged
                ((2, 3), (3, None), (3, None, 1, 1), (1, 2)),  # non-batch, non-batch
                ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)),  # batch, ragged, non-batch
                ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)),  # batch, ragged, non-batch
                (
                    (0, 1, 2, 3),
                    (),
                    (1, 1, 1, 1),
                    (0, 1, 2),
                ),  # batch, ragged, non-batch, non-batch
                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),  # non-batch
            )  # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
        elif op_name == "mean":
            reduce_dims = (
                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),
                ((3,), (3, None, 3), (3, None, 3, 1), (2,)),
            )

        for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims:
            nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
            out = func(nt, dim=rd, keepdim=keepdim)
            ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
            if not torch.compiler.is_compiling:  # if not using torch dynamo
                self.assertEqual(len(out.shape), len(ref_shape))
                for o, r in zip(out.shape, ref_shape):
                    if r is not None:
                        self.assertEqual(o, r)
                    else:
                        self.assertTrue(isinstance(o, torch.SymInt))

        # verify correctness of values
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,
        )
        for tensor_list, reduce_dim_tuple in itertools.product(
            tensor_lists, reduce_dims
        ):
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple

            if nt.dim() > reduce_dim[-1]:
                out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
                if nt._ragged_idx in reduce_dim:  # raggedness reduced away
                    out_expected = func(
                        nt.values(), dim=reduce_dim_expected, keepdim=keepdim
                    )
                    self.assertTrue(torch.allclose(out_actual, out_expected))
                else:  # raggedness preserved
                    out_expected = func(nt.values(), dim=reduce_dim_expected)
                    self.assertTrue(
                        torch.allclose(
                            out_actual.values().view(-1), out_expected.view(-1)
                        )
                    )

    @dtypes(torch.float32)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_softmax_dim(
        self,
        device,
        dtype,
        requires_grad,
        components_require_grad,
    ):
        """
        Softmax passes when reducing on valid reduction dimensions.
        """
        ts = self._get_list_for_jagged_tensor(
            ((2, 3, 4), 3, 4), device=device, requires_grad=True
        )  # (B, j0, 3, 4)

        output_shape = (3, None, 3, 4)

        # verify correctness of shapes (assuming that ragged_idx == 1)
        reduce_dims = (
            (2, 1),
            (3, 2),
        )  # (reduction dimension, effective reduction dimension for baseline)

        for reduce_dim, _ in reduce_dims:
            nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
            out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
            torch._dynamo.disable(self.assertEqual)(
                len(out_actual.shape), len(output_shape)
            )  # disable if running on dynamo
            for dim_actual, dim_expected in zip(out_actual.shape, output_shape):
                if dim_expected is not None:
                    self.assertEqual(dim_actual, dim_expected)
                else:
                    self.assertTrue(isinstance(dim_actual, torch.SymInt))

        # verify correctness of values
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,
        )
        for tensor_list, reduce_dim_tuple in itertools.product(
            tensor_lists, reduce_dims
        ):
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            reduce_dim, reduce_dim_expected = reduce_dim_tuple

            if nt.dim() > reduce_dim:
                out_actual = torch.nn.functional.softmax(
                    nt, dim=reduce_dim
                )  # nested tensor
                out_expected = torch.nn.functional.softmax(
                    nt.values(), dim=reduce_dim_expected
                )  # dense tensor of dimensions 1 less than out_actual
                self.assertTrue(
                    torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
                )

    @dtypes(torch.float32)
    @parametrize(
        "func",
        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
        name_fn=get_op_name,
    )
    @parametrize("keepdim", [False, True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_op_dim_reduce_ragged_idx_1_different_output_shape(
        self, device, dtype, keepdim, requires_grad, components_require_grad, func
    ):
        """
        Operator on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
        This test is for operators which return an output tensor with a shape different from the input tensor.
        """
        if get_op_name(func) == "mean" and not keepdim:
            return

        op_name = get_op_name(func)

        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,  # (B, *, 1)
        )
        reduce_dim = (1,)  # ragged

        for tensor_list in tensor_lists:
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
            out_expected = torch.cat(
                [func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()]
            )

            self.assertFalse(
                out_actual.is_nested,
                f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
            )  # output is a dense tensor
            self.assertTrue(torch.allclose(out_actual, out_expected))

    @dtypes(torch.float32)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_softmax_dim_reduce_ragged_idx_1(
        self, device, dtype, requires_grad, components_require_grad
    ):
        """
        Softmax on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
        """
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,  # (B, *, 1)
            include_2d_tensor=True,  # (B, *)
        )
        reduce_dim = 1  # ragged

        for tensor_list in tensor_lists:
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
            out_expected = torch.cat(
                [
                    torch.nn.functional.softmax(t, dim=reduce_dim - 1)
                    for t in nt.unbind()
                ]
            )

            self.assertTrue(
                out_actual.is_nested,
                "softmax(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
            )  # output is a nested tensor
            self.assertTrue(torch.allclose(out_actual.values(), out_expected))

    @dtypes(torch.float32)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_softmax_reduce_batch_dim(
        self, device, dtype, requires_grad, components_require_grad
    ):
        """
        Softmax on NestedTensor fails when trying to reduce across batch dimension.
        """
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,  # (B, *, 1)
        )
        reduce_dim = 0  # batch

        for tensor_list in tensor_lists:
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            with self.assertRaisesRegex(
                RuntimeError,
                "not supported when reducing across the batch dimension for NestedTensor",
            ):
                out = torch.nn.functional.softmax(nt, dim=reduce_dim)

    @dtypes(torch.float32)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_layer_norm_reduce_ragged_idx_1(
        self, device, dtype, requires_grad, components_require_grad
    ):
        """
        Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1.
        """

        # requires_grad = False does not currently work with dynamo tests and throws this error:
        #   AssertionError: SymInts must use SymNodeVariable.
        #   If the underlying value is static, we will create a ConstantVariable and specialize.
        if torch._dynamo.is_compiling() and not requires_grad:
            return

        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,  # (B, *, 1)
        )

        for tensor_list in tensor_lists:
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            if (
                nt.dim() >= 3
            ):  # layer norm only works for tensors with 3 or more dimensions
                normalized_shape = nt.shape[nt._ragged_idx :]

                out_actual = torch.nn.functional.layer_norm(
                    nt, normalized_shape=normalized_shape
                )
                out_expected = torch.cat(
                    [
                        torch.nn.functional.layer_norm(t, normalized_shape=t.shape)
                        for t in nt.unbind()
                    ]
                )  # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M)

                self.assertTrue(
                    out_actual.is_nested,
                    "layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
                )  # output is a nested tensor
                self.assertEqual(out_actual._values.shape, out_expected.shape)
                self.assertTrue(torch.allclose(out_actual.values(), out_expected))

    @dtypes(torch.float32)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_layer_norm_2d_input(
        self,
        device,
        dtype,
        requires_grad,
        components_require_grad,
    ):
        """
        Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor
        """
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,  # (B, *, 1)
            include_2d_tensor=True,  # (B, *)
        )

        for tensor_list in tensor_lists:
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            if nt.dim() <= 2:
                with self.assertRaisesRegex(
                    RuntimeError,
                    "not supported for NestedTensor objects with 2 or fewer dimensions",
                ):
                    out = torch.nn.functional.layer_norm(
                        nt, normalized_shape=(nt.shape[nt._ragged_idx],)
                    )

    @dtypes(torch.float32)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_layer_norm_operate_on_batch_dim(
        self,
        device,
        dtype,
        requires_grad,
        components_require_grad,
    ):
        """
        Layer normalization on NestedTensor fails when trying to operate on the batch dimension
        """
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,  # (B, *, 1)
            include_2d_tensor=True,  # (B, *)
        )

        for tensor_list in tensor_lists:
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            if nt.dim() > 2:  # cannot perform layer normalization on 2D tensors
                with self.assertRaisesRegex(
                    RuntimeError,
                    "not supported when normalizing over the batch dimension for NestedTensor",
                ):
                    out = torch.nn.functional.layer_norm(nt, normalized_shape=nt.shape)

    @dtypes(torch.float32)
    @parametrize(
        "func",
        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
        name_fn=get_op_name,
    )
    @parametrize(
        "transpose_offset", [1, 2]
    )  # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
    @parametrize("keepdim", [False, True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape(
        self,
        device,
        dtype,
        keepdim,
        requires_grad,
        components_require_grad,
        func,
        transpose_offset,
    ):
        """
        Operator on NestedTensor passes when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
        This test is for operators which return an output tensor with a shape different from the input tensor.
        """
        if get_op_name(func) == "mean" and not keepdim:
            return

        op_name = get_op_name(func)

        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,  # (B, *, 1)
            include_2d_tensor=True,  # (B, *)
        )

        for tensor_list in tensor_lists:
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            if nt.dim() > nt._ragged_idx + transpose_offset:
                nt_transposed = nt.transpose(
                    nt._ragged_idx, nt._ragged_idx + transpose_offset
                )
                reduce_dim = (nt_transposed._ragged_idx,)  # ragged

                out_actual = func(nt_transposed, dim=reduce_dim, keepdim=keepdim)
                out_expected = torch.cat(
                    [
                        func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0)
                        for t in nt_transposed.unbind()
                    ]
                )

                self.assertFalse(
                    out_actual.is_nested,
                    f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
                )  # output is a dense tensor
                self.assertTrue(torch.allclose(out_actual, out_expected, rtol=1e-4))

    @dtypes(torch.float32)
    @parametrize(
        "transpose_offset", [1, 2]
    )  # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape(
        self,
        device,
        dtype,
        requires_grad,
        components_require_grad,
        transpose_offset,
    ):
        """
        Softmax on NestedTensor fails when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
        This test is for operators which return an output tensor with the same shape as the input tensor.
        """
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,  # (B, *, 1)
        )

        for tensor_list in tensor_lists:
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            if nt.dim() > nt._ragged_idx + transpose_offset:
                nt_transposed = nt.transpose(
                    nt._ragged_idx, nt._ragged_idx + transpose_offset
                )
                reduce_dim = nt_transposed._ragged_idx  # ragged

                with self.assertRaisesRegex(
                    RuntimeError,
                    "not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor",
                ):
                    out = torch.nn.functional.softmax(nt_transposed, dim=reduce_dim)

    @dtypes(torch.float32)
    @parametrize(
        "func",
        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
        name_fn=get_op_name,
    )
    @parametrize("keepdim", [False, True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_op_dim_transpose_non_ragged_dim_different_output_shape(
        self, device, dtype, keepdim, requires_grad, components_require_grad, func
    ):
        """
        Operator passes when reducing transposed nested tensors on valid reduction dimensions.
        This test is for operators which return an output tensor with a shape different from the input tensor.
        """
        if get_op_name(func) == "mean" and not keepdim:
            return

        # verify correctness of shapes (assuming that ragged_idx == 1)
        if get_op_name(func) == "sum":
            reduce_dims = (
                ((0, 1), (3, 4), (1, 1, 3, 4), (0,)),  # batch, ragged
                ((2, 3), (3, None), (3, None, 1, 1), (1, 2)),  # non-batch, non-batch
                ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)),  # batch, ragged, non-batch
                ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)),  # batch, ragged, non-batch
                (
                    (0, 1, 2, 3),
                    (),
                    (1, 1, 1, 1),
                    (0, 1, 2),
                ),  # batch, ragged, non-batch, non-batch
                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),  # non-batch
            )  # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
        elif get_op_name(func) == "mean":
            reduce_dims = (
                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),
                ((3,), (3, None, 3), (3, None, 3, 1), (2,)),
            )

        # verify correctness of values
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
        )
        for tensor_list, reduce_dim_tuple in itertools.product(
            tensor_lists, reduce_dims
        ):
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            ).transpose(-1, -2)

            reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple

            if nt.dim() > max(
                reduce_dim[-1], nt._ragged_idx + 2
            ):  # ensure that transposed dimensions are non-batch, non-ragged dimensions
                out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
                if nt._ragged_idx in reduce_dim:  # raggedness reduced away
                    out_expected = func(
                        nt.values(), dim=reduce_dim_expected, keepdim=keepdim
                    )
                    self.assertTrue(torch.allclose(out_actual, out_expected))
                else:  # raggedness preserved
                    out_expected = func(nt.values(), dim=reduce_dim_expected)
                    self.assertTrue(
                        torch.allclose(
                            out_actual.values().view(-1), out_expected.view(-1)
                        )
                    )

    @dtypes(torch.float32)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_softmax_dim_transpose_non_ragged_dim(
        self,
        device,
        dtype,
        requires_grad,
        components_require_grad,
    ):
        """
        Softmax passes when reducing transposed nested tensors on valid reduction dimensions.
        This test is for operators which return an output tensor with the same shape as the input tensor.
        """
        # verify correctness of shapes (assuming that ragged_idx == 1)
        reduce_dims = (
            (2, 1),
            (3, 2),
        )  # (reduction dimension, effective reduction dimension for baseline)

        # verify correctness of values
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False,
            include_requires_grad=components_require_grad,
            include_inner_dim_size_1=True,  # (B, *, 1)
        )
        for tensor_list, reduce_dim_tuple in itertools.product(
            tensor_lists, reduce_dims
        ):
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            ).transpose(-1, -2)

            reduce_dim, reduce_dim_expected = reduce_dim_tuple

            if nt.dim() > max(reduce_dim, nt._ragged_idx + 2):
                out_actual = torch.nn.functional.softmax(
                    nt, dim=reduce_dim
                )  # nested tensor
                out_expected = torch.nn.functional.softmax(
                    nt.values(), dim=reduce_dim_expected
                )  # dense tensor of dimensions 1 less than out_actual

                self.assertTrue(
                    torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
                )

    @dtypes(torch.float32)
    @parametrize("keepdim", [False, True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_sum_dim_reduce_ragged_and_non_batch(
        self,
        device,
        dtype,
        keepdim,
        requires_grad,
        components_require_grad,
    ):
        """
        Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions
        """
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False, include_requires_grad=components_require_grad
        )
        reduce_dims = (
            (1, 2),  # ragged, non-batch
            (1, 3),  # ragged, non-batch
        )

        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            if nt.dim() > reduce_dim[-1]:
                with self.assertRaisesRegex(
                    RuntimeError,
                    "not supported along a ragged and non-batch dimension for NestedTensor",
                ):
                    out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)

    @dtypes(torch.float32)
    @parametrize("keepdim", [False, True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_sum_dim_reduce_batch_and_non_batch(
        self,
        device,
        dtype,
        keepdim,
        requires_grad,
        components_require_grad,
    ):
        """
        Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions
        """
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False, include_requires_grad=components_require_grad
        )
        reduce_dims = (
            (0, 2),  # batch, non-batch
            (0, 3),  # batch, non-batch
        )

        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            if nt.dim() > reduce_dim[-1]:
                with self.assertRaisesRegex(
                    RuntimeError,
                    "not supported along the batch dimension but not the ragged dimension for NestedTensor",
                ):
                    out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)

    @dtypes(torch.float32)
    @parametrize(
        "func",
        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
        name_fn=get_op_name,
    )
    @parametrize("keepdim", [False, True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_op_dim_reduce_batch_only_different_output_shape(
        self, device, dtype, keepdim, requires_grad, components_require_grad, func
    ):
        """
        Operator on NestedTensor fails when trying to reduce across batch dimension
        """
        if get_op_name(func) == "mean" and not keepdim:
            return

        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False, include_requires_grad=components_require_grad
        )
        reduce_dim = (0,)  # batch

        for tensor_list in tensor_lists:
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            with self.assertRaisesRegex(
                RuntimeError,
                "not supported along the batch dimension but not the ragged dimension for NestedTensor",
            ):
                out = func(nt, dim=reduce_dim, keepdim=keepdim)

    @dtypes(torch.float32)
    @parametrize(
        "func",
        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
        name_fn=get_op_name,
    )
    @parametrize("keepdim", [False, True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_op_dim_with_lengths_different_output_shape(
        self,
        device,
        dtype,
        keepdim,
        requires_grad,
        components_require_grad,
        func,
    ):
        """
        Operator on NestedTensor fails when trying to reduce a nested tensor with lengths,
        i.e. a nested tensor with holes, if reducing on the ragged dimension.
        This test is for operators which return an output tensor with different shape than the input tensor.
        """
        if get_op_name(func) == "mean" and not keepdim:
            return

        reduce_dims = ((1,), (2,), (2, 3))

        lengths = torch.randint(5, 10, (20,), device=device)
        offsets = torch.zeros((21,), device=device, dtype=torch.int)
        torch.cumsum(lengths, dim=0, out=offsets[1:])

        values = torch.randn(
            (offsets[-1].item(), 20),
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
        )

        nt_with_holes = torch.nested.nested_tensor_from_jagged(
            values,
            offsets,
            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
        )

        for reduce_dim in reduce_dims:
            if nt_with_holes.dim() > reduce_dim[-1]:
                if nt_with_holes._ragged_idx in reduce_dim:
                    with self.assertRaisesRegex(
                        RuntimeError,
                        "not supported where lengths is not None "
                        + "if reducing across the ragged dimension for NestedTensor",
                    ):
                        out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
                else:
                    out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)

    @dtypes(torch.float32)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_softmax_dim_with_lengths(
        self,
        device,
        dtype,
        requires_grad,
        components_require_grad,
    ):
        """
        Softmax on NestedTensor fails when trying to reduce a nested tensor with lengths,
        i.e. a nested tensor with holes, if reducing on the ragged dimension.
        """
        reduce_dims = (1, 2, 3)

        lengths = torch.randint(5, 10, (20,), device=device)
        offsets = torch.zeros((21,), device=device, dtype=torch.int)
        torch.cumsum(lengths, dim=0, out=offsets[1:])

        values = torch.randn(
            (offsets[-1].item(), 20),
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
        )

        nt_with_holes = torch.nested.nested_tensor_from_jagged(
            values,
            offsets,
            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
        )

        for reduce_dim in reduce_dims:
            if nt_with_holes.dim() > reduce_dim:
                if nt_with_holes._ragged_idx == reduce_dim:
                    with self.assertRaisesRegex(
                        RuntimeError,
                        "not supported where lengths is not None "
                        + "if reducing across the ragged dimension for NestedTensor",
                    ):
                        out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
                else:
                    out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)

    @skipIfTorchDynamo(
        "ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work "
        + "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. "
        + "If the underlying value is static, we will create a ConstantVariable and specialize.`"
    )
    @dtypes(torch.float32)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_layer_norm_with_lengths(
        self,
        device,
        dtype,
        requires_grad,
        components_require_grad,
    ):
        """
        Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths,
        i.e. a nested tensor with holes, if operating on the ragged dimension.
        """

        # create components for nested tensor
        lengths = torch.randint(5, 10, (20,), device=device)
        offsets = torch.zeros((21,), device=device, dtype=torch.int)
        torch.cumsum(lengths, dim=0, out=offsets[1:])
        values = torch.randn(
            (offsets[-1].item(), 10, 30),
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
        )

        nt_with_holes = torch.nested.nested_tensor_from_jagged(
            values,
            offsets,
            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
        )

        ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx]

        normalized_shapes = (
            (10, 30),  # normalization on non-ragged dimension passes
            (ragged_size, 10, 30),  # normalization on ragged dimension fails
        )

        for normalized_shape in normalized_shapes:
            if ragged_size in normalized_shape:
                with self.assertRaisesRegex(
                    RuntimeError,
                    "not supported where lengths is not None if operating on the ragged dimension for NestedTensor",
                ):
                    out = torch.nn.functional.layer_norm(
                        nt_with_holes, normalized_shape=normalized_shape
                    )
            else:
                out = torch.nn.functional.layer_norm(
                    nt_with_holes, normalized_shape=normalized_shape
                )

    @dtypes(torch.float32)
    @parametrize("keepdim", [True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_mean_dim_reduce_multiple_dims(
        self,
        device,
        dtype,
        keepdim,
        requires_grad,
        components_require_grad,
    ):
        """
        Mean on NestedTensor fails when trying to reduce across multiple dimensions
        """
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False, include_requires_grad=components_require_grad
        )
        reduce_dims = ((0, 1), (2, 3), (2, 3, 4))

        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            if nt.dim() > reduce_dim[-1]:
                with self.assertRaisesRegex(
                    RuntimeError,
                    "not supported across multiple dimensions for NestedTensor",
                ):
                    out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)

    @dtypes(torch.float32)
    @parametrize("keepdim", [False, True])
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_mean_dim_keepdim_False(
        self,
        device,
        dtype,
        keepdim,
        requires_grad,
        components_require_grad,
    ):
        """
        Mean on NestedTensor fails when keepdim=False
        """
        tensor_lists = self._get_example_tensor_lists(
            include_list_of_lists=False, include_requires_grad=components_require_grad
        )
        reduce_dims = ((1,), (2,), (3,))

        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            if nt.dim() > reduce_dim[-1]:
                if not keepdim:
                    with self.assertRaisesRegex(
                        RuntimeError,
                        "not supported when keepdim=False for NestedTensor",
                    ):
                        out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)
                else:
                    out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)

    @dtypes(torch.float, torch.double, torch.half)
    @parametrize("requires_grad", [False, True])
    @parametrize("weights_only", [False, True])
    def test_serialization(self, device, dtype, requires_grad, weights_only):
        def compare_metadata(nt1, nt2):
            self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
            self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
            self.assertEqual(
                nt1._nested_tensor_storage_offsets(),
                nt2._nested_tensor_storage_offsets(),
            )

        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
        for a in [nt_contiguous, nt_noncontiguous]:
            buffer = io.BytesIO()
            serialized = torch.save(a, buffer)
            buffer.seek(0)
            b = torch.load(buffer, weights_only=weights_only)
            # should be both conceptually equal and metadata equivalent
            self.assertEqual(a, b)
            compare_metadata(a, b)
            # should be conceptually equal but not necessarily metadata equivalent
            self.assertEqual(b, nt_contiguous)
            self.assertEqual(b, nt_noncontiguous)

    @unittest.skipIf(
        PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
    )
    @onlyCUDA
    def test_pin_memory(self, device):
        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
        for nt in [nt_contiguous, nt_noncontiguous]:
            self.assertFalse(nt.is_pinned())
            pinned = nt.pin_memory(device)
            self.assertTrue(pinned.is_pinned())
            self.assertEqual(nt, pinned)
            self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
            # test that pin_memory on already pinned tensor has no effect
            self.assertIs(pinned, pinned.pin_memory())
            self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())

    @torch.compiler.disable
    def _validate_nt(
        self,
        nt,
        device,
        dtype,
        layout,
        requires_grad,
        dim,
        batch_size,
        contiguous,
        cached_min_seqlen=None,
        cached_max_seqlen=None,
        base=None,
        ref_nt=None,
    ):
        # Validate a bunch of properties after NT construction.
        device = torch.device(device)
        self.assertEqual(nt.dim(), dim)
        self.assertEqual(nt.device, device)
        self.assertEqual(nt.dtype, dtype)
        self.assertEqual(nt.layout, layout)
        self.assertEqual(nt.requires_grad, requires_grad)
        self.assertEqual(nt.is_contiguous(), contiguous)

        if layout == torch.jagged:
            self.assertEqual(nt._values.device, device)
            self.assertEqual(nt._offsets.device, device)
            self.assertEqual(nt.shape[0], batch_size)
            self.assertTrue(isinstance(nt.shape[1], torch.SymInt))

            if base is not None:
                self.assertTrue(nt._is_view() and nt._base is base)
                replay_cache = nt._view_func(torch.randn_like(nt._base))._metadata_cache
                self.assertEqual(
                    "min_seqlen" in replay_cache, cached_min_seqlen is not None
                )
                self.assertEqual(
                    "max_seqlen" in replay_cache, cached_max_seqlen is not None
                )

            self.assertEqual(
                "min_seqlen" in nt._metadata_cache, cached_min_seqlen is not None
            )
            self.assertEqual(
                "max_seqlen" in nt._metadata_cache, cached_max_seqlen is not None
            )

            if cached_min_seqlen is not None:
                self.assertEqual(nt._min_seqlen, cached_min_seqlen)

            if cached_max_seqlen is not None:
                self.assertEqual(nt._max_seqlen, cached_max_seqlen)

        if ref_nt is not None:
            self.assertEqual(nt.size(0), ref_nt.size(0))
            for n1, n2 in zip(nt.unbind(), ref_nt.unbind()):
                self.assertEqual(n1, n2)

    @dtypes(torch.float, torch.double, torch.half)
    @parametrize("requires_grad", [False, True])
    @parametrize("components_require_grad", [False, True])
    def test_jagged_layout_construction_nested_tensor(
        self, device, dtype, requires_grad, components_require_grad
    ):
        for tensor_list in self._get_example_tensor_lists(
            include_list_of_lists=True, include_requires_grad=components_require_grad
        ):
            nt = torch.nested.nested_tensor(
                tensor_list,
                device=device,
                dtype=dtype,
                layout=torch.jagged,
                requires_grad=requires_grad,
            )

            expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
            expected_batch_size = len(tensor_list)
            expected_contiguous = True
            expected_min_seqlen = min(
                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
                for t in tensor_list
            )
            expected_max_seqlen = max(
                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
                for t in tensor_list
            )
            self._validate_nt(
                nt,
                device,
                dtype,
                torch.jagged,
                requires_grad,
                expected_dim,
                expected_batch_size,
                expected_contiguous,
                expected_min_seqlen,
                expected_max_seqlen,
            )

            # Make sure grads -don't- flow back into original tensors for nested_tensor()
            if requires_grad:
                (nt * 2).backward(torch.ones_like(nt))
            for t in tensor_list:
                t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t)
                self.assertTrue(t.grad is None)

    @dtypes(torch.float, torch.double, torch.half)
    @parametrize("components_require_grad", [False, True])
    def test_jagged_layout_construction_as_nested_tensor(
        self, device, dtype, components_require_grad
    ):
        # NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list
        for tensor_list in self._get_example_tensor_lists(
            include_list_of_lists=False, include_requires_grad=components_require_grad
        ):
            nt = torch.nested.as_nested_tensor(
                tensor_list, device=device, dtype=dtype, layout=torch.jagged
            )

            # nt.requires_grad=True should be set if at least one component requires grad
            expected_dim = tensor_list[0].dim() + 1
            expected_batch_size = len(tensor_list)
            expected_contiguous = True
            expected_min_seqlen = min(
                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
                for t in tensor_list
            )
            expected_max_seqlen = max(
                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
                for t in tensor_list
            )
            self._validate_nt(
                nt,
                device,
                dtype,
                torch.jagged,
                components_require_grad,
                expected_dim,
                expected_batch_size,
                expected_contiguous,
                expected_min_seqlen,
                expected_max_seqlen,
            )

            # Make sure grads flow back into original tensors for as_nested_tensor()
            if components_require_grad:
                (nt * 2).backward(torch.ones_like(nt))
                for t in tensor_list:
                    if t.requires_grad:
                        self.assertEqual(t.grad, torch.ones_like(t) * 2)
                    else:
                        self.assertTrue(t.grad is None)

    @xfailIfTorchDynamo
    @unittest.skipIf(
        PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
    )
    @onlyCUDA
    def test_jagged_layout_construction_with_pinned_memory(self, device):
        for tensor_list in self._get_example_tensor_lists():
            nt = torch.nested.nested_tensor(
                tensor_list, layout=torch.jagged, device="cpu", pin_memory=True
            )

            expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
            expected_batch_size = len(tensor_list)
            expected_min_seqlen = min(
                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
                for t in tensor_list
            )
            expected_max_seqlen = max(
                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
                for t in tensor_list
            )
            self._validate_nt(
                nt,
                device="cpu",
                dtype=torch.float32,
                layout=torch.jagged,
                requires_grad=False,
                dim=expected_dim,
                batch_size=expected_batch_size,
                contiguous=True,
                cached_min_seqlen=expected_min_seqlen,
                cached_max_seqlen=expected_max_seqlen,
            )
            self.assertTrue(nt.is_pinned())

    @dtypes(torch.float, torch.double, torch.half)
    @parametrize("requires_grad", [False, True])
    @parametrize("values_is_view", [False, True])
    def test_jagged_view_from_values_offsets(
        self, device, dtype, requires_grad, values_is_view
    ):
        if values_is_view:
            # make values a view of base
            base = torch.randn(
                2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad
            )
            values = base.flatten(0, -2)
        else:
            values = torch.randn(
                10, 5, device=device, dtype=dtype, requires_grad=requires_grad
            )
        offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)

        nt = nested_view_from_values_offsets(values, offsets)

        expected_dim = values.dim() + 1
        expected_batch_size = offsets.shape[0] - 1
        expected_base = base if values_is_view else values
        lengths = offsets.diff()
        self._validate_nt(
            nt,
            device,
            dtype,
            torch.jagged,
            requires_grad,
            expected_dim,
            expected_batch_size,
            # ensure NT is a proper view
            base=expected_base,
            contiguous=True,
            # if no min / max are passed, expect the metadata cache to be empty
            cached_min_seqlen=None,
            cached_max_seqlen=None,
        )

        if requires_grad:
            # Make sure grads flow back
            (nt * 2).backward(torch.ones_like(nt))

            @torch.compiler.disable
            def _check_grad(t):
                self.assertTrue(t.grad is not None)
                self.assertEqual(t.grad, torch.ones_like(t) * 2)

            _check_grad(base if values_is_view else values)

    @dtypes(torch.float)
    @parametrize("pass_min_max", [False, True])
    def test_nested_tensor_from_jagged(self, device, dtype, pass_min_max):
        # === construct from (values, offsets) ===
        values = torch.randn(10, 5, device=device, dtype=dtype)
        offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)

        # compute min / max seqlen
        lengths = offsets.diff()
        min_seqlen = lengths.min().item()
        max_seqlen = lengths.max().item()

        if pass_min_max:
            nt = torch.nested.nested_tensor_from_jagged(
                values, offsets=offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
            )
        else:
            nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets)
        self._validate_nt(
            nt,
            device,
            dtype,
            torch.jagged,
            requires_grad=False,
            dim=3,
            batch_size=4,
            contiguous=True,
            cached_min_seqlen=(min_seqlen if pass_min_max else None),
            cached_max_seqlen=(max_seqlen if pass_min_max else None),
            base=values,
        )

        # === construct from (values, offsets, lengths) ===
        lengths = torch.tensor([2, 1, 1, 2], device=device)

        # compute min / max seqlen
        min_seqlen = lengths.min().item()
        max_seqlen = lengths.max().item()

        if pass_min_max:
            nt = torch.nested.nested_tensor_from_jagged(
                values,
                offsets=offsets,
                lengths=lengths,
                min_seqlen=min_seqlen,
                max_seqlen=max_seqlen,
            )
        else:
            nt = torch.nested.nested_tensor_from_jagged(
                values, offsets=offsets, lengths=lengths
            )

        # when both offsets / lengths are specified, expect non-contiguous
        self._validate_nt(
            nt,
            device,
            dtype,
            torch.jagged,
            requires_grad=False,
            dim=3,
            batch_size=4,
            contiguous=False,
            cached_min_seqlen=(min_seqlen if pass_min_max else None),
            cached_max_seqlen=(max_seqlen if pass_min_max else None),
            base=values,
        )
        self.assertIs(nt.lengths(), lengths)

        # === construct from (values, lengths) ===
        values = torch.randn(14, 5, device=device, dtype=dtype)
        lengths = torch.tensor([2, 3, 4, 5], device=device)

        # compute min / max seqlen
        min_seqlen = lengths.min().item()
        max_seqlen = lengths.max().item()

        if pass_min_max:
            nt = torch.nested.nested_tensor_from_jagged(
                values, lengths=lengths, min_seqlen=min_seqlen, max_seqlen=max_seqlen
            )
        else:
            nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)

        # for now, if only lengths is specified, convert to offsets to integrate best with the
        # existing kernels
        expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device)
        expected_nt = torch.nested.nested_tensor_from_jagged(
            values, offsets=expected_offsets
        )
        self._validate_nt(
            nt,
            device,
            dtype,
            torch.jagged,
            requires_grad=False,
            dim=3,
            batch_size=4,
            contiguous=True,
            cached_min_seqlen=(min_seqlen if pass_min_max else None),
            cached_max_seqlen=(max_seqlen if pass_min_max else None),
            base=values,
            ref_nt=expected_nt,
        )

        # error case: no offsets or lengths
        with self.assertRaisesRegex(
            RuntimeError, "At least one of offsets or lengths is required"
        ):
            torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None)

    @onlyCPU
    def test_nested_tensor_from_jagged_fx_trace(self, device):
        def fn(x, y):
            return torch.nested.nested_tensor_from_jagged(x, y)

        def user_unwrapped(x, y):
            return fn(x, y)

        with self.assertRaisesRegex(
            RuntimeError,
            "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace",
        ):
            torch.fx.symbolic_trace(user_unwrapped)

    @dtypes(torch.float, torch.double, torch.half)
    @parametrize("dim", range(5))
    @parametrize(
        "layout",
        [torch.strided, torch.jagged],
        name_fn=lambda l: f"layout_{str(l).split('.')[1]}",
    )
    @parametrize("requires_grad", [False, True])
    @parametrize("contiguous", [False, True])
    def test_as_nested_tensor_from_tensor(
        self, device, dtype, dim, layout, requires_grad, contiguous
    ):
        if dim == 0:
            t = torch.tensor(3.0, requires_grad=requires_grad)
        else:
            t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad)
        assert t.dim() == dim

        if dim < 2:
            # 0-1 dim tensors can't be converted to NTs
            with self.assertRaisesRegex(
                RuntimeError, "Expected tensor argument to have dim"
            ):
                nt = torch.nested.as_nested_tensor(
                    t, device=device, dtype=dtype, layout=layout
                )
            return

        orig_t = t
        if not contiguous:
            t = t.transpose(0, 1)

        nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout)
        expected_dim = t.dim()
        expected_batch_size = t.size(0)
        expected_seqlen = t.size(1) if layout == torch.jagged else None
        self._validate_nt(
            nt,
            device,
            dtype,
            layout,
            requires_grad=requires_grad,
            dim=dim,
            batch_size=expected_batch_size,
            contiguous=True,
            cached_min_seqlen=expected_seqlen,
            cached_max_seqlen=expected_seqlen,
        )

        if torch.device(device) == t.device and dtype == t.dtype and contiguous:
            # should be the non-copying (view) case
            self.assertTrue(nt._is_view() and nt._base is t)

        # should have equivalent components to construction from unbound tensor list
        nt_from_unbind = torch.nested.as_nested_tensor(
            list(t.unbind(0)), device=device, dtype=dtype, layout=layout
        )
        self.assertEqualIgnoringNestedInts(nt, nt_from_unbind)

        # ensure call on a NT with the same properties returns the NT directly
        nt2 = torch.nested.as_nested_tensor(
            nt, device=device, dtype=dtype, layout=layout
        )
        self.assertTrue(nt is nt2)

        # ensure call with device=None uses input tensor device
        nt3 = torch.nested.as_nested_tensor(
            t.to(device=device, dtype=dtype),
            device=None,
            dtype=None,
            layout=layout,
        )
        self._validate_nt(
            nt3,
            device,
            dtype,
            layout,
            requires_grad=requires_grad,
            dim=dim,
            batch_size=expected_batch_size,
            contiguous=True,
            cached_min_seqlen=expected_seqlen,
            cached_max_seqlen=expected_seqlen,
        )

        # we don't support conversion between layouts this way atm
        other_layout = torch.strided if layout == torch.jagged else torch.jagged
        with self.assertRaisesRegex(
            RuntimeError, "Converting between nested tensor layouts is not supported"
        ):
            torch.nested.as_nested_tensor(
                nt, device=device, dtype=dtype, layout=other_layout
            )

        if requires_grad:
            # make sure gradients flow back into inputs
            (nt * 2).backward(torch.ones_like(nt))
            self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2)

    @dtypes(torch.double, torch.half)
    @onlyCUDA
    def test_device_dtype_transfer_updates_offsets(self, device, dtype):
        for tensor_list in self._get_example_tensor_lists():
            orig_device = torch.device("cpu")
            orig_dtype = torch.float32
            nt = torch.nested.nested_tensor(
                tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype
            )

            self.assertEqual(torch.int64, nt.offsets().dtype)
            nt = nt.to(device=device).to(dtype=dtype)

            # offsets should still be int64 on the new device
            self.assertEqual(nt.values().device, nt.offsets().device)
            self.assertEqual(torch.int64, nt.offsets().dtype)

    def test_unbind(self, device):
        for tensor_list in self._get_example_tensor_lists():
            nt = torch.nested.nested_tensor(
                tensor_list, layout=torch.jagged, device=device
            )  # ragged_idx = 1
            out = nt.unbind()
            self.assertEqual(len(out), len(tensor_list))
            for i, t in enumerate(out):
                self.assertEqual(t, tensor_list[i])

    @parametrize("ragged_idx", [2, 3])
    def test_unbind_transpose(self, device, ragged_idx):
        for tensor_list in self._get_example_tensor_lists():
            nt = torch.nested.nested_tensor(
                tensor_list, layout=torch.jagged, device=device
            )
            if ragged_idx < nt.dim():
                nt = nt.transpose(1, ragged_idx)  # set ragged_idx
                out = nt.unbind()
                self.assertEqual(len(out), len(tensor_list))
                for i, t in enumerate(out):
                    self.assertEqual(
                        t.transpose(0, ragged_idx - 1), tensor_list[i]
                    )  # transpose back each element of result

    def test_unbind_transpose_ragged_idx_last_dim(self, device):
        for tensor_list in self._get_example_tensor_lists():
            nt = torch.nested.nested_tensor(
                tensor_list, layout=torch.jagged, device=device
            ).transpose(1, -1)  # set ragged_idx = last dimension
            out = nt.unbind()
            self.assertEqual(len(out), len(tensor_list))
            for i, t in enumerate(out):
                self.assertEqual(
                    t.transpose(0, -1), tensor_list[i]
                )  # transpose back each element of result

    def test_unbind_lengths(self, device):
        values = torch.randn(16, 128, device=device)
        offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
        lengths = torch.tensor([6, 2, 1, 2], device=device)
        nt = torch.nested.nested_tensor_from_jagged(
            values, offsets=offsets, lengths=lengths
        )  # 3D nested tensor

        tensor_list = []
        for i in range(offsets.shape[0] - 1):
            tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])])

        out = nt.unbind()
        self.assertEqual(len(out), len(tensor_list))
        for i, t in enumerate(out):
            self.assertEqual(t, tensor_list[i])

    def test_unbind_lengths_ragged_idx_1(self, device):
        values = torch.randn(16, 8, 128, device=device)
        offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
        lengths = torch.tensor([6, 2, 1, 2], device=device)
        ragged_idx = 1
        nt = torch.nested._internal.nested_tensor.NestedTensor(
            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
        )  # 4D nested tensor

        tensor_list = []
        for i in range(offsets.shape[0] - 1):
            tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :])

        out = nt.unbind()

        self.assertEqual(len(out), len(tensor_list))
        for i, t in enumerate(out):
            self.assertEqual(t, tensor_list[i])

    def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device):
        values = torch.randn(16, 8, 128, device=device)
        offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
        lengths = torch.tensor([6, 2, 1, 2], device=device)
        ragged_idx = 2
        nt = torch.nested._internal.nested_tensor.NestedTensor(
            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
        )  # 4D nested tensor

        self.assertRaisesRegex(
            RuntimeError,
            r"unbind\(\): nested tensor offsets and lengths.*",
            lambda: nt.unbind(),
        )

    def test_unbind_lengths_ragged_idx_2(self, device):
        values = torch.randn(16, 8, 128, device=device)
        offsets = torch.tensor([0, 2, 4, 8], device=device)
        lengths = torch.tensor([2, 1, 3], device=device)
        ragged_idx = 2
        nt = torch.nested._internal.nested_tensor.NestedTensor(
            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
        )  # 4D nested tensor

        tensor_list = []
        for i in range(offsets.shape[0] - 1):
            tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :])

        out = nt.unbind()

        self.assertEqual(len(out), len(tensor_list))
        for i, t in enumerate(out):
            self.assertEqual(t, tensor_list[i])

    def test_unbind_lengths_ragged_idx_3(self, device):
        values = torch.randn(16, 8, 128, device=device)
        offsets = torch.tensor([0, 100, 128], device=device)
        lengths = torch.tensor([50, 28], device=device)
        ragged_idx = 3
        nt = torch.nested._internal.nested_tensor.NestedTensor(
            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
        )  # 4D nested tensor

        tensor_list = []
        for i in range(offsets.shape[0] - 1):
            tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])

        out = nt.unbind()

        self.assertEqual(len(out), len(tensor_list))
        for i, t in enumerate(out):
            self.assertEqual(t, tensor_list[i])

    @skipIfTorchDynamo(
        "TorchDynamo raises an error for ragged_idx == 0 earlier than Torch"
    )
    def test_unbind_lengths_ragged_idx_0(self, device):
        values = torch.randn(16, 8, 128, device=device)
        offsets = torch.tensor([0, 100, 128], device=device)
        lengths = torch.tensor([50, 28], device=device)
        ragged_idx = 0
        nt = torch.nested._internal.nested_tensor.NestedTensor(
            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
        )  # 4D nested tensor

        tensor_list = []
        for i in range(offsets.shape[0] - 1):
            tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])

        self.assertRaisesRegex(
            RuntimeError,
            r"unbind\(\): nested tensor.*out of bounds",
            lambda: nt.unbind(),
        )

    def test_narrow(self, device):
        starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
        lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
        buffer = (
            torch.arange(0, 10, device=device, dtype=torch.int64)
            .unsqueeze(0)
            .expand(5, -1)
            .clone()
            .detach()
        )
        nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged)

        self.assertTrue(nt._is_view() and nt._base is buffer)

        # TODO: Use this approach when unbind is functional
        # unbinded_nt = nt.unbind()
        # for i in range(starts.shape[0]):
        #     self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i])
        for i in range(starts.shape[0]):
            self.assertEqual(
                torch.arange(
                    starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64
                ),
                nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])],
            )

    def test_njt_cat(self, device):
        offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
        values_1 = torch.randn(
            3, 2, dtype=torch.float64, device=device, requires_grad=True
        )
        values_2 = torch.randn(
            3, 4, dtype=torch.float64, device=device, requires_grad=True
        )

        def grad_test_func(values_1, values_2, offsets):
            nt_1 = torch.nested.nested_tensor_from_jagged(values_1, offsets)
            nt_2 = torch.nested.nested_tensor_from_jagged(values_2, offsets)
            nt_3 = torch.cat([nt_1, nt_2], dim=-1)
            return nt_3.values()

        assert gradcheck(
            grad_test_func,
            inputs=(values_1, values_2, offsets),
            check_batched_grad=False,
        )

    def test_is_contiguous(self, device):
        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
        nt_contiguous = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)

        starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
        lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
        narrow_base = (
            torch.arange(0, 10, device=device, dtype=torch.int64)
            .unsqueeze(0)
            .expand(5, -1)
            .clone()
        )
        nt_noncontiguous = torch.nested.narrow(
            narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged
        )

        starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64)
        lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64)
        nt_contiguous_narrow = torch.nested.narrow(
            narrow_base, 1, starts_c, lengths_c, layout=torch.jagged
        )

        # Test contiguous case
        assert nt_contiguous.is_contiguous()

        # Test narrow case
        assert not nt_noncontiguous.is_contiguous()
        assert nt_contiguous_narrow.is_contiguous()

        # Test querying by memory_format
        self.assertTrue(
            nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
        )
        self.assertTrue(
            not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
        )
        self.assertTrue(
            nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format)
        )

    def test_layout_under_torch_dispatch_mode(self):
        from torch.testing._internal.logging_tensor import (
            capture_logs_with_logging_tensor_mode,
        )

        nt = random_nt_from_dims(
            [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
        )

        with capture_logs_with_logging_tensor_mode():
            self.assertEqual(nt.layout, torch.jagged)

    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
    @parametrize(
        "func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__
    )
    def test_like_shape(self, func):
        nt = random_nt_from_dims(
            [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
        )
        nt_like = func(nt)

        for nt_ub in nt_like.unbind():
            t_like = func(nt_ub)
            self.assertEqual(nt_ub.shape, t_like.shape)

    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
    @parametrize(
        "func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__
    )
    def test_like_value(self, func):
        nt = random_nt_from_dims(
            [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
        )
        nt_like = func(nt)

        for nt_ub in nt_like.unbind():
            t_like = func(nt_ub)
            self.assertEqual(nt_ub, t_like)

    def test_noncontiguous_pointwise(self, device):
        a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
        b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
        c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
        nt = torch.nested.nested_tensor([a, b, c], layout=torch.jagged)
        # transpose ragged dim
        transposed = nt.transpose(1, 2)
        self.assertFalse(transposed.is_contiguous())
        clone = transposed.clone()

        def check_nt_equality(x, y):
            self.assertEqual(x.values(), y.values())
            self.assertEqual(x.offsets(), y.offsets())
            self.assertEqual(x._ragged_idx, y._ragged_idx)
            self.assertEqual(x.shape, y.shape)

        self.assertFalse(clone.is_contiguous())
        check_nt_equality(clone, transposed)

        clone_contig = transposed.clone(memory_format=torch.contiguous_format)
        self.assertTrue(clone_contig.is_contiguous())
        check_nt_equality(clone_contig, transposed)

        detached = transposed.detach()
        self.assertFalse(clone.is_contiguous())
        check_nt_equality(detached, transposed)

    def test_permute(self, device):
        nt = random_nt_from_dims(
            [2, None, 3, 5], device, torch.float32, layout=torch.jagged
        )
        nt_shape = nt.shape
        nt_inner_shape = nt.values().shape
        with self.assertRaisesRegex(
            ValueError,
            r"permute\(\): number of dimensions in the tensor input \(4\) "
            + r"does not match the length of the desired ordering of dimensions \(3\).",
        ):
            nt.permute(0, 2, 1)
        with self.assertRaisesRegex(
            ValueError, r"permute\(\): duplicate dims are not allowed."
        ):
            nt.permute(0, 2, -2, 3)
        with self.assertRaisesRegex(
            ValueError, "Permute is not supported on the batch dimension for jagged NT"
        ):
            nt.permute(1, 0, 2, 3)
        nt_permute = nt.permute(0, 2, 1, -1)
        self.assertEqual(
            nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3])
        )
        self.assertEqual(
            nt_permute.values().shape,
            (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]),
        )
        self.assertEqual(nt_permute._ragged_idx, 2)
        self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt)

    def test_to_dtype(self, device):
        nt = random_nt_from_dims(
            [2, None, 3], device, torch.float32, layout=torch.jagged
        )
        nt_after = nt.to(torch.float64)
        self.assertEqual(torch.float32, nt.dtype)
        self.assertEqual(torch.float64, nt_after.dtype)
        self.assertEqual(torch.float64, nt_after.values().dtype)
        self.assertEqual(torch.int64, nt_after.offsets().dtype)

        noncontiguous_nt = nt.transpose(1, 2)
        noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16)
        self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype)
        self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype)
        self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype)

    def test_to_copy(self, device):
        nt = torch.nested.nested_tensor(
            [
                torch.randn(
                    i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
                )
                for i in range(3)
            ],
            layout=torch.jagged,
        )

        nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16)
        self.assertEqual(torch.float16, nt_copy_dtype.dtype)

        nt_t = nt.transpose(1, 2)
        nt_t_copy_dtype = torch.ops.aten._to_copy(nt_t, dtype=torch.float16)
        self.assertEqual(torch.float16, nt_t_copy_dtype.dtype)

    def test_copy_(self, device):
        offsets = torch.tensor([0, 2, 4], device=device)
        a = torch.nested.nested_tensor_from_jagged(
            torch.zeros(4, 3, device=device), offsets
        )
        b = torch.nested.nested_tensor_from_jagged(
            torch.ones(4, 3, device=device), offsets
        )
        a.copy_(b)
        torch._dynamo.disable(self.assertEqual)(a, b)

        offsets_2 = torch.tensor([0, 2, 4], device=device)
        c = torch.nested.nested_tensor_from_jagged(
            torch.ones(4, 3, device=device), offsets_2
        )
        # fail when tensors have the same size but not the exact same offset tensor.
        with self.assertRaisesRegex(
            RuntimeError,
            "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
        ):
            a.copy_(c)

        # fail when tensors have different sizes
        a = a.transpose(1, 2)
        with self.assertRaisesRegex(
            RuntimeError,
            "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
        ):
            a.copy_(b)

    @skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()")
    def test_profiler_sequence_nr(self):
        with torch.profiler.profile() as prof:
            values = torch.randn(4, 6, requires_grad=True)
            offsets = torch.tensor([0, 2, 4])
            values = values * 2
            l = torch.nn.Linear(6, 8)
            nt = torch.nested.nested_tensor_from_jagged(values, offsets)

            nt = l(nt)
            val = nt.values()

            loss = val.sum()
            loss.backward()

        fwd_seq_nrs = []
        for evt in prof.events():
            if (
                "linear" in evt.name.lower()
                and "backward" not in evt.name.lower()
                and evt.sequence_nr != -1
            ):
                fwd_seq_nrs.append(evt.sequence_nr)

        bwd_seq_nrs = []
        for evt in prof.events():
            if (
                "linear" in evt.name.lower()
                and "backward" in evt.name.lower()
                and "evaluate_function" not in evt.name.lower()
                and evt.sequence_nr != -1
            ):
                bwd_seq_nrs.append(evt.sequence_nr)

        # There should only be one such event with a sequence number:
        # the PythonTLSSnapshot event - but, note that it's not terrible if
        # we end up with multiple events with the same sequence number - so we
        # could relax this check if it becomes inconvenient to maintain this
        # property.
        self.assertEqual(len(fwd_seq_nrs), 1)
        self.assertEqual(len(bwd_seq_nrs), 1)
        self.assertEqual(fwd_seq_nrs[0], bwd_seq_nrs[0])

    def test_is_same_size(self, device):
        def get_3_tensors():
            return [
                torch.randn(
                    i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
                )
                for i in range(3)
            ]

        nt1, offsets1 = jagged_from_list(get_3_tensors(), None)
        nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1)

        nt3, offsets2 = jagged_from_list(get_3_tensors(), None)
        nt4, offsets2 = jagged_from_list(get_3_tensors(), offsets2)

        def check_size(nt1, nt2, nt3, nt4):
            self.assertTrue(torch.ops.aten.is_same_size(nt1, nt2))
            self.assertTrue(torch.ops.aten.is_same_size(nt3, nt4))
            self.assertFalse(torch.ops.aten.is_same_size(nt1, nt3))

        check_size(nt1, nt2, nt3, nt4)

        nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4))
        check_size(nt1_t, nt2_t, nt3_t, nt4_t)

    @skipIfTorchDynamo("compiles internally")
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    def test_specialize_dynamic_shape(self, device):
        values = torch.randn((18, 16), device=device)
        offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device)
        like_values = torch.randn_like(values)

        # this marks values as dynamic
        nt = torch.nested.nested_tensor_from_jagged(values, offsets)

        def fn(values, same_size):
            # here, the dynamic shape is specialized by same_size's shape
            # https://github.com/pytorch/pytorch/issues/127097
            # make sure this doesn't error out in torch.compile
            return values + same_size

        self.assertEqual(
            fn(values, like_values),
            torch.compile(fn)(values, like_values),
        )

    @skipIfTorchDynamo("compiles internally")
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    def test_specialize_dynamic_shape_recompile(self, device):
        def generate_inp(total_len):
            values = torch.randn((total_len, 16), device=device)
            offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device)
            like_values = torch.randn_like(values)
            return values, offsets, like_values

        def check_results(ref_fn, res_fn, args):
            values, offsets, like_values = args
            # this may add dynamic shape markings
            # goal of this test is to make sure that whatever markings are there,
            # we eventually stop recompiling as shape changes.
            nt = torch.nested.nested_tensor_from_jagged(values, offsets)

            self.assertEqual(ref_fn(values, like_values), res_fn(values, like_values))

        def fn(values, same_size):
            return values + same_size

        compile_counter = torch._dynamo.testing.CompileCounter()

        compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn)
        check_results(fn, compiled_fn, generate_inp(18))
        self.assertEqual(compile_counter.frame_count, 1)

        check_results(fn, compiled_fn, generate_inp(19))
        # we'll probably recompile here with dynamic shapes - it's okay if not though.
        frame_count_2 = compile_counter.frame_count
        self.assertIn(frame_count_2, [1, 2])

        # make sure that by now we've already compiled with dynamic shapes, so additional
        # shapes should not trigger additional recompiles.
        check_results(fn, compiled_fn, generate_inp(20))
        self.assertEqual(compile_counter.frame_count, frame_count_2)

    # Note 1: Math fallback doesn't work with bfloat16 on CUDA
    # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT
    @unittest.skipIf(
        TEST_WITH_ROCM,
        "ROCm doesn't support flash attention or mem_efficient attention for NT",
    )
    @dtypes(
        *(
            [torch.float16, torch.bfloat16, torch.float32]
            if SM80OrLater
            else [torch.float16, torch.float32]
        )
    )
    def test_sdpa(self, device, dtype):
        batch_size = 1
        emb_dims = 128
        n_heads = 8
        head_dims = emb_dims // n_heads

        sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
        sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)

        query = torch.nn.Linear(
            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
        )
        key = torch.nn.Linear(
            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
        )
        value = torch.nn.Linear(
            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
        )

        # Simplest case: 1 sentence, no batching
        x_d1 = sen1.unsqueeze(0)
        x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged)

        # See note below for why we detach here.
        q_d1 = (
            query(x_d1)
            .view(batch_size, -1, n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        q_d1_t = q_d1.transpose(1, 2)
        k_d1 = (
            key(x_d1)
            .view(batch_size, -1, n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        k_d1_t = k_d1.transpose(1, 2)
        v_d1 = (
            value(x_d1)
            .view(batch_size, -1, n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        v_d1_t = v_d1.transpose(1, 2)

        q_nt = (
            query(x_nt)
            .view(*x_nt.size()[0:2], n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        q_nt_t = q_nt.transpose(1, 2)
        k_nt = (
            key(x_nt)
            .view(*x_nt.size()[0:2], n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        k_nt_t = k_nt.transpose(1, 2)
        v_nt = (
            value(x_nt)
            .view(*x_nt.size()[0:2], n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        v_nt_t = v_nt.transpose(1, 2)

        # High Precision Math Reference
        q_d1_f32 = q_d1.to(torch.float32)
        k_d1_f32 = k_d1.to(torch.float32)
        v_d1_f32 = v_d1.to(torch.float32)
        q_d1_f32_t = q_d1_f32.transpose(1, 2)
        k_d1_f32_t = k_d1_f32.transpose(1, 2)
        v_d1_f32_t = v_d1_f32.transpose(1, 2)
        out_ref = torch.ops.aten._scaled_dot_product_attention_math(
            q_d1_f32_t, k_d1_f32_t, v_d1_f32_t
        )[0]
        grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32))

        # Low Precision Math Reference
        out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
            q_d1_t, k_d1_t, v_d1_t
        )[0]
        grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1))

        # Compute tolerances
        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(grads_ref[0], grads_lp_ref[0])
        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1])
        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2])
        grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol]
        grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol]

        attn_d1 = torch.nn.functional.scaled_dot_product_attention(
            q_d1_t, k_d1_t, v_d1_t
        ).transpose(1, 2)
        attn_nt = torch.nn.functional.scaled_dot_product_attention(
            q_nt_t, k_nt_t, v_nt_t
        ).transpose(1, 2)

        self.assertEqual(
            attn_d1,
            attn_nt.unbind()[0].unsqueeze(0),
            atol=output_ref_atol,
            rtol=output_ref_rtol,
        )

        # Simple case: 2 sentences, no extra params
        x_d2 = sen2.unsqueeze(0)
        x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)

        # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before
        # it is transposed. This is because today we cannot backward through view or unbind a
        # transposed tensor.
        q_d2 = (
            query(x_d2)
            .view(batch_size, -1, n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        q_d2_t = q_d2.transpose(1, 2)
        k_d2 = (
            key(x_d2)
            .view(batch_size, -1, n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        k_d2_t = k_d2.transpose(1, 2)
        v_d2 = (
            value(x_d2)
            .view(batch_size, -1, n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        v_d2_t = v_d2.transpose(1, 2)

        q_nt = (
            query(x_nt)
            .view(*x_nt.size()[0:2], n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        q_nt_t = q_nt.transpose(1, 2)
        k_nt = (
            key(x_nt)
            .view(*x_nt.size()[0:2], n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        k_nt_t = k_nt.transpose(1, 2)
        v_nt = (
            value(x_nt)
            .view(*x_nt.size()[0:2], n_heads, head_dims)
            .detach()
            .requires_grad_(True)
        )
        v_nt_t = v_nt.transpose(1, 2)

        attn_d2 = torch.nn.functional.scaled_dot_product_attention(
            q_d2_t, k_d2_t, v_d2_t
        ).transpose(1, 2)
        d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1))
        d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2))

        # Simple case 3: batch_size = 1, seq_len = 1
        q_3 = torch.randn(1, 8, 16, dtype=dtype, device=device)
        q_nt_3 = torch.nested.as_nested_tensor([q_3], layout=torch.jagged)
        q_nt_3 = q_nt_3.transpose(1, 2)
        attn_out = torch.nn.functional.scaled_dot_product_attention(
            q_nt_3, q_nt_3, q_nt_3
        )
        self.assertEqual(attn_out.shape, q_nt_3.shape)

        def check_forward_backward():
            attn_nt = torch.nn.functional.scaled_dot_product_attention(
                q_nt_t, k_nt_t, v_nt_t
            ).transpose(1, 2)

            attn_nts = attn_nt.unbind()
            self.assertEqual(
                attn_d1,
                attn_nts[0].unsqueeze(0),
                atol=output_ref_atol,
                rtol=output_ref_rtol,
            )
            self.assertEqual(
                attn_d2,
                attn_nts[1].unsqueeze(0),
                atol=output_ref_atol,
                rtol=output_ref_rtol,
            )

            nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt))
            for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(
                nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols
            ):
                unbound_nt_grads = nt_grad.unbind()
                self.assertEqual(
                    d1_grad,
                    unbound_nt_grads[0].unsqueeze(0),
                    atol=grad_atol,
                    rtol=grad_rtol,
                )
                self.assertEqual(
                    d2_grad,
                    unbound_nt_grads[1].unsqueeze(0),
                    atol=grad_atol,
                    rtol=grad_rtol,
                )

        # Default
        check_forward_backward()

        # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices)
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_mem_efficient=True, enable_math=True
        ):
            check_forward_backward()

        # Test math fallback
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_mem_efficient=False, enable_math=True
        ):
            # Math fallback doesn't work with bfloat16 on CUDA because
            # "group_gemm_dispatch" not implemented for 'BFloat16'
            if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
                check_forward_backward()

    @skipIfTorchDynamo("SDPA test compiles internally")
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    # Guarding with sqrt() doesn't work on ROCm?
    @skipCUDAIfRocm
    @onlyCUDA
    @dtypes(
        *(
            [torch.float16, torch.bfloat16, torch.float32]
            if SM80OrLater
            else [torch.float16, torch.float32]
        )
    )
    def test_sdpa_compile(self, device, dtype):
        batch_size = 1
        emb_dims = 1024
        n_heads = 8
        head_dims = emb_dims // n_heads

        sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
        sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)

        query = torch.nn.Linear(
            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
        )
        key = torch.nn.Linear(
            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
        )
        value = torch.nn.Linear(
            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
        )

        # Simplest case: 1 sentence, no batching
        x_d1 = sen1.unsqueeze(0)
        x_d2 = sen2.unsqueeze(0)
        x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)

        q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
        k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
        v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
        q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
        k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
        v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)

        q_nt = (
            query(x_nt)
            .view(*x_nt.size()[0:2], n_heads, head_dims)
            .detach()
            .transpose(1, 2)
        )
        k_nt = (
            key(x_nt)
            .view(*x_nt.size()[0:2], n_heads, head_dims)
            .detach()
            .transpose(1, 2)
        )
        v_nt = (
            value(x_nt)
            .view(*x_nt.size()[0:2], n_heads, head_dims)
            .detach()
            .transpose(1, 2)
        )

        # High Precision Math Reference
        q_d1_f32 = q_d1.to(torch.float32)
        k_d1_f32 = k_d1.to(torch.float32)
        v_d1_f32 = v_d1.to(torch.float32)
        out_ref = torch.ops.aten._scaled_dot_product_attention_math(
            q_d1_f32, k_d1_f32, v_d1_f32
        )[0]
        # Low Precision Math Reference
        out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
            q_d1, k_d1, v_d1
        )[0]
        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)

        attn_d1 = torch.nn.functional.scaled_dot_product_attention(
            q_d1, k_d1, v_d1
        ).transpose(1, 2)
        attn_d2 = torch.nn.functional.scaled_dot_product_attention(
            q_d2, k_d2, v_d2
        ).transpose(1, 2)

        compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention)
        attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2)

        attn_nts = attn_nt.unbind()
        self.assertEqual(
            attn_d1,
            attn_nts[0].unsqueeze(0),
            atol=output_ref_atol,
            rtol=output_ref_rtol,
        )
        self.assertEqual(
            attn_d2,
            attn_nts[1].unsqueeze(0),
            atol=output_ref_atol,
            rtol=output_ref_rtol,
        )

    @dtypes(torch.float32, torch.double, torch.half)
    def test_sdpa_with_constant_sequence_length(self, device, dtype):
        # shape (B, P*, S, D)
        # B: batch size
        # P*: ragged number of prompts
        # S: (constant) sequence length
        # D: embedding size
        query = random_nt_from_dims(
            [4, None, 8, 10],
            device=device,
            dtype=dtype,
            layout=torch.jagged,
            requires_grad=True,
        )
        key = random_nt_from_similar(query)
        value = random_nt_from_similar(query)
        output = F.scaled_dot_product_attention(query, key, value)
        self.assertTrue(isinstance(output, NestedTensor))
        output.values().sum().backward()

        query_dense = query.clone().detach().requires_grad_(True)
        # should be equivalent to just running the buffers through
        output_dense = F.scaled_dot_product_attention(
            query_dense.values(), key.values(), value.values()
        )
        torch._dynamo.disable(self.assertEqual)(output._values, output_dense)
        output_dense.sum().backward()
        torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad)

    @onlyCUDA
    @unittest.skipIf(
        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
        "Platform doesn't support flash or mem-efficient attention",
    )
    @dtypes(
        *(
            [torch.float16, torch.bfloat16, torch.float32]
            if SM80OrLater
            else [torch.float16, torch.float32]
        )
    )
    def test_sdpa_with_packed_in_proj(self, device, dtype):
        # shape (B, *, D)
        input_packed = random_nt_from_dims(
            [5, None, 10], device=device, dtype=dtype, layout=torch.jagged
        )

        # Do input projection.
        num_heads = 2
        # should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient)
        head_dim = 8
        qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to(
            device=device, dtype=dtype
        )

        def in_proj(input_packed, qkv_linear=qkv_linear):
            qkv_post_proj = qkv_linear(input_packed)
            # these are non-contiguous to trigger _is_safe_to_get_storage_as_tensor()
            q, k, v = qkv_post_proj.chunk(3, dim=-1)
            q = q.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
            k = k.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
            v = v.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
            return q, k, v

        q, k, v = in_proj(input_packed)
        output = F.scaled_dot_product_attention(q, k, v, attn_mask=None)

        # compare to individually running unbound components through
        for in_component, out_component in zip(
            input_packed.unbind(), output.transpose(-2, -3).unbind()
        ):
            q, k, v = in_proj(in_component)
            out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3)

            # Low Precision Math Reference
            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[
                0
            ].transpose(-2, -3)
            output_ref_atol, output_ref_rtol = get_tolerances(
                out, out_lp_ref, fudge_factor=2
            )

            self.assertEqual(
                out, out_component, atol=output_ref_atol, rtol=output_ref_rtol
            )

    @skipIfTorchDynamo("SDPA test compiles internally")
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    # mha_varlen_fwd not supported on ROCm
    @skipCUDAIfRocm
    @onlyCUDA
    @dtypes(
        *(
            [torch.float16, torch.bfloat16, torch.float32]
            if SM80OrLater
            else [torch.float16, torch.float32]
        )
    )
    def test_sdpa_backwards(self, device, dtype):
        values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype)
        offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64)

        @torch.compile
        def f(values, offsets):
            nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
            nt = nt.transpose(-2, -3)
            # purposefully graph break to trigger view replay for subclass view input
            torch.tensor(1).item()
            output = F.scaled_dot_product_attention(nt, nt, nt).transpose(-2, -3)
            return convert_nt_to_jagged(output)

        output = f(values, offsets)
        output.sum().backward()
        self.assertEqual(values.grad, torch.ones_like(values))

    @unittest.skipIf(
        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
        "Platform doesn't support flash or mem-efficient attention",
    )
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    @skipCUDAIfRocm
    @onlyCUDA
    @skipIfTorchDynamo()
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    def test_sdpa_autocast(self, device):
        def fn_nt(values32, values16, offsets):
            nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16)
            nt16 = convert_jagged_to_nested_tensor(values16, offsets, max_length=16)
            nt32 = nt32.transpose(1, 2)
            nt16 = nt16.transpose(1, 2)
            return F.scaled_dot_product_attention(nt32, nt16, nt32)

        def fn_dense(x32, x16):
            x32 = x32.view(8, 16, 4, 16).transpose(1, 2)
            x16 = x16.view(8, 16, 4, 16).transpose(1, 2)
            return F.scaled_dot_product_attention(x32, x16, x32)

        values32 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float32)
        values16 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float16)
        offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)

        x32 = values32.clone()
        x16 = values16.clone()

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            out_dense_eager = fn_dense(x32, x16)
            out_dense_compiled = torch.compile(fn_dense)(x32, x16)
            out_nt_eager = fn_nt(values32, values16, offsets)
            out_nt_compiled = torch.compile(fn_nt)(values32, values16, offsets)

        self.assertEqual(out_dense_eager, out_dense_compiled)
        self.assertEqual(
            out_dense_eager.transpose(1, 2),
            out_nt_eager.values().transpose(0, 1).view(8, 16, 4, 16),
        )
        self.assertEqual(
            out_dense_eager.transpose(1, 2),
            out_nt_compiled.values().transpose(0, 1).view(8, 16, 4, 16),
        )

        def get_values():
            return tuple(
                x.clone().detach().requires_grad_(True) for x in (values32, values16)
            )

        v32_dense_eager, v16_dense_eager = get_values()
        v32_dense_compile, v16_dense_compile = get_values()
        v32_nt_eager, v16_nt_eager = get_values()
        v32_nt_compile, v16_nt_compile = get_values()

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            loss_dense_eager = fn_dense(v32_dense_eager, v16_dense_eager).sum()
            loss_dense_compile = torch.compile(fn_dense)(
                v32_dense_compile, v16_dense_compile
            ).sum()
            loss_nt_eager = fn_nt(v32_nt_eager, v16_nt_eager, offsets).values().sum()
            loss_nt_compile = (
                torch.compile(fn_nt)(v32_nt_compile, v16_nt_compile, offsets)
                .values()
                .sum()
            )

        loss_dense_eager.backward()
        loss_dense_compile.backward()
        loss_nt_eager.backward()
        loss_nt_compile.backward()

        self.assertEqual(v32_dense_eager.grad, v32_dense_compile.grad)
        self.assertEqual(v32_dense_eager.grad, v32_nt_eager.grad)
        self.assertEqual(v32_dense_eager.grad, v32_nt_compile.grad)

        self.assertEqual(v16_dense_eager.grad, v16_dense_compile.grad)
        self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad)
        self.assertEqual(v16_dense_eager.grad, v16_nt_compile.grad)

    @unittest.skipIf(
        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
        "Platform doesn't support flash or mem-efficient attention",
    )
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    @skipCUDAIfRocm
    @onlyCUDA
    @skipIfTorchDynamo()
    def test_sdpa_flop_counter(self, device):
        from torch.utils.flop_counter import FlopCounterMode

        def get_flops(nt):
            flop_counter = FlopCounterMode(display=False)
            with flop_counter:
                ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt)
                ret.values().sum().backward()
            return flop_counter.get_total_flops()

        values = torch.randn(
            (8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16
        )
        offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
        nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16)

        values_meta = torch.randn(
            (8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16
        )
        offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32)
        nt_meta = convert_jagged_to_nested_tensor(values, offsets, max_length=16)

        self.assertEqual(get_flops(nt), get_flops(nt_meta))

    @skipIfTorchDynamo()
    def test_nested_tensor_activation_checkpoint(self, device):
        values = torch.randn(
            9, 3, 256, requires_grad=True, device=device, dtype=torch.float32
        )
        lengths = torch.tensor([1, 2, 3, 3], device=device, dtype=torch.int64)
        offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)

        def fn(values, offsets):
            nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
            return convert_nt_to_jagged(nt).sum()

        checkpoint(fn, values, offsets, use_reentrant=False).backward()
        self.assertIsNotNone(values.grad)

        context_fn = partial(
            create_selective_checkpoint_contexts, [torch.ops.aten.cumsum.default]
        )

        values.grad = None

        def fn(values, lengths):
            offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)
            nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
            return convert_nt_to_jagged(nt).sum()

        checkpoint(
            fn, values, lengths, use_reentrant=False, context_fn=context_fn
        ).backward()
        self.assertIsNotNone(values.grad)

    # Internally-defined NT use cases are lifted to here for maximum test realism.
    # TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated.
    @skipCUDAIfRocm  # not needed
    @skipIfTorchDynamo("compiles internally")
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    @parametrize("use_legacy_api", [True, False])
    @skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
    def test_dummy_mha_with_nt(self, device, use_legacy_api):
        bs = 3
        d1 = 2
        d2 = 4
        d3 = 16
        n_heads = 2
        d_head = d3 // n_heads
        max_length_1 = 10
        max_length_2 = 20
        torch.manual_seed(0)

        class mha(torch.nn.Module):
            def __init__(self, use_legacy_api) -> None:
                super().__init__()
                torch.manual_seed(0)
                self.linear = torch.nn.Linear(d2, d3, device=device)
                self.use_legacy_api = use_legacy_api

            def forward(self, query, value, offsets):
                value = self.linear(value)
                if self.use_legacy_api:
                    key = convert_jagged_to_nested_tensor_legacy(
                        value, offsets, max_length_1
                    )
                    value = convert_jagged_to_nested_tensor_legacy(
                        value, offsets, max_length_2
                    )
                    query = convert_dense_to_nested_tensor_legacy(query)
                else:
                    key = convert_jagged_to_nested_tensor(value, offsets, max_length_1)
                    value = convert_jagged_to_nested_tensor(
                        value, offsets, max_length_2
                    )
                    query = convert_dense_to_nested_tensor(query)
                q = query.view(bs, -1, n_heads, d_head).transpose(1, 2)
                k = key.view(bs, -1, n_heads, d_head).transpose(1, 2)
                v = value.view(bs, -1, n_heads, d_head).transpose(1, 2)

                with torch.nn.attention.sdpa_kernel(
                    [
                        torch.nn.attention.SDPBackend.FLASH_ATTENTION,
                        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
                    ]
                ):
                    attn_output = torch.nn.functional.scaled_dot_product_attention(
                        q,
                        k,
                        v,
                        attn_mask=None,
                        dropout_p=0.0,
                        is_causal=False,
                    )
                attn_output = attn_output.transpose(1, 2)
                if self.use_legacy_api:
                    attn_output = convert_nt_to_jagged_legacy(attn_output)
                else:
                    attn_output = convert_nt_to_jagged(attn_output)
                return attn_output, key._max_seqlen, value._max_seqlen

        query = torch.rand(bs, d1, d3, device=device)
        value = torch.rand(30, d2, requires_grad=True, device=device)
        # total_length must > than max_length otherwise flash_attn backwark will fail
        offsets = torch.tensor([0, 2, 3, 30], device=device)

        m = mha(use_legacy_api)
        symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
        m = torch.compile(symbolic_traced)
        attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m(
            query, value, offsets
        )
        loss = attn_output.sum()
        # Check that NT can be fx traced and torch.compile, and backward works
        loss.backward()

        # Check that value.requires_grad is not lost after tracing and compiling
        value_grad = value.grad  # save for comparison later
        self.assertIsNotNone(value_grad)
        # check that max_seqlen is cached properly
        self.assertEqual(cached_key_max_seqlen, max_length_1)
        self.assertEqual(cached_value_max_seqlen, max_length_2)

        # check if the output is numerically equivalent with the eager mode
        m_eager = mha(use_legacy_api)

        value.grad = None
        attn_output_eager, _, _ = m_eager(query, value, offsets)
        attn_output_eager.sum().backward()
        self.assertTrue(torch.allclose(attn_output_eager, attn_output))
        self.assertTrue(torch.allclose(value_grad, value.grad))

    @dtypes(torch.float32)
    def test_apply_(self, device, dtype):
        nt = random_nt_from_dims(
            [5, None, 10],
            device=device,
            dtype=dtype,
            layout=torch.jagged,
            requires_grad=True,
        )

        def f(x):
            return x * 2

        if device != "cpu":
            with self.assertRaisesRegex(
                TypeError, "apply_ is only implemented on CPU tensors"
            ):
                nt.apply_(f)
            return

        before = nt._values.clone().detach()

        nt.apply_(f)
        expected = f(before)
        self.assertEqual(expected, nt._values)
        # apply_ should swap values in-place without appending to autograd graph
        self.assertIsNone(nt.grad)
        self.assertIsNone(nt._values.grad_fn)

    @dtypes(torch.float64, torch.float32, torch.half)
    def test_jagged_padded_dense_conversion_kernels(self, device, dtype):
        values = torch.randn(10, 5, device=device, dtype=dtype)
        offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64)
        max_length = offsets.diff().max().item()
        padding_value = 1.3

        # convert jagged -> padded dense
        padded = torch.ops.aten._jagged_to_padded_dense_forward(
            values, [offsets], [max_length], padding_value
        )

        batch_size = offsets.shape[0] - 1
        expected_padded_shape = (batch_size, max_length, values.shape[-1])
        self.assertEqual(padded.shape, expected_padded_shape)

        # convert padded dense -> jagged
        total_L = values.shape[0]
        output_jagged = torch.ops.aten._padded_dense_to_jagged_forward(
            padded, [offsets], total_L
        )

        # should be equivalent to the original values
        self.assertEqual(values, output_jagged)

        # success case: truncate to max length as needed
        trunc_max_length = max_length - 1
        trunc_padded = torch.ops.aten._jagged_to_padded_dense_forward(
            values, [offsets], [trunc_max_length], padding_value
        )
        self.assertEqual(padded[:, :trunc_max_length, :], trunc_padded)

        # specific to CPU impls
        if device == "cpu":
            # error case: multiple offsets on cpu since CPU kernels don't support more now
            with self.assertRaisesRegex(
                RuntimeError, "only a single jagged dim is supported"
            ):
                torch.ops.aten._jagged_to_padded_dense_forward(
                    values, [offsets, offsets], [max_length, max_length], padding_value
                )

            with self.assertRaisesRegex(
                RuntimeError, "only a single jagged dim is supported"
            ):
                torch.ops.aten._padded_dense_to_jagged_forward(
                    padded, [offsets, offsets], total_L
                )

            # error case: > 1D offsets
            offsets2d = offsets.unsqueeze(-1)
            with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
                torch.ops.aten._jagged_to_padded_dense_forward(
                    values, [offsets2d], [max_length], padding_value
                )

            with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
                torch.ops.aten._padded_dense_to_jagged_forward(
                    padded, [offsets2d], total_L
                )

            # error case: final offset != total_L
            offsets_wrong = offsets.clone().detach()
            offsets_wrong[-1] = total_L + 1
            with self.assertRaisesRegex(
                RuntimeError, "final offset should match total_L value"
            ):
                torch.ops.aten._padded_dense_to_jagged_forward(
                    padded, [offsets_wrong], total_L
                )

            # error case: 1D padded input
            padded_wrong = padded.flatten().clone().detach()
            with self.assertRaisesRegex(RuntimeError, "expected padded dim >= 2"):
                torch.ops.aten._padded_dense_to_jagged_forward(
                    padded_wrong, [offsets], total_L
                )

            # error case: batch item has length > max length
            # max_length is 5 above; 7 here
            offsets_wrong = torch.tensor(
                [0, 1, 8, 9, 10], device=device, dtype=torch.int64
            )
            with self.assertRaisesRegex(RuntimeError, "found batch item of length"):
                torch.ops.aten._padded_dense_to_jagged_forward(
                    padded, [offsets_wrong], total_L
                )

    @dtypes(torch.float32)
    @skipIfTorchDynamo("Test compiles internally")
    @unittest.skipIf(
        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
    )
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    @skipCUDAIfRocm
    def test_compile_preserves_metadata_cache(self, device, dtype):
        # shape (B, *, D)
        nt = random_nt_from_dims(
            [4, None, 3, 16],
            device=device,
            dtype=dtype,
            layout=torch.jagged,
            requires_grad=True,
        )

        # expect min / max seqlen to be stored here
        cache = dict(nt._metadata_cache)

        @torch.compile
        def f(nt):
            q = nt.transpose(-3, -2)
            output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2)
            return output

        output = f(nt)
        output.backward(torch.ones_like(output))
        self.assertEqual(output._metadata_cache, cache)

    @dtypes(torch.float32)
    @skipIfTorchDynamo("Test compiles internally")
    @unittest.skipIf(
        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
    )
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    @skipCUDAIfRocm
    def test_compile_with_dynamic_max_seq_len(self, device, dtype):
        # shape (B, *, D)
        # max seq len: 18
        nt = torch.nested.nested_tensor(
            [
                torch.randn(2, 5),
                torch.randn(3, 5),
                torch.randn(18, 5),
            ],
            layout=torch.jagged,
        )

        # max seq len: 19
        nt2 = torch.nested.nested_tensor(
            [
                torch.randn(2, 5),
                torch.randn(3, 5),
                torch.randn(19, 5),
            ],
            layout=torch.jagged,
        )

        def f(nt):
            # TODO: Replace with public API when we can use @properties
            return torch.ones_like(nt) * nt._get_max_seqlen()

        for dynamic in [False, True, None]:
            self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))

    @dtypes(torch.float32)
    @skipIfTorchDynamo("Test compiles internally")
    @unittest.skipIf(
        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
    )
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    @skipCUDAIfRocm
    def test_compile_with_dynamic_min_seq_len(self, device, dtype):
        # shape (B, *, D)
        # min seq len: 7
        nt = torch.nested.nested_tensor(
            [
                torch.randn(7, 5),
                torch.randn(8, 5),
                torch.randn(9, 5),
            ],
            layout=torch.jagged,
        )

        # min seq len: 8
        nt2 = torch.nested.nested_tensor(
            [
                torch.randn(8, 5),
                torch.randn(9, 5),
                torch.randn(10, 5),
            ],
            layout=torch.jagged,
        )

        def f(nt):
            # TODO: Replace with public API when we can use @properties
            return torch.ones_like(nt) * nt._get_min_seqlen()

        for dynamic in [False, True, None]:
            self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))

    @dtypes(torch.float32)
    @skipIfTorchDynamo("Test compiles internally")
    @unittest.skipIf(
        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
    )
    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
    @skipCUDAIfRocm
    def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype):
        # shape (B, *, D)
        # max seq len: 18
        nt = torch.nested.nested_tensor(
            [
                torch.randn(2, 5),
                torch.randn(3, 5),
                torch.randn(18, 5),
            ],
            layout=torch.jagged,
        )

        # max seq len: 19
        nt2 = torch.nested.nested_tensor(
            [
                torch.randn(2, 5),
                torch.randn(3, 5),
                torch.randn(19, 5),
            ],
            layout=torch.jagged,
        )

        def f(nt):
            nt2 = nt.sin() + 1
            # TODO: Replace with public API when we can use @properties
            return torch.ones_like(nt2) * nt2._get_max_seqlen()

        ref = f(nt)
        output = torch.compile(f, fullgraph=True, dynamic=False)(nt)
        self.assertEqual(ref, output)

        for dynamic in [False, True, None]:
            self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))

    @dtypes(torch.float32, torch.double, torch.half)
    def test_unbind_backward(self, device, dtype):
        nt = torch.nested.nested_tensor(
            [
                torch.randn(2, 4, device=device),
                torch.randn(5, 4, device=device),
                torch.randn(3, 4, device=device),
            ],
            layout=torch.jagged,
            requires_grad=True,
        )

        a, b, c = nt.unbind()
        b.sum().backward()

        @torch._dynamo.disable
        def check(nt):
            expected_grad = torch.zeros_like(nt)
            expected_grad.unbind()[1].add_(1.0)
            self.assertEqual(nt.grad, expected_grad)

        check(nt)


FORWARD_FAILURES = {
    # === BEGIN NotImplementedError SECTION ===
    # unary
    "nn.functional.celu",
    "nn.functional.elu",
    "nn.functional.hardshrink",
    "nn.functional.hardsigmoid",
    "nn.functional.hardtanh",
    "nn.functional.logsigmoid",
    "nn.functional.mish",
    "nn.functional.relu6",
    "nn.functional.rrelu",
    "nn.functional.selu",
    "nn.functional.softplus",
    "nn.functional.softshrink",
    "nn.functional.threshold",
    "rad2deg",
    # binary
    "__rsub__",
    "complex",
    "floor_divide",
    "polar",
    "rsub",
    # reduction
    "all",
    "amax",
    "amin",
    "any",
    "argmax",
    "argmin",
    "count_nonzero",
    "linalg.vector_norm",
    "nansum",
    "std",
    "std.unbiased",
    "var",
    "var.unbiased",
    # === BEGIN UNSUPPORTED SECTION ===
    # RuntimeError: mean(): not supported for NestedTensor on dim=1
    "mean",
    # ValueError: expects strided tensor (got torch.jagged tensor)
    "masked.amax",
    "masked.amin",
    "masked.argmax",
    "masked.argmin",
    "masked.logsumexp",
    "masked.mean",
    "masked.norm",
    "masked.prod",
    "masked.std",
    "masked.sum",
    "masked.var",
    # === BEGIN BUG SECTION ===
    # Returns a tuple of Tensors so it doesn't work with NJT's unary pointwise logic
    "frexp",
    # Need to adjust sample input func to pass the right thing
    "nn.functional.prelu",
    # TypeError: fill() received an invalid combination of arguments
    # got (NestedTensor), but expected one of:
    # * (Tensor input, Tensor value)
    # * (Tensor input, Number value)
    "fill",
    # RuntimeError: unsupported tensor layout: Jagged
    "jiterator_binary",
    "jiterator_binary_return_by_ref",
    "jiterator_unary",
    # Bug found: sum() with keepdim=True returns invalid shape
    "sum",
    # RuntimeError: prod(): keepdim=True must be set for NestedTensor
    "prod",
    # RuntimeError: "jagged_to_padded_dense" not implemented for 'Bool'
    "nanmean",
}

BACKWARD_FAILURES = {
    *FORWARD_FAILURES,
    # TODO: categorize these
    "__rpow__",
    "atanh",
    "cdouble",
    "cfloat",
    "chalf",
    "clamp_max",
    "clamp_min",
    "copysign",
    "float_power",
    "max.binary",
    "maximum",
    "min.binary",
    "minimum",
    "pow",
    "sgn",
    "sinc",
    "special.i1",
    "special.i1e",
    # clone() on a "non-contiguous with holes" NJT allocates a new offsets -> new nested int
    # RuntimeError: Function CloneBackward0 returned an invalid gradient at index 0 -
    # got [3, j29, 5] but expected shape compatible with [3, j28, 5]
    "clone",
    # Calling into torch.ops.aten.size directly
    "masked_select",
}

COMPILE_FORWARD_FAILURES = {
    *FORWARD_FAILURES,
    # clone() on non-contiguous with holes NJTs currently use unbind(), leading to
    # data-dependent error in torch.compile
    "clone",
}

COMPARE_TENSOR_COMPONENT_EQUALITY = {
    # masked_select is expected to output a different shape
    "masked_select",
}


def withXFails(failure_list):
    return decorateIf(
        unittest.expectedFailure,
        lambda params: params["op"].full_name in failure_list,
    )


# OpInfo-based NJT tests. These tests utilize an NJT-specific op_db generated from the standard
# op_db. Note that certain tradeoffs were made wrt coverage vs. time spent running tests:
#   * All tests run with dtype=torch.float32 only
class TestNestedTensorOpInfo(NestedTensorTestCase):
    # TODO: move this
    def _gen_grad_outputs(self, out_val):
        if isinstance(out_val, (list, tuple)):
            return tuple(torch.ones_like(c) for c in out_val)
        else:
            return (torch.ones_like(out_val),)

    @withXFails(FORWARD_FAILURES)
    @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,))
    def test_forward(self, device, dtype, op):
        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False):
            # compare to reference, but expect different nested int
            out = op.op(sample.input, *sample.args, **sample.kwargs)
            out_ref = op.ref(op, sample)
            self.assertEqualIgnoringNestedInts(out, out_ref)

    @withXFails(BACKWARD_FAILURES)
    @ops(
        [op for op in njt_op_db if op.supports_njt and op.supports_autograd],
        allowed_dtypes=(torch.float32,),
    )
    def test_backward(self, device, dtype, op):
        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True):
            # compare to reference, but expect different nested int
            out = op.op(sample.input, *sample.args, **sample.kwargs)
            out_ref = op.ref(op, sample)
            self.assertEqualIgnoringNestedInts(out, out_ref)

            inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
            g_inps = [
                inp
                for inp in inps
                if isinstance(inp, torch.Tensor) and inp.requires_grad
            ]
            if len(g_inps) > 0:
                grads = torch.autograd.grad(
                    out, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out)
                )

                grads_ref = torch.autograd.grad(
                    out_ref,
                    inputs=g_inps,
                    grad_outputs=self._gen_grad_outputs(out_ref),
                )

                self.assertEqual(grads, grads_ref)

    @withXFails(COMPILE_FORWARD_FAILURES)
    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
    @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,))
    def test_compile_forward(self, device, dtype, op):
        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False):
            torch.compiler.reset()

            op_fn = op.op

            def f(*args, **kwargs):
                return op_fn(*args, **kwargs)

            compiled_f = torch.compile(
                f, fullgraph=True, backend="aot_eager_decomp_partition"
            )

            out_ref = f(sample.input, *sample.args, **sample.kwargs)
            out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)

            if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
                self.assertEqualIgnoringNestedInts(out_compile, out_ref)
            else:
                self.assertEqual(out_compile, out_ref)

    @withXFails(BACKWARD_FAILURES)
    @ops(
        [op for op in njt_op_db if op.supports_njt and op.supports_autograd],
        allowed_dtypes=(torch.float32,),
    )
    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
    def test_compile_backward(self, device, dtype, op):
        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True):
            torch.compiler.reset()

            op_fn = op.op

            def f(*args, **kwargs):
                return op_fn(*args, **kwargs)

            compiled_f = torch.compile(
                f, fullgraph=True, backend="aot_eager_decomp_partition"
            )

            out_ref = f(sample.input, *sample.args, **sample.kwargs)
            out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)

            self.assertEqual(out_compile, out_ref)

            inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
            g_inps = [
                inp
                for inp in inps
                if isinstance(inp, torch.Tensor) and inp.requires_grad
            ]
            if len(g_inps) > 0:
                grads_compile = torch.autograd.grad(
                    out_compile,
                    inputs=g_inps,
                    grad_outputs=self._gen_grad_outputs(out_compile),
                )

                grads_ref = torch.autograd.grad(
                    out_ref, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out_ref)
                )

                self.assertEqual(grads_compile, grads_ref)


instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
instantiate_device_type_tests(TestNestedTensorSubclass, globals())
instantiate_device_type_tests(TestNestedTensorOpInfo, globals())

if __name__ == "__main__":
    run_tests()
