# Owner(s): ["module: named tensor"]

import unittest
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import get_all_device_types
from collections import namedtuple, OrderedDict
import itertools
import functools
import torch
from torch import Tensor
import torch.nn.functional as F
from multiprocessing.reduction import ForkingPickler
import pickle
import io
import sys
import warnings


def pass_name_to_python_arg_parser(name):
    x = torch.empty(2, names=(name,))


def flatten(lst):
    return [item for sublist in lst for item in sublist]


Function = namedtuple('TestCase', ['name', 'lambd'])


def parse_compressed_namedshape(string):
    # This is a metalanguage for describing a shape of a tensor compactly.
    # 'N:3,C:2' -> size = [3, 2], names: ['N', 'C']
    # 'None:3,None:2' -> size = [3, 2], names: ['None', 'None']
    # '3,2' -> size = [3, 2], names=None passed to ctor.
    def parse_name(maybe_name):
        maybe_name = maybe_name.strip()
        if maybe_name == 'None':
            return None
        return maybe_name

    string = string.strip()

    # '' -> size: [], names:None
    if len(string) == 0:
        return None, []

    # '3, 2' -> size = [3, 2], None names.
    if ':' not in string:
        return None, [int(size) for size in string.split(',')]

    dims = string.split(',')
    tuples = [dim.split(':') for dim in dims]
    return zip(*[(parse_name(name), int(size)) for name, size in tuples])


def create(namedshape, factory=torch.randn):
    # namedshape: str
    names, shape = parse_compressed_namedshape(namedshape)
    return factory(shape, names=names)


def out_fn(operator):
    @functools.wraps(operator)
    def fn(*inputs):
        return operator(*inputs[1:], out=inputs[0])
    return fn


class TestNamedTensor(TestCase):
    def test_aaa_must_run_first_check_experimental_warning(self):
        # TODO(rzou): It would be nice for this to be a "real" python warning.
        # Right now this error message only prints once and doesn't respect
        # warnings.simplefilter behavior (where python users can control whether
        # or not to display warnings once, all the time, or never).
        with warnings.catch_warnings(record=True) as warns:
            x = torch.randn(3, 3, names=('N', 'C'))
            self.assertEqual(len(warns), 1)
            self.assertTrue(str(warns[0].message).startswith(
                'Named tensors and all their associated APIs are an experimental feature'))

    def test_trivial(self):
        pass

    def _test_name_inference(self, op, args=(), expected_names=(), device='cpu',
                             maybe_raises_regex=None):
        casted_args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg
                       for arg in args]
        if maybe_raises_regex is not None:
            with self.assertRaisesRegex(RuntimeError, maybe_raises_regex):
                result = op(*args)
            return
        result = op(*args)
        self.assertEqual(result.names, expected_names,
                         msg=f'Name inference for {op.__name__} on device {device} failed')

    # TODO(rzou): Some form of this check should be added to self.assertEqual.
    # Right now I don't know what it should look like.
    def assertTensorDataAndNamesEqual(self, x, y):
        self.assertEqual(x.names, y.names)
        unnamed_x = x.rename(None)
        unnamed_y = y.rename(None)
        self.assertEqual(unnamed_x, unnamed_y)

    def _test_factory(self, factory, device):
        x = factory([], device=device)
        self.assertEqual(x.names, ())

        x = factory(1, 2, 3, device=device)
        self.assertEqual(x.names, (None, None, None))

        x = factory(1, 2, 3, names=None, device=device)
        self.assertEqual(x.names, (None, None, None))

        x = factory(1, 2, 3, names=('N', 'T', 'D'), device=device)
        self.assertEqual(x.names, ('N', 'T', 'D'))

        x = factory(1, 2, 3, names=('N', None, 'D'), device=device)
        self.assertEqual(x.names, ('N', None, 'D'))

        x = factory(1, 2, 3, names=('_1', 'batch9', 'BATCH_5'), device=device)
        self.assertEqual(x.names, ('_1', 'batch9', 'BATCH_5'))

        with self.assertRaisesRegex(RuntimeError,
                                    'a valid identifier contains only'):
            x = factory(2, names=('1',), device=device)

        with self.assertRaisesRegex(RuntimeError,
                                    'a valid identifier contains only'):
            x = factory(2, names=('?',), device=device)

        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
            x = factory(2, 1, names=('N',), device=device)

        with self.assertRaisesRegex(TypeError, 'invalid combination of arguments'):
            x = factory(2, 1, names='N', device=device)

        with self.assertRaisesRegex(RuntimeError, 'construct a tensor with duplicate names'):
            x = factory(2, 1, 1, names=('N', 'C', 'N'), device=device)

        names64 = ['A' * i for i in range(1, 65)]
        x = factory([1] * 64, names=names64, device=device)
        self.assertEqual(x.names, names64)

        with self.assertRaisesRegex(
                RuntimeError,
                'only support up to 64 dims'):
            names65 = ['A' * i for i in range(1, 66)]
            x = factory([1] * 65, names=names64, device=device)

    @skipIfTorchDynamo("not a bug: Dynamo causes the refcounts to be different")
    def test_none_names_refcount(self, N=10):
        def scope():
            unnamed = torch.empty(2, 3)
            unnamed.names  # materialize [None, None]

        prev_none_refcnt = sys.getrefcount(None)
        # Ran it N times to reduce flakiness
        [scope() for i in range(N)]
        after_none_refcnt = sys.getrefcount(None)
        self.assertTrue(after_none_refcnt - prev_none_refcnt < N / 2,
                        msg='Using tensor.names should not change '
                            'the refcount of Py_None')

    def test_has_names(self):
        unnamed = torch.empty(2, 3)
        none_named = torch.empty(2, 3, names=(None, None))
        partially_named = torch.empty(2, 3, names=('N', None))
        fully_named = torch.empty(2, 3, names=('N', 'C'))

        self.assertFalse(unnamed.has_names())
        self.assertFalse(none_named.has_names())
        self.assertTrue(partially_named.has_names())
        self.assertTrue(fully_named.has_names())

    def test_py3_ellipsis(self):
        tensor = torch.randn(2, 3, 5, 7)
        output = tensor.refine_names('N', ..., 'C')
        self.assertEqual(output.names, ['N', None, None, 'C'])

    def test_refine_names(self):
        # Unnamed tensor -> Unnamed tensor
        self._test_name_inference(Tensor.refine_names,
                                  [create('None:1,None:2,None:3'), 'N', 'C', 'H'],
                                  ['N', 'C', 'H'])

        # Named tensor -> Named tensor
        self._test_name_inference(Tensor.refine_names,
                                  [create('N:1,C:2,H:3'), 'N', 'C', 'H'],
                                  ['N', 'C', 'H'])

        # Partially named tensor -> named tensor
        self._test_name_inference(Tensor.refine_names,
                                  [create('None:1,C:2,None:3'), None, 'C', 'H'],
                                  [None, 'C', 'H'])

        # Too few names
        self._test_name_inference(Tensor.refine_names,
                                  [create('None:2,None:3'), 'N', 'C', 'H'],
                                  maybe_raises_regex="different number of dims")

        # Cannot change Tensor[D] to Tensor[N]
        self._test_name_inference(Tensor.refine_names,
                                  [create('D:3'), 'N'],
                                  maybe_raises_regex="is different from")

        # Cannot change Tensor[D] to Tensor[None]
        self._test_name_inference(Tensor.refine_names,
                                  [create('D:3'), None],
                                  maybe_raises_regex="'D' is more specific than None")

        # globbing behavior exists
        self._test_name_inference(Tensor.refine_names,
                                  [create('None:1,None:1,None:2,None:3'), '...', 'C', 'H'],
                                  [None, None, 'C', 'H'])

    def test_detach(self):
        names = ['N']
        self._test_name_inference(
            Tensor.detach_,
            [torch.randn(3, requires_grad=True, names=names)],
            names)
        self._test_name_inference(
            Tensor.detach,
            [torch.randn(3, requires_grad=True, names=names)],
            names)

    def test_index_fill(self):
        for device in get_all_device_types():
            expected_names = ('N', 'C')
            x = torch.randn(3, 5, device=device, names=expected_names)

            output = x.index_fill_('C', torch.tensor([0, 1], device=device), 5)
            self.assertEqual(output.names, expected_names)

            output = x.index_fill_('C', torch.tensor([0, 1], device=device), torch.tensor(4.))
            self.assertEqual(output.names, expected_names)

            output = x.index_fill('C', torch.tensor([0, 1], device=device), 5)
            self.assertEqual(output.names, expected_names)

            output = x.index_fill('C', torch.tensor([0, 1], device=device), torch.tensor(4.))
            self.assertEqual(output.names, expected_names)

    def test_equal(self):
        for device in get_all_device_types():
            tensor = torch.randn(2, 3, device=device)
            other = tensor.clone()

            self.assertTrue(torch.equal(tensor.rename('N', 'C'), other.rename('N', 'C')))
            self.assertFalse(torch.equal(tensor.rename('M', 'C'), other.rename('N', 'C')))
            self.assertFalse(torch.equal(tensor.rename(None, 'C'), other.rename('N', 'C')))

    def test_squeeze(self):
        x = create('N:3,C:1,H:1,W:1')
        output = x.squeeze('C')
        self.assertEqual(output.names, ['N', 'H', 'W'])

        output = x.squeeze()
        self.assertEqual(output.names, ['N'])

    def test_repr(self):
        named_tensor = torch.zeros(2, 3).rename_('N', 'C')
        expected = "tensor([[0., 0., 0.],\n        [0., 0., 0.]], names=('N', 'C'))"
        self.assertEqual(repr(named_tensor), expected)

        unnamed_tensor = torch.zeros(2, 3)
        expected = "tensor([[0., 0., 0.],\n        [0., 0., 0.]])"
        self.assertEqual(repr(unnamed_tensor), expected)

        none_named_tensor = torch.zeros(2, 3).rename_(None, None)
        self.assertEqual(repr(none_named_tensor), expected)

    def test_diagonal(self):
        named_tensor = torch.zeros(2, 3, 5, 7, names=list('ABCD'))
        self.assertEqual(named_tensor.diagonal().names, ['C', 'D', None])
        self.assertEqual(named_tensor.diagonal(1, 3).names, ['A', 'C', None])

        self.assertEqual(named_tensor.diagonal(outdim='E', dim1='B', dim2='D').names,
                         ['A', 'C', 'E'])

    def test_max_pooling(self):
        def check_tuple_return(op, inputs, expected_names):
            values, indices = op(*inputs)
            self.assertEqual(values.names, expected_names)
            self.assertEqual(indices.names, expected_names)

        for device in get_all_device_types():

            named_tensor_1d = torch.zeros(2, 3, 5, device=device, names=list('ABC'))
            named_tensor_2d = torch.zeros(2, 3, 5, 7, device=device, names=list('ABCD'))
            named_tensor_3d = torch.zeros(2, 3, 5, 7, 9, device=device, names=list('ABCDE'))

            self.assertEqual(F.max_pool1d(named_tensor_1d, 2).names, named_tensor_1d.names)
            self.assertEqual(F.max_pool2d(named_tensor_2d, [2, 2]).names, named_tensor_2d.names)
            self.assertEqual(F.max_pool3d(named_tensor_3d, [2, 2, 2]).names, named_tensor_3d.names)

            check_tuple_return(F.max_pool1d_with_indices, [named_tensor_1d, 2], named_tensor_1d.names)
            check_tuple_return(F.max_pool2d_with_indices, [named_tensor_2d, [2, 2]], named_tensor_2d.names)
            check_tuple_return(F.max_pool3d_with_indices, [named_tensor_3d, [2, 2, 2]], named_tensor_3d.names)

    def test_max_pooling_without_names_does_not_warn(self):
        for device in get_all_device_types():
            tensor_2d = torch.zeros(2, 3, 5, 7, device=device, requires_grad=True)
            with warnings.catch_warnings(record=True) as warns:
                warnings.simplefilter("always")
                result = F.max_pool2d(tensor_2d, [2, 2])
                result.sum().backward()
                self.assertEqual(len(warns), 0)

    def test_no_save_support(self):
        named_tensor = torch.zeros(2, 3, names=('N', 'C'))
        buf = io.BytesIO()
        with self.assertRaisesRegex(RuntimeError, "NYI"):
            torch.save(named_tensor, buf)

    def test_no_pickle_support(self):
        named_tensor = torch.zeros(2, 3, names=('N', 'C'))
        with self.assertRaisesRegex(RuntimeError, "NYI"):
            serialized = pickle.dumps(named_tensor)

    def test_no_multiprocessing_support(self):
        named_tensor = torch.zeros(2, 3, names=('N', 'C'))
        buf = io.BytesIO()
        with self.assertRaisesRegex(RuntimeError, "NYI"):
            ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor)

    def test_big_tensor_repr_has_names(self):
        def check_repr(named_tensor):
            unnamed_tensor = named_tensor.rename(None)
            names_tag = f'names={named_tensor.names}'
            self.assertIn(names_tag, repr(named_tensor))

        check_repr(torch.randn(128, 3, 64, 64, names=('N', 'C', 'H', 'W')))

    def test_noncontig_contiguous(self):
        # This type of contiguous is special-cased and therefore needs its own test
        for device in get_all_device_types():
            x = torch.randn(2, 3, device=device).t().rename_('N', 'C')
            self.assertEqual(x.contiguous().names, ('N', 'C'))

    def test_copy_transpose(self):
        # This type of copy is special-cased and therefore needs its own test
        def _test(self_names, other_names, expected_names):
            x = torch.empty(2, 5, names=self_names)
            y = torch.empty(5, 2).t().rename_(*other_names)
            x.copy_(y)
            self.assertEqual(x.names, expected_names)

        _test(('N', 'C'), ('N', 'C'), ('N', 'C'))
        _test(None, ('N', 'C'), ('N', 'C'))

    def test_rename_(self):
        tensor = torch.empty(1, 1, names=('N', 'C'))
        self.assertEqual(tensor.rename_(None).names, (None, None))
        self.assertEqual(tensor.rename_('H', 'W').names, ('H', 'W'))
        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
            tensor.rename_('N', 'C', 'W')
        with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
            tensor.rename_('N', 'N')

    def test_rename(self):
        tensor = torch.empty(1, 1, names=('N', 'C'))

        self.assertEqual(tensor.rename(None).names, (None, None))
        self.assertEqual(tensor.rename('H', 'W').names, ('H', 'W'))

        # Check that we didn't modify tensor.names
        self.assertEqual(tensor.names, ('N', 'C'))

        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
            tensor.rename('N', 'C', 'W')
        with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
            tensor.rename('N', 'N')

        with self.assertRaisesRegex(RuntimeError, 'either positional args or keyword args'):
            tensor.rename(None, N='batch')

        # rename returns a view on the tensor
        self.assertEqual(tensor.rename('H', 'W').data_ptr(), tensor.data_ptr())
        self.assertEqual(tensor.rename(None).data_ptr(), tensor.data_ptr())

    def test_rename_globber(self):
        scalar = torch.randn([])
        unnamed_tensor = torch.empty(1, 1, 1, 1)
        named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W'))

        self.assertEqual(scalar.rename(None).names, [])
        self.assertEqual(scalar.rename('...').names, [])

        # Check that it works with unnamed tensors
        self.assertEqual(unnamed_tensor.rename('...').names, unnamed_tensor.names)
        self.assertEqual(unnamed_tensor.rename('...', 'H', 'W').names,
                         [None, None, 'H', 'W'])
        self.assertEqual(unnamed_tensor.rename('N', '...', 'W').names,
                         ['N', None, None, 'W'])
        self.assertEqual(unnamed_tensor.rename('N', 'C', '...').names,
                         ['N', 'C', None, None])

        # Check that it works with named tensors
        self.assertEqual(named_tensor.rename('...').names, named_tensor.names)
        self.assertEqual(named_tensor.rename('...', 'width').names,
                         ['N', 'C', 'H', 'width'])
        self.assertEqual(named_tensor.rename('batch', 'channels', '...', 'width').names,
                         ['batch', 'channels', 'H', 'width'])
        self.assertEqual(named_tensor.rename('batch', '...').names,
                         ['batch', 'C', 'H', 'W'])

        # Test empty glob
        self.assertEqual(unnamed_tensor.rename('...', None, None, None, None).names,
                         [None, None, None, None])
        self.assertEqual(named_tensor.rename('N', 'C', 'H', '...', 'W').names,
                         ['N', 'C', 'H', 'W'])

        # Multiple globs throw
        with self.assertRaisesRegex(RuntimeError, 'More than one '):
            named_tensor.rename('...', 'channels', '...')

    def test_rename_rename_map(self):
        scalar = torch.randn([])
        unnamed_tensor = torch.empty(1, 1, 1, 1)
        named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W'))

        with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
            scalar.rename(N='batch')
        with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
            unnamed_tensor.rename(N='batch')
        with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
            named_tensor.rename(B='batch')
        with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
            named_tensor.rename(H='height', B='batch')

        self.assertEqual(named_tensor.rename(N='batch').data_ptr(),
                         named_tensor.data_ptr())
        self.assertEqual(named_tensor.rename(N='batch').names,
                         ['batch', 'C', 'H', 'W'])
        self.assertEqual(named_tensor.rename(N='batch', H='height').names,
                         ['batch', 'C', 'height', 'W'])

    def test_set_names_property(self):
        tensor = torch.empty(1, 1, names=('N', 'C'))

        tensor.names = None
        self.assertEqual(tensor.names, (None, None))

        tensor.names = ('N', 'W')
        self.assertEqual(tensor.names, ('N', 'W'))

        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
            tensor.names = ['N', 'C', 'W']
        with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
            tensor.names = ['N', 'N']

    def test_factory_edge_cases(self):
        for device in get_all_device_types():
            self._test_factory(torch.empty, device)

    def test_factory_coverage(self):
        def _test(factory, device):
            names = ('N', 'T', 'D')

            torch.manual_seed(0)
            result = factory(1, 2, 3, names=names, device=device)

            torch.manual_seed(0)
            expected = factory(1, 2, 3, device=device).rename_(*names)

            self.assertTensorDataAndNamesEqual(result, expected)

        supported = [
            torch.ones,
            torch.rand,
            torch.randn,
            torch.zeros,
        ]

        for op, device in itertools.product(supported, get_all_device_types()):
            _test(op, device)

        # Test torch.full
        for device in get_all_device_types():
            names = ('N', 'T', 'D')
            result = torch.full([1, 2, 3], 2., names=names, device=device)
            expected = torch.full([1, 2, 3], 2., device=device).rename_(*names)
            self.assertTensorDataAndNamesEqual(result, expected)

    def test_tensor_from_lists(self):
        names = ('N', 'C')
        tensor = torch.tensor([[1]], names=names)
        self.assertEqual(tensor.names, names)

        names = ('N',)
        tensor = torch.tensor([1], names=names)
        self.assertEqual(tensor.names, names)

        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
            names = ('N', 'C')
            tensor = torch.tensor([1], names=names)

    @unittest.skipIf(not TEST_NUMPY, "no numpy")
    def test_tensor_from_numpy(self):
        import numpy as np
        arr = np.array([[1]])
        names = ('N', 'C')
        tensor = torch.tensor([[1]], names=names)
        self.assertEqual(tensor.names, names)

    def test_tensor_from_tensor(self):
        x = torch.randn(1, 1)
        names = ('N', 'C')
        tensor = torch.tensor(x, names=names)
        self.assertEqual(tensor.names, names)

    def test_tensor_from_named_tensor(self):
        x = torch.randn(1, 1, names=('N', 'D'))
        tensor = torch.tensor(x)
        self.assertEqual(tensor.names, ('N', 'D'))

        # there's no way to distinguish between names=None and not passing in names.
        # If the user passes in names=None they are asking for trouble.
        x = torch.randn(1, 1, names=('N', 'D'))
        tensor = torch.tensor(x, names=None)
        self.assertEqual(tensor.names, ('N', 'D'))

        x = torch.randn(1, 1, names=('N', 'D'))
        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
            tensor = torch.tensor(x, names=('N', 'C'))

    def test_size(self):
        t = torch.empty(2, 3, 5, names=('N', None, 'C'))
        self.assertEqual(t.size('N'), 2)
        self.assertEqual(t.size('C'), 5)
        with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
            t.size('channels')
        with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
            torch.empty(2, 3, 4).size('N')

    def test_stride(self):
        t = torch.empty(2, 3, 5, names=('N', None, 'C'))
        self.assertEqual(t.stride('N'), 3 * 5)
        self.assertEqual(t.stride('C'), 1)
        with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
            t.stride('channels')
        with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
            torch.empty(2, 3, 4).stride('N')

    def test_transpose_variants(self):
        t = torch.randn(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
        self.assertEqual(t.transpose('N', 'C').names, ['C', 'N', 'H', 'W'])
        self.assertEqual(t.transpose(1, 3).names, ['N', 'W', 'H', 'C'])

        t = torch.randn(2, 3, names=('N', 'C'))
        self.assertEqual(t.t().names, ['C', 'N'])

    def test_resize(self):
        for device in get_all_device_types():
            named = torch.randn(2, names=('N',), device=device)
            named.resize_([2])
            self.assertEqual(named.names, ['N'])

            with self.assertRaisesRegex(RuntimeError, "Cannot resize named tensor"):
                named.resize_([3])

            other_named = torch.randn(2, names=('N',), device=device)
            named.resize_as_(other_named)
            self.assertEqual(other_named.names, ['N'])

            unnamed = torch.randn(2, device=device)
            with self.assertRaisesRegex(
                    RuntimeError, r'names .* are not the same as the computed output names'):
                named.resize_as_(unnamed)

            unnamed = torch.randn(1, device=device)
            unnamed.resize_as_(named)
            self.assertEqual(unnamed.names, ['N'])

    def test_cdist(self):
        for device in get_all_device_types():
            tensor = torch.randn(3, 1, 2, 7, names=('M', 'N', 'first_group', 'features'),
                                 device=device)
            other = torch.randn(5, 11, 7, names=('N', 'second_group', 'features'),
                                device=device)
            result = torch.cdist(tensor, other)
            self.assertEqual(result.names, ['M', 'N', 'first_group', 'second_group'])

    def test_info_smoke(self):
        # Smoke test for info functions / methods / attributes on named tensors.
        tensor = torch.empty(1, 1, names=('N', 'D'))

        tensor.device
        tensor.dtype
        tensor.get_device()
        tensor.is_complex()
        tensor.is_floating_point()
        tensor.is_nonzero()
        torch.is_same_size(tensor, tensor)
        torch.is_signed(tensor)
        tensor.layout
        tensor.numel()
        tensor.dim()
        tensor.element_size()
        tensor.is_contiguous()
        tensor.is_cuda
        tensor.is_leaf
        tensor.is_pinned()
        tensor.is_shared()
        tensor.is_sparse
        tensor.ndimension()
        tensor.nelement()
        tensor.shape
        tensor.size()
        tensor.size(1)
        tensor.storage()
        tensor.storage_offset()
        tensor.storage_type()
        tensor.stride()
        tensor.stride(1)
        tensor.data
        tensor.data_ptr()
        tensor.ndim
        tensor.item()
        tensor.type()
        tensor.is_shared()
        tensor.is_signed()

    def test_autograd_smoke(self):
        x = torch.randn(3, 3, names=('N', 'D'), requires_grad=True)

        y = x.clone()
        y.retain_grad()
        y.register_hook(lambda x: x)

        y.sum().backward()

        # autograd related attributes
        tensor = torch.empty(1, 1, names=('N', 'D'), requires_grad=True)
        tensor = tensor.relu()
        tensor.output_nr
        tensor.grad_fn
        tensor.requires_grad

    def test_split_fns_propagates_names(self):
        fns = [
            lambda x: x.split(1, 0),
            lambda x: x.split([1, 1], 1),
            lambda x: x.chunk(2, 0),
        ]

        for device in get_all_device_types():
            orig_tensor = torch.empty(2, 2, names=('N', 'D'), device=device)
            for fn in fns:
                splits = fn(orig_tensor)
                for split in splits:
                    self.assertEqual(split.names, orig_tensor.names)

    def test_any_all(self):
        for device in get_all_device_types():
            x = torch.zeros(3, dtype=torch.bool, device=device, names=('C',))
            self.assertEqual(x.any().names, [])
            self.assertEqual(x.all().names, [])

    def test_addcmul_addcdiv(self):
        for device in get_all_device_types():
            names = ['N']
            a = torch.rand(3, device=device, names=names)
            b = torch.rand(3, device=device, names=names)
            # avoid division by 0
            c = torch.rand(3, device=device, names=names).clamp_min_(0.1)
            out = torch.randn(3, device=device, names=names)

            self.assertEqual(torch.addcmul(a, b, c).names, names)
            self.assertEqual(torch.addcmul(a, b, c, out=out).names, names)
            self.assertEqual(a.addcmul_(b, c).names, names)

            self.assertEqual(torch.addcdiv(a, b, c).names, names)
            self.assertEqual(torch.addcdiv(a, b, c, out=out).names, names)
            self.assertEqual(a.addcdiv_(b, c).names, names)

    def test_binary_ops(self):
        def test_basic(op):
            a = torch.empty(2, 3, names=('N', 'C'))
            b = torch.empty(3, 2, names=('C', 'N'))
            c = torch.empty(3, names=('C',))
            d = torch.empty(5, names=('W',))

            self.assertEqual(op(a, a).names, ('N', 'C'))
            self.assertEqual(op(a, c).names, ('N', 'C'))
            # TODO: dynamo will throw a slightly different
            # error message because it's adding fake tensors
            # `must match the size of` portion is the dynamo error
            with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
                op(a, d)
            with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
                op(a, b)

        def test_wildcard(op):
            a = torch.empty(2, 3, names=('N', 'C'))
            c = torch.empty(2, 3, names=(None, 'C'))
            self.assertEqual(op(a, c).names, ('N', 'C'))

            b = torch.empty(2, 3)
            self.assertEqual(op(a, b).names, ('N', 'C'))

            d = torch.empty(2, 3, names=('C', None))
            with self.assertRaisesRegex(RuntimeError, "Misaligned"):
                op(d, c)

        def test_mixed_unnamed_named(op, is_inplace):
            named2 = torch.randn(1, 1, names=('N', 'C'))
            unnamed1 = torch.randn(1)
            unnamed2 = torch.randn(1, 1)
            unnamed3 = torch.randn(1, 1, 1)

            def compute_expected_names(tensor, other):
                assert tensor.has_names() ^ other.has_names()
                named = tensor if tensor.has_names() else other
                unnamed = other if tensor.has_names() else tensor
                unnamed_dim = unnamed.dim()
                if unnamed_dim > named.dim():
                    return [None] * (unnamed_dim - named.dim()) + list(named.names)
                else:
                    return named.names

            inputs = itertools.chain(
                itertools.product([named2], [unnamed1, unnamed2, unnamed3]),
                itertools.product([unnamed1, unnamed2, unnamed3], [named2]),
            )
            if is_inplace:
                # In-place ops have the constraint that they must not change shape.
                inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()]

            for tensor, other in inputs:
                expected_names = compute_expected_names(tensor, other)
                self.assertEqual(op(tensor, other).names, expected_names)

        def method(name, *args, **kwargs):
            return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))]

        def function(name, *args, **kwargs):
            return [Function(name, lambda a, b: getattr(torch, name)(a, b, *args, **kwargs))]

        def out_function(name, *args, **kwargs):
            out_fn = getattr(torch, name)

            def fn(a, b):
                result = torch.empty([0], dtype=a.dtype, device=a.device)
                out_fn(a, b, *args, out=result, **kwargs)
                return result

            return [Function(name, fn)]

        def fn_method_and_inplace(name, *args, **kwargs):
            return (
                method(name, *args, **kwargs) +
                method(name + '_', *args, **kwargs) +
                out_function(name, *args, **kwargs)
            )

        tests = [
            fn_method_and_inplace('add'),
            fn_method_and_inplace('div'),
            fn_method_and_inplace('mul'),
            fn_method_and_inplace('sub'),
            fn_method_and_inplace('pow'),
            fn_method_and_inplace('atan2'),
            method('copy_'),
            function('floor_divide'),
            function('true_divide'),
        ]
        tests = flatten(tests)

        for name, op in tests:
            test_basic(op)
            test_wildcard(op)
            test_mixed_unnamed_named(op, is_inplace=name.endswith('_'))

    def test_logical_ops(self):
        # Implemented via TensorIterator, so just check that each version
        # (out-of-place, inplace, out=) propagates names.
        def zeros(*args, **kwargs):
            return torch.zeros(*args, dtype=torch.bool, **kwargs)

        for op in ('logical_xor', 'logical_and', 'logical_or'):
            self._test_name_inference(
                getattr(torch, op),
                (create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
                expected_names=['N', 'C'])

            self._test_name_inference(
                getattr(Tensor, op + '_'),
                (create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
                expected_names=['N', 'C'])

            self._test_name_inference(
                lambda out, x, y: getattr(torch, op)(x, y, out=out),
                (create('0', zeros), create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
                expected_names=['N', 'C'])

    def test_pow_special(self):
        # There are a few pow cases that don't go through TensorIterator.
        # Test them here.
        for device in get_all_device_types():
            named = torch.randn(2, 3, names=('N', 'C'), device=device)
            unnamed = torch.randn([0], device=device)

            result = torch.pow(named, 0, out=unnamed.clone())
            self.assertEqual(result.names, named.names)

            result = torch.pow(named, 1, out=unnamed.clone())
            self.assertEqual(result.names, named.names)

            result = torch.pow(1, named, out=unnamed.clone())
            self.assertEqual(result.names, named.names)

    def test_out_fn_semantics(self):
        out_fn = torch.abs
        unnamed_tensor = torch.randn(3, 2)
        none_named_tensor = torch.randn(3, 2, names=(None, None))
        named_tensor = torch.randn(3, 2, names=('N', 'C'))
        partially_named_tensor = torch.randn(3, 2, names=('N', None))

        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
            out_fn(partially_named_tensor, out=named_tensor)
        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
            out_fn(named_tensor, out=partially_named_tensor)
        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
            out_fn(none_named_tensor, out=named_tensor)
        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
            out_fn(unnamed_tensor, out=named_tensor)

        output = torch.randn(3, 2)
        out_fn(unnamed_tensor, out=output)
        self.assertFalse(output.has_names())

        output = torch.randn(3, 2, names=(None, None))
        out_fn(named_tensor, out=output)
        self.assertEqual(output.names, named_tensor.names)

        output = torch.randn(3, 2)
        out_fn(named_tensor, out=output)
        self.assertEqual(output.names, named_tensor.names)

        output = torch.randn(3, 2, names=(None, None))
        out_fn(unnamed_tensor, out=output)
        self.assertFalse(output.has_names())

    def test_unary_propagate_names_fns(self):
        def _test(testcase, names=('N', 'D'), device='cpu'):
            sizes = [2] * len(names)
            tensor = torch.empty(sizes, names=names, device=device)
            try:
                out = testcase.lambd(tensor)
            except RuntimeError as err:
                # Get a better error message by catching the error and asserting.
                raise RuntimeError(f'{testcase.name}: {err}') from err
            self.assertEqual(out.names, tensor.names,
                             msg=testcase.name)

        def fn(name, *args, **kwargs):
            return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))]

        def method(name, *args, **kwargs):
            return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))]

        def out_function(name, *args, **kwargs):
            out_fn = getattr(torch, name)

            def fn(tensor):
                result = torch.empty([0], dtype=tensor.dtype, device=tensor.device)
                out_fn(tensor, *args, out=result, **kwargs)
                return result

            return [Function(name + '_out', fn)]

        def fn_method_and_inplace(name, *args, **kwargs):
            return (
                method(name, *args, **kwargs) +
                method(name + '_', *args, **kwargs) +
                out_function(name, *args, **kwargs)
            )

        # All of these operate on 2x2 tensors.
        tests = [
            # unary pointwise
            fn_method_and_inplace('abs'),
            fn_method_and_inplace('acos'),
            fn_method_and_inplace('asin'),
            fn_method_and_inplace('atan'),
            fn_method_and_inplace('ceil'),
            fn_method_and_inplace('clamp', -1, 1),
            fn_method_and_inplace('clamp_min', -2),
            fn_method_and_inplace('clamp_max', 2),
            method('cauchy_'),
            method('clone'),
            method('contiguous'),
            fn_method_and_inplace('cos'),
            fn_method_and_inplace('cosh'),
            fn_method_and_inplace('digamma'),
            fn_method_and_inplace('erf'),
            fn_method_and_inplace('erfc'),
            fn_method_and_inplace('erfinv'),
            fn_method_and_inplace('exp'),
            fn_method_and_inplace('expm1'),
            method('exponential_'),
            fn_method_and_inplace('floor'),
            fn_method_and_inplace('frac'),
            method('geometric_', p=0.5),
            fn_method_and_inplace('lgamma'),
            fn_method_and_inplace('log'),
            fn_method_and_inplace('log10'),
            fn_method_and_inplace('log1p'),
            fn_method_and_inplace('log2'),
            method('log_normal_'),
            fn_method_and_inplace('neg'),
            method('normal_'),
            [Function('polygamma', lambda t: torch.polygamma(1, t))],
            method('polygamma_', 1),
            fn_method_and_inplace('reciprocal'),
            method('random_', 0, 1),
            method('random_', 1),
            method('random_'),
            method('relu_'),
            method('requires_grad_'),
            method('relu'),
            fn_method_and_inplace('round'),
            fn_method_and_inplace('rsqrt'),
            fn_method_and_inplace('sigmoid'),
            fn_method_and_inplace('sign'),
            fn_method_and_inplace('sin'),
            fn_method_and_inplace('sinh'),
            fn_method_and_inplace('sqrt'),
            fn_method_and_inplace('tan'),
            fn_method_and_inplace('tanh'),
            fn('threshold', 0, 1),
            fn('threshold_', 0, 1),
            out_function('threshold', 0, 1),
            fn_method_and_inplace('trunc'),
            method('uniform_'),
            method('zero_'),
            method('fill_', 1),
            method('fill_', torch.tensor(3.14)),

            # conversions
            method('to', dtype=torch.long),
            method('to', device='cpu'),
            method('to', torch.empty([])),
            method('bool'),
            method('byte'),
            method('char'),
            method('cpu'),
            method('double'),
            method('float'),
            method('long'),
            method('half'),
            method('int'),
            method('short'),
            method('type', dtype=torch.long),

            # cumsum and cumprod
            fn('cumsum', 0),
            fn('cumsum', 'D'),
            out_function('cumsum', 'D'),
            fn('cumprod', 0),
            fn('cumprod', 'D'),
            out_function('cumprod', 'D'),

            # views
            method('narrow', 0, 0, 1),

            # creation functions
            fn('empty_like'),
            fn('zeros_like'),
            fn('ones_like'),
            fn('full_like', 3.14),
            fn('rand_like'),
            fn('randn_like'),

            # bernoulli variants
            method('bernoulli_', 0.5),
            method('bernoulli_', torch.tensor(0.5)),

            method('softmax', dim=1),
            method('softmax', dim='D'),
            method('log_softmax', dim=1),
            method('log_softmax', dim='D'),

            [Function('F.dropout(inplace)', lambda t: F.dropout(t, p=0.5, inplace=True))],
            [Function('F.dropout(outplace)', lambda t: F.dropout(t, p=0.5, inplace=False))],
        ]
        tests = flatten(tests)

        for testcase, device in itertools.product(tests, get_all_device_types()):
            _test(testcase, device=device)

    def test_cummax_cummin(self):
        def test_ops(op):
            for device in get_all_device_types():
                names = ('N', 'D')
                tensor = torch.rand(2, 3, names=names)
                result = op(tensor, 0)
                self.assertEqual(result[0].names, names)
                self.assertEqual(result[1].names, names)
        test_ops(torch.cummax)
        test_ops(torch.cummin)

    def test_logcumsumexp(self):
        for device in get_all_device_types():
            names = ('N', 'D')
            tensor = torch.rand(2, 3, names=names)
            result = torch.logcumsumexp(tensor, 'D')
            self.assertEqual(result.names, names)

    def test_bitwise_not(self):
        for device in get_all_device_types():
            names = ('N', 'D')
            tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
            result = torch.empty(0, dtype=torch.bool)

            self.assertEqual(tensor.bitwise_not().names, names)
            self.assertEqual(torch.bitwise_not(tensor, out=result).names, names)
            self.assertEqual(tensor.bitwise_not_().names, names)

    def test_logical_not(self):
        for device in get_all_device_types():
            names = ('N', 'D')
            tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
            result = torch.empty(0, dtype=torch.bool)

            self.assertEqual(tensor.logical_not().names, names)
            self.assertEqual(torch.logical_not(tensor, out=result).names, names)
            self.assertEqual(tensor.logical_not_().names, names)

    def test_bernoulli(self):
        for device in get_all_device_types():
            names = ('N', 'D')
            tensor = torch.rand(2, 3, names=names)
            result = torch.empty(0)
            self.assertEqual(tensor.bernoulli().names, names)

            torch.bernoulli(tensor, out=result)
            self.assertEqual(result.names, names)

    def test_flatten(self):
        tensor = torch.randn(2, 3, 5, 7, 11, names=('N', 'C', 'D', 'H', 'W'))

        # basic
        out = tensor.flatten('D', 'W', 'features')
        self.assertEqual(out.names, ['N', 'C', 'features'])
        self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))

        # int overload
        out = tensor.flatten(2, 4, 'features')
        self.assertEqual(out.names, ['N', 'C', 'features'])
        self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))

        # list overload
        out = tensor.flatten(['D', 'H', 'W'], 'features')
        self.assertEqual(out.names, ['N', 'C', 'features'])
        self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))

        # Non-contiguous flatten: N and H are not "adjacent" in memory.
        sentences = torch.randn(2, 3, 5, 7, names=('N', 'T', 'H', 'D'))
        sentences = sentences.transpose('T', 'H')
        out = sentences.flatten('N', 'H', 'N_H')
        self.assertEqual(out.names, ['N_H', 'T', 'D'])

        with self.assertRaisesRegex(RuntimeError, "Name 'L' not found in"):
            tensor.flatten(['D', 'L'], 'features')

        with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
            tensor.flatten(['D', 'W'], 'features')

        with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
            tensor.flatten(['H', 'D', 'W'], 'features')

    def test_flatten_nodims(self):
        tensor = torch.empty((2, 3))
        with self.assertRaisesRegex(RuntimeError, "cannot be empty"):
            tensor.flatten((), 'abcd')

    def test_flatten_index_error(self):
        tensor = torch.randn(1, 2)
        with self.assertRaisesRegex(IndexError,
                                    r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
            tensor.flatten(0, 2)
        with self.assertRaisesRegex(IndexError,
                                    r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
            tensor.flatten(0, 2, 'N')
        with self.assertRaisesRegex(RuntimeError,
                                    r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
            tensor.flatten(1, 0)
        with self.assertRaisesRegex(RuntimeError,
                                    r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
            tensor.flatten(1, 0, 'N')

    def test_unflatten(self):
        # test args: tensor, int, namedshape
        self.assertTrue(torch.equal(
            torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))),
            torch.ones(2, 2, names=('A', 'B'))))
        self.assertTrue(torch.equal(
            torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]),
            torch.ones(2, 2, names=('A', 'B'))))
        self.assertTrue(torch.equal(
            torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])),
            torch.ones(2, 2, names=('A', 'B'))))
        self.assertTrue(torch.equal(
            torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)),
            torch.ones(2, 10, names=('A', 'B1'))))
        self.assertTrue(torch.equal(
            torch.ones(2, 3 * 4 * 5 * 6, names=('A', 'B'))
                 .unflatten('B', (['B1', 3], ['B2', 4], ['B3', -1], ['B4', 6])),
            torch.ones(2, 3, 4, 5, 6, names=('A', 'B1', 'B2', 'B3', 'B4'))))
        self.assertTrue(torch.equal(
            torch.ones(2, 0, names=('A', 'B'))
                 .unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])),
            torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3'))))

        # test args: namedtensor, str, namedshape
        self.assertTrue(torch.equal(
            torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))),
            torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))

        # test invalid args: namedtensor, str, sizes
        with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
            torch.tensor([1], names=('A',)).unflatten('A', (1, 1))

        # test invalid args: namedtensor, int, sizes
        with self.assertRaisesRegex(RuntimeError, r"input is a named tensor but no names were given for unflattened sizes"):
            torch.tensor([1], names=("A",)).unflatten(0, (1, 1))

        with self.assertRaisesRegex(RuntimeError,
                                    r"Provided sizes \[3, -1\] don't multiply up to the "
                                    r"size of dim 1 \('B': 4\) in Tensor\['A', 'B'\]"):
            torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 3), ('B2', -1)))

        with self.assertRaisesRegex(RuntimeError,
                                    r"the unspecified dimension size -1 can be any value and is ambiguous"):
            torch.ones(2, 0, names=('A', 'B')).unflatten('B', (('B1', 0), ('B2', -1)))

        tensor = torch.randn(7, 2 * 3 * 5, 11, names=('N', 'D', 'K'))

        # accepts OrderedDict
        out = tensor.unflatten('D', OrderedDict((('C', 2), ('H', 3), ('W', 5))))
        self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K'))
        self.assertEqual(out.shape, (7, 2, 3, 5, 11))

        # Unflatten left-most
        out = tensor.unflatten('N', (('N', 7), ('H', 1)))
        self.assertEqual(out.names, ('N', 'H', 'D', 'K'))
        self.assertEqual(out.shape, (7, 1, 2 * 3 * 5, 11))

        # Unflatten right-most
        out = tensor.unflatten('K', (('K', 11), ('H', 1)))
        self.assertEqual(out.names, ('N', 'D', 'K', 'H'))
        self.assertEqual(out.shape, (7, 2 * 3 * 5, 11, 1))

        with self.assertRaisesRegex(RuntimeError, "don't multiply up to"):
            tensor.unflatten('D', (('H', 3), ('W', 5)))

        with self.assertRaisesRegex(RuntimeError, 'sizes must be non-empty'):
            tensor.unflatten('D', None)

        with self.assertRaisesRegex(RuntimeError, 'non-empty'):
            tensor.unflatten('D', OrderedDict())

    def test_unsupported_op_error_msg(self):
        named = torch.randn(3, 3, names=('N', 'C'))
        with self.assertRaisesRegex(
                RuntimeError, r"pdist.+is not yet supported with named tensors"):
            torch.pdist(named)
        with self.assertRaisesRegex(
                RuntimeError, r"as_strided_.+is not yet supported with named tensors"):
            named.as_strided_((3, 3), (3, 1))

    def test_reduction_fns(self):
        def check_output(output, expected_names):
            if isinstance(output, torch.Tensor):
                self.assertEqual(output.names, expected_names)
                return
            for out in output:
                self.assertEqual(out.names, expected_names)

        def sum_all_outputs(output):
            if isinstance(output, torch.Tensor):
                return output.sum()
            result = 0
            for out in output:
                result = out + result
            return result.sum()

        def test_simple_reduce(op, device):
            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
            check_output(op(t, 1), ['N', 'L'])
            check_output(op(t, -1), ['N', 'C'])
            check_output(op(t, 'C'), ['N', 'L'])
            ops_support_dim_none = [
                'sum',
                'mean',
                'std',
                'var',
                'std_mean',
                'var_mean',
                'nanmean',
                'nansum',
            ]
            if op.__name__ in ops_support_dim_none:
                check_output(op(t, None), [])
            else:
                with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
                    op(t, None)
            with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'):
                op(t, 'H')

        def test_autograd_supports_dimname_overload(op, device):
            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device, requires_grad=True)
            sum_all_outputs(op(t, 'C')).backward()
            self.assertIsNotNone(t.grad)

        def test_complete_reduce(op, device):
            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
            check_output(op(t), [])

        def test_multidim_reduce(op, device):
            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)

            check_output(op(t, [1, 2]), ['N'])
            check_output(op(t, [0, -1]), ['C'])
            check_output(op(t, ['C', 'L']), ['N'])
            with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
                op(t, [None, 'C'])

        def test_out_variant(op, output_lambda, device):
            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
            if output_lambda:
                out = output_lambda(t)
            else:
                out = torch.empty([0], device=device)
            op(t, 'C', out=out)
            check_output(out, ['N', 'L'])

        def test_keepdim(op, device):
            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
            check_output(op(t, 'C', keepdim=True), ['N', 'C', 'L'])

        def values_and_indices(t):
            return (torch.empty([0], device=t.device),
                    torch.empty([0], device=t.device, dtype=torch.long))

        def kthvalue_wrapper(tensor, *args, **kwargs):
            # Return the 0-th value
            return torch.kthvalue(tensor, 1, *args, **kwargs)

        Case = namedtuple('Case', [
            'op',
            'supports_complete_reduce',
            'supports_multidim_reduce',
            'supports_out_variant',
            'supports_keepdim',
            'output_lambda',
        ])

        tests = [
            Case(torch.sum, True, True, True, True, None),
            Case(torch.prod, True, False, True, True, None),
            Case(torch.mean, True, True, True, True, None),
            Case(torch.var, True, True, True, True, None),
            Case(torch.std, True, True, True, True, None),
            Case(torch.std_mean, True, True, False, True, None),
            Case(torch.var_mean, True, True, False, True, None),
            Case(torch.min, True, False, True, True, values_and_indices),
            Case(torch.max, True, False, True, True, values_and_indices),
            Case(torch.unbind, False, False, False, False, None),
            Case(torch.logsumexp, False, True, True, True, None),
            Case(torch.mode, False, False, True, True, values_and_indices),
            Case(kthvalue_wrapper, False, False, True, True, values_and_indices),
            Case(torch.median, True, False, True, True, values_and_indices),
            Case(torch.nanmedian, True, False, True, True, values_and_indices),
        ]

        for testcase, device in itertools.product(tests, get_all_device_types()):
            op = testcase.op
            test_simple_reduce(op, device)
            test_autograd_supports_dimname_overload(op, device)

            if testcase.supports_keepdim:
                test_keepdim(op, device)
            if testcase.supports_out_variant:
                test_out_variant(op, testcase.output_lambda, device)
            if testcase.supports_complete_reduce:
                test_complete_reduce(op, device)
            if testcase.supports_multidim_reduce:
                test_multidim_reduce(op, device)

    def test_masked_select(self):
        # simple
        self._test_name_inference(
            torch.masked_select,
            (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')),
            expected_names=[None])

        # left broadcast
        self._test_name_inference(
            torch.masked_select,
            (create('C:3'), (create('2,3') > 0).rename('N', 'C')),
            expected_names=[None])

        # right broadcast
        self._test_name_inference(
            torch.masked_select,
            (create('N:2,C:3'), (create('3') > 0).rename('C')),
            expected_names=[None])

        # error
        self._test_name_inference(
            torch.masked_select,
            (create('N:2,C:3'), (create('3') > 0).rename('D')),
            maybe_raises_regex='do not match')

        # out=
        self._test_name_inference(
            out_fn(torch.masked_select),
            (create('0'), create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')),
            expected_names=[None])

    def test_cat(self):
        # simple
        self._test_name_inference(
            torch.cat,
            [[create('N:2,C:3'), create('N:2,C:3')]],
            expected_names=['N', 'C'])

        # error: zero dim
        self._test_name_inference(
            torch.cat,
            [[create(''), create('')]],
            maybe_raises_regex='zero-dim')

        # error: names don't match
        self._test_name_inference(
            torch.cat,
            [[create('N:2,C:3'), create('C:3,N:2')]],
            maybe_raises_regex='do not match')

        # error: different number of dims
        self._test_name_inference(
            torch.cat,
            [[create('N:2,C:3'), create('C:3')]],
            maybe_raises_regex='must have same number of dimensions')

        # out=
        self._test_name_inference(
            out_fn(torch.cat),
            [create('0'), [create('N:2,C:3'), create('N:2,C:3')]],
            expected_names=['N', 'C'])

    def test_masked_fill(self):
        # simple
        self._test_name_inference(
            Tensor.masked_fill,
            (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
            expected_names=['N', 'C'])

        # left broadcast
        self._test_name_inference(
            Tensor.masked_fill,
            (create('C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
            maybe_raises_regex="must be less than or equal to")

        # right broadcast
        self._test_name_inference(
            Tensor.masked_fill,
            (create('N:2,C:3'), (create('3') > 0).rename('C'), 3.14),
            expected_names=['N', 'C'])

        # error
        self._test_name_inference(
            Tensor.masked_fill,
            (create('N:2,C:3'), (create('3') > 0).rename('D'), 3.14),
            maybe_raises_regex='do not match')

        # inplace
        self._test_name_inference(
            Tensor.masked_fill_,
            (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
            expected_names=['N', 'C'])

        # inplace, computed names don't match output tensor names
        self._test_name_inference(
            Tensor.masked_fill_,
            (create('N:2,None:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
            maybe_raises_regex="not the same as the computed output names")


    def test_using_seen_interned_string_doesnt_bump_refcount(self):
        def see_name():
            seen_name = 'N'
            pass_name_to_python_arg_parser(seen_name)

        see_name()
        seen_name = 'N'
        old_refcnt = sys.getrefcount(seen_name)

        pass_name_to_python_arg_parser(seen_name)

        new_refcnt = sys.getrefcount(seen_name)
        self.assertEqual(new_refcnt, old_refcnt)

    # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464
    @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+")
    def test_using_unseen_interned_string_bumps_refcount_permanently(self):
        # Please don't use this as a name in a different test.
        unseen_name = 'abcdefghi'
        old_refcnt = sys.getrefcount(unseen_name)

        pass_name_to_python_arg_parser(unseen_name)

        new_refcnt = sys.getrefcount(unseen_name)
        self.assertEqual(new_refcnt, old_refcnt + 1)

    # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464
    @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+")
    def test_using_unseen_uninterned_string_refcounts(self):
        # Please don't use this as a name in a different test.
        # non-compile-time constants are not interned
        unseen_name = ''.join(['abc', 'def', 'ghi', 'jkl'])
        interned_unseen_name = 'abcdefghijkl'
        self.assertFalse(unseen_name is interned_unseen_name)

        old_uninterned_refcnt = sys.getrefcount(unseen_name)
        old_interned_refcnt = sys.getrefcount(interned_unseen_name)

        pass_name_to_python_arg_parser(unseen_name)

        new_uninterned_refcnt = sys.getrefcount(unseen_name)
        new_interned_refcnt = sys.getrefcount(interned_unseen_name)

        # Internally, PyTorch should not hold a reference to the uninterned string
        self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt)

        # Instead, we should hold a new reference to the interned version.
        self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1)

    def _test_select(self, device):
        x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device)
        y = x.select(1, 1)
        self.assertEqual(y.names, ('N', 'H', 'W'))

        y = x.select('C', 1)
        self.assertEqual(y.names, ('N', 'H', 'W'))

        with self.assertRaisesRegex(
                RuntimeError, 'Please look up dimensions by name'):
            y = x.select(None, 1)

    def test_select(self):
        self._test_select('cpu')

    @unittest.skipIf(not TEST_CUDA, 'no CUDA')
    def test_select_cuda(self):
        self._test_select('cuda')

    def _test_as_strided(self, device):
        x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device)
        y = x.as_strided([2 * 3 * 4 * 5], [1])
        self.assertEqual(y.names, (None,))

    def test_as_strided(self):
        self._test_as_strided('cpu')

    @unittest.skipIf(not TEST_CUDA, 'no CUDA')
    def test_as_strided_cuda(self):
        self._test_as_strided('cuda')

    def test_no_jit_tracer_support(self):
        def foo(x):
            return torch.full(x.shape, 2., names=('N',))

        with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
            x = torch.randn(3)
            torch.jit.trace(foo, example_inputs=x)

        def bar(x):
            return x.select('N', 1)

        with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
            x = torch.randn(3)
            torch.jit.trace(bar, example_inputs=x)

    def test_no_jit_script_support(self):
        @torch.jit.script
        def foo(x):
            return x + 1

        with self.assertRaisesRegex(RuntimeError, 'NYI'):
            foo(torch.randn(2, 3, names=('N', 'C')))

        @torch.jit.ignore
        def add_names(x):
            x.names = ('N', 'C')

        @torch.jit.script
        def return_named_tensor(input):
            add_names(input)
            return input

        with self.assertRaisesRegex(RuntimeError, "NYI"):
            return_named_tensor(torch.randn(1, 1))

    def test_align_to(self):
        # trivial
        tensor = create('N:3')
        output = tensor.align_to('N')
        self.assertEqual(output.names, ['N'])
        self.assertEqual(output.shape, [3])

        # unsqueeze behavior
        tensor = create('N:3')
        output = tensor.align_to('N', 'D')
        self.assertEqual(output.names, ['N', 'D'])
        self.assertEqual(output.shape, [3, 1])

        # transpose behavior
        tensor = create('N:3,C:2')
        output = tensor.align_to('C', 'N')
        self.assertEqual(output.names, ['C', 'N'])
        self.assertEqual(output.shape, [2, 3])

        # unsqueeze / transpose
        tensor = create('C:2,N:3,H:5')
        output = tensor.align_to('N', 'H', 'W', 'C')
        self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
        self.assertEqual(output.shape, [3, 5, 1, 2])

        # All input dimensions must be named
        with self.assertRaisesRegex(RuntimeError, "All input dims must be named. Found unnamed dim at index 0"):
            create('None:2,C:3').align_to('N', 'C')

        # not enough names
        with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'N'"):
            create('N:2,C:3').align_to('C')

        # names not found
        with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'C'"):
            create('N:2,C:3').align_to('D', 'N')

    def test_align_to_ellipsis(self):
        tensor = create('N:7,H:3,W:5,C:2')

        # ... = ['N', 'H', 'W', 'C']
        output = tensor.align_to('...')
        self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
        self.assertEqual(output.shape, [7, 3, 5, 2])

        # ... = ['H', 'C']
        output = tensor.align_to('...', 'W', 'N')
        self.assertEqual(output.names, ['H', 'C', 'W', 'N'])
        self.assertEqual(output.shape, [3, 2, 5, 7])

        # ... = ['N', 'W']
        output = tensor.align_to('H', 'C', '...')
        self.assertEqual(output.names, ['H', 'C', 'N', 'W'])
        self.assertEqual(output.shape, [3, 2, 7, 5])

        # ... = ['H', 'C']
        output = tensor.align_to('W', '...', 'N')
        self.assertEqual(output.names, ['W', 'H', 'C', 'N'])
        self.assertEqual(output.shape, [5, 3, 2, 7])

        # ... = []
        output = tensor.align_to('N', '...', 'C', 'D', 'H', 'W')
        self.assertEqual(output.names, ['N', 'C', 'D', 'H', 'W'])
        self.assertEqual(output.shape, [7, 2, 1, 3, 5])

        # Input tensor partially named
        partially_named = create('None:2,None:3,None:5,C:7')
        output = partially_named.align_to('C', '...')
        self.assertEqual(output.names, ['C', None, None, None])
        self.assertEqual(output.shape, [7, 2, 3, 5])

        with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"):
            partially_named.align_to('C', None, '...')

        # Input order partially named
        with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"):
            tensor.align_to('...', 'N', None)

        # Input order duplicate names
        with self.assertRaisesRegex(RuntimeError, "duplicate names"):
            tensor.align_to('...', 'N', 'N')

    def test_align_as(self):
        # align_as calls align_to internally. align_to has pretty substantial tests,
        # so just test some basic things here.
        tensor = create('C:2,N:3,H:5')
        other = create('N:1,H:1,W:1,C:1')
        output = tensor.align_as(other)
        self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
        self.assertEqual(output.shape, [3, 5, 1, 2])

    @unittest.skip("Not implemented yet")
    def test_align_tensors_two_inputs(self):
        def _test(tensor_namedshape, align_names, expected_sizes, expected_error):
            tensor_names, tensor_sizes = tensor_namedshape
            tensor = torch.empty(*tensor_sizes, names=tensor_names)
            other = torch.empty([1] * len(align_names), names=align_names)
            if expected_error is not None:
                with self.assertRaisesRegex(RuntimeError, expected_error):
                    torch.align_tensors(tensor, other)
                return

            output, _ = torch.align_tensors(tensor, other)
            self.assertEqual(output.shape, expected_sizes)
            self.assertEqual(output.names, align_names)

        Case = namedtuple('Case', [
            'tensor_namedshape',
            'align_names',
            'expected_sizes',
            'expected_error',
        ])

        tests = [
            # basic tests
            Case(tensor_namedshape=(['C'], [2]),
                 align_names=['C'],
                 expected_sizes=[2],
                 expected_error=None),
            Case(tensor_namedshape=(['C'], [2]),
                 align_names=['D'],
                 expected_sizes=None,
                 expected_error='not a subsequence'),

            # single-dim alignment test
            Case(tensor_namedshape=(['C'], [2]),
                 align_names=['N', 'C'],
                 expected_sizes=[1, 2],
                 expected_error=None),
            Case(tensor_namedshape=[['N'], [2]],
                 align_names=['N', 'C'],
                 expected_sizes=[2, 1],
                 expected_error=None),

            # multiple dim alignment test
            Case(tensor_namedshape=[['N', 'C'], [2, 3]],
                 align_names=['N', 'H', 'C', 'W'],
                 expected_sizes=[2, 1, 3, 1],
                 expected_error=None),
            Case(tensor_namedshape=[['N', 'C'], [2, 3]],
                 align_names=['C', 'H', 'N', 'W'],
                 expected_sizes=None,
                 expected_error='not a subsequence'),

            # scalar tensor tests
            Case(tensor_namedshape=[None, [[]]],
                 align_names=['N', 'C'],
                 expected_sizes=[1, 1],
                 expected_error=None),
            Case(tensor_namedshape=[[], [[]]],
                 align_names=[None, None],
                 expected_sizes=[1, 1],
                 expected_error=None),

            # unnamed tensor tests
            Case(tensor_namedshape=[None, [2, 3]],
                 align_names=[None, None],
                 expected_sizes=[2, 3],
                 expected_error=None),
            Case(tensor_namedshape=[None, [2, 3]],
                 align_names=[None, None, None],
                 expected_sizes=[1, 2, 3],
                 expected_error=None),
            Case(tensor_namedshape=[None, [2]],
                 align_names=['N'],
                 expected_sizes=None,
                 expected_error='not a subsequence'),

            # unnamed dim alignment tests
            Case(tensor_namedshape=[[None], [2]],
                 align_names=['N', None],
                 expected_sizes=[1, 2],
                 expected_error=None),
            Case(tensor_namedshape=[[None], [2]],
                 align_names=['N', None, None, None],
                 expected_sizes=[1, 1, 1, 2],
                 expected_error=None),
            Case(tensor_namedshape=[['N'], [2]],
                 align_names=['N', None, None, None],
                 expected_sizes=[2, 1, 1, 1],
                 expected_error=None),
            Case(tensor_namedshape=[[None, 'N', None], [2, 3, 5]],
                 align_names=[None, None, 'N', None],
                 expected_sizes=[1, 2, 3, 5],
                 expected_error=None),
            Case(tensor_namedshape=[[None], [2]],
                 align_names=[None, 'N'],
                 expected_sizes=None,
                 expected_error='absolute position from the right'),
            Case(tensor_namedshape=[None, [2]],
                 align_names=[None, 'N'],
                 expected_sizes=None,
                 expected_error='absolute position from the right'),
            Case(tensor_namedshape=[[None, 'N'], [2, 3]],
                 align_names=[None, 'C', 'N'],
                 expected_sizes=None,
                 expected_error='absolute position from the right'),
        ]

        for test in tests:
            _test(*test)

    @unittest.skip("Not implemented yet")
    def test_align_tensors(self):
        def reference_fn(*tensors):
            longest_names = tensors[0].names
            for tensor in tensors:
                if len(tensor.names) > len(longest_names):
                    longest_names = tensor.names
            return [tensor.align_to(*longest_names) for tensor in tensors]

        x = torch.empty(1, 1, names=('N', 'H'))
        y = torch.empty(2, 3, 5, names=('N', 'C', 'H'))
        z = torch.empty(2, names=('N',))
        output = torch.align_tensors(x, y, z)
        expected_tensors = reference_fn(x, y, z)
        for tensor, expected in zip(output, expected_tensors):
            self.assertTensorDataAndNamesEqual(tensor, expected)

    def test_mm(self):
        for device in get_all_device_types():
            self._test_name_inference(
                torch.mm, device=device,
                args=(create('N:3,C:2'), create('W:2,H:5')),
                expected_names=('N', 'H'))

            # left arg is unnamed
            self._test_name_inference(
                torch.mm, device=device,
                args=(create('3,2'), create('W:2,H:5')),
                expected_names=(None, 'H'))

            # right arg is unnamed
            self._test_name_inference(
                torch.mm, device=device,
                args=(create('N:3,C:2'), create('2,5')),
                expected_names=('N', None))

            # out=
            self._test_name_inference(
                out_fn(torch.mm), device=device,
                args=(create('0'), create('N:3,C:2'), create('W:2,H:5')),
                expected_names=('N', 'H'))

            self._test_name_inference(
                torch.mm, device=device,
                args=(create('N:3,C:2'), create('W:2,N:5')),
                maybe_raises_regex='with duplicate names')

    def test_expand(self):
        for device in get_all_device_types():
            self._test_name_inference(
                Tensor.expand, device=device,
                args=(create('D:1'), [3]), expected_names=('D',))

            self._test_name_inference(
                Tensor.expand, device=device,
                args=(create('H:3,W:2'), [10, 3, 3, 2]),
                expected_names=(None, None, 'H', 'W'))

            self._test_name_inference(
                Tensor.expand, device=device,
                args=(create('3, 2'), [10, 3, 3, 2]),
                expected_names=(None, None, None, None))

    def test_addmm(self):
        for device in get_all_device_types():
            # full names
            self._test_name_inference(
                torch.addmm, device=device,
                args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')),
                expected_names=('N', 'H'))

            # no name on bias
            self._test_name_inference(
                torch.addmm, device=device,
                args=(create('3,5'), create('N:3,C:2'), create('W:2,H:5')),
                expected_names=('N', 'H'))

            # partially named bias
            self._test_name_inference(
                torch.addmm, device=device,
                args=(create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')),
                expected_names=('N', 'H'))

            # out=
            self._test_name_inference(
                out_fn(torch.addmm), device=device,
                args=(create('0'), create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')),
                expected_names=('N', 'H'))

            # inplace
            self._test_name_inference(
                torch.Tensor.addmm_, device=device,
                args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')),
                expected_names=('N', 'H'))

            self._test_name_inference(
                torch.addmm, device=device,
                args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,N:5')),
                maybe_raises_regex='with duplicate names')

    def test_bmm(self):
        for device in get_all_device_types():
            # full names
            self._test_name_inference(
                torch.bmm, device=device,
                args=(create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
                expected_names=('N', 'A', 'B'))

            # no name on left tensor
            self._test_name_inference(
                torch.bmm, device=device,
                args=(create('7,3,2'), create('N:7,A:2,B:5')),
                expected_names=('N', None, 'B'))

            # no name on right tensor
            self._test_name_inference(
                torch.bmm, device=device,
                args=(create('N:7,A:3,B:2'), create('7,2,5')),
                expected_names=('N', 'A', None))

            # out=
            self._test_name_inference(
                out_fn(torch.bmm), device=device,
                args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
                expected_names=('N', 'A', 'B'))

            # duplicate names after mm
            self._test_name_inference(
                torch.bmm, device=device,
                args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')),
                maybe_raises_regex='with duplicate names')

            # matching error (batch dimensions must be alignable)
            self._test_name_inference(
                torch.bmm, device=device,
                args=(create('N:3,A:3,B:3'), create('M:3,A:3,B:3')),
                maybe_raises_regex='do not match')

            # misalignment (batch dimension is getting contracted)
            self._test_name_inference(
                torch.bmm, device=device,
                args=(create('N:3,A:3,B:3'), create('None:3,N:3,B:3')),
                maybe_raises_regex='misaligned')

    def test_matmul(self):
        for device in get_all_device_types():
            # input tensors are less than 1D
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create(''), create('A:2')),
                maybe_raises_regex='at least 1D')
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('A:2'), create('')),
                maybe_raises_regex='at least 1D')

            # 1D @ 1D
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('A:2'), create('B:2')),
                expected_names=[])

            # ND @ 1D
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('A:3,C:2'), create('B:2')),
                expected_names=['A'])
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('A:5,C:3,D:2'), create('B:2')),
                expected_names=['A', 'C'])

            # 1D @ ND
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('C:2'), create('A:2,B:3')),
                expected_names=['B'])
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('C:2'), create('A:3,B:2,D:5')),
                expected_names=['A', 'D'])

            # 2D @ 2D
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('A:3,B:2'), create('A:2,B:3')),
                expected_names=['A', 'B'])
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('A:3,B:2'), create('B:2,A:5')),
                maybe_raises_regex='with duplicate names')

            # ND @ ND where N >= 2
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('C:5,A:3,B:2'), create('A:2,B:3')),
                expected_names=['C', 'A', 'B'])
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('C:5,A:3,B:2'), create('None:1,A:2,B:3')),
                expected_names=['C', 'A', 'B'])
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('C:5,A:3,B:2'), create('None:2,None:1,A:2,B:3')),
                expected_names=[None, 'C', 'A', 'B'])

            # out=
            self._test_name_inference(
                out_fn(torch.matmul), device=device,
                args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
                expected_names=('N', 'A', 'B'))

            # duplicate names after mm
            self._test_name_inference(
                torch.bmm, device=device,
                args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')),
                maybe_raises_regex='with duplicate names')

            # misalignment (batch dimension is getting contracted)
            self._test_name_inference(
                torch.matmul, device=device,
                args=(create('N:3,A:3,B:3'), create('A:3,N:3,B:3')),
                maybe_raises_regex='do not match')

    def test_mv(self):
        for device in get_all_device_types():
            self._test_name_inference(
                torch.mv, device=device,
                args=(create('N:3,C:2'), create('W:2')),
                expected_names=('N',))

            # left arg is unnamed
            self._test_name_inference(
                torch.mv, device=device,
                args=(create('3,2'), create('W:2')),
                expected_names=(None,))

            # right arg is unnamed
            self._test_name_inference(
                torch.mv, device=device,
                args=(create('N:3,C:2'), create('2')),
                expected_names=('N',))

            # out=
            self._test_name_inference(
                out_fn(torch.mv), device=device,
                args=(create('0'), create('N:3,C:2'), create('W:2')),
                expected_names=('N',))

    def test_addmv(self):
        for device in get_all_device_types():
            # full names
            self._test_name_inference(
                torch.addmv, device=device,
                args=(create('N:3'), create('N:3,C:2'), create('H:2')),
                expected_names=['N'])

            # no name on bias
            self._test_name_inference(
                torch.addmv, device=device,
                args=(create('3'), create('N:3,C:2'), create('H:2')),
                expected_names=('N',))

            # out=
            self._test_name_inference(
                out_fn(torch.addmv), device=device,
                args=(create('0'), create('N:3'), create('N:3,C:2'), create('H:2')),
                expected_names=('N',))

            # inplace
            self._test_name_inference(
                torch.Tensor.addmv_, device=device,
                args=(create('N:3'), create('N:3,C:2'), create('H:2')),
                expected_names=('N',))

    def test_autograd_ignores_names(self):
        # sigmoid forward is supported by named tensors, but sigmoid_backward
        # is not (see native_functions.yaml). Test that autograd ignores names
        # and that the sigmoid_backward succeeds.
        x = torch.randn(3, 3, names=('N', 'C'), requires_grad=True)
        x.sigmoid().sum().backward()

    def test_tensor_grad_is_unnamed(self):
        x = torch.randn(3, 3, names=(None, None), requires_grad=True)
        y = torch.randn(3, 3, names=('N', 'C'), requires_grad=True)
        (x * y).sum().backward()

        # Check that names weren't propagated
        self.assertEqual(y.grad.names, [None, None])
        self.assertEqual(x.grad.names, [None, None])

    def test_autograd_warns_named_grad(self):
        base = torch.randn(3, 3, names=('N', 'C'))
        named_grad = base.clone()
        base.requires_grad_()

        with warnings.catch_warnings(record=True) as warns:
            # Cause all warnings to always be triggered.
            warnings.simplefilter("always")
            base.clone().backward(named_grad)
            self.assertEqual(len(warns), 1)
            self.assertTrue(
                str(warns[0].message).startswith('Autograd was passed a named grad tensor'))

    def test_nyi_dimname_overload_msg(self):
        x = torch.randn(3, 3)
        with self.assertRaisesRegex(RuntimeError, "squeeze: You passed a dimname"):
            x.squeeze_("N")

    def test_dot(self):
        for device in get_all_device_types():
            # torch.dot ignores the names of both tensors
            self._test_name_inference(
                torch.dot, device=device,
                args=(create('C:2'), create('W:2')),
                expected_names=[])

    def test_comparison_ops(self):
        for device in get_all_device_types():
            a = torch.randn(3, 3, names=('N', 'C'), device=device)
            b = torch.randn(3, 3, names=('N', 'C'), device=device)
            scalar = torch.randn([], device=device)

            self.assertEqual((a == b).names, ['N', 'C'])
            self.assertEqual((a != b).names, ['N', 'C'])
            self.assertEqual((a > b).names, ['N', 'C'])
            self.assertEqual((a < b).names, ['N', 'C'])
            self.assertEqual((a >= b).names, ['N', 'C'])
            self.assertEqual((a <= b).names, ['N', 'C'])

            self.assertEqual((a == 1).names, ['N', 'C'])
            self.assertEqual((a != 1).names, ['N', 'C'])
            self.assertEqual((a > 1).names, ['N', 'C'])
            self.assertEqual((a < 1).names, ['N', 'C'])
            self.assertEqual((a >= 1).names, ['N', 'C'])
            self.assertEqual((a <= 1).names, ['N', 'C'])

            self.assertEqual((a == scalar).names, ['N', 'C'])
            self.assertEqual((a != scalar).names, ['N', 'C'])
            self.assertEqual((a > scalar).names, ['N', 'C'])
            self.assertEqual((a < scalar).names, ['N', 'C'])
            self.assertEqual((a >= scalar).names, ['N', 'C'])
            self.assertEqual((a <= scalar).names, ['N', 'C'])

            res = torch.empty(3, 3, dtype=torch.bool, device=device)
            torch.eq(a, b, out=res)
            self.assertEqual(res.names, ['N', 'C'])
            torch.ne(a, b, out=res)
            self.assertEqual(res.names, ['N', 'C'])
            torch.lt(a, b, out=res)
            self.assertEqual(res.names, ['N', 'C'])
            torch.gt(a, b, out=res)
            self.assertEqual(res.names, ['N', 'C'])
            torch.le(a, b, out=res)
            self.assertEqual(res.names, ['N', 'C'])
            torch.ge(a, b, out=res)
            self.assertEqual(res.names, ['N', 'C'])

            res = torch.isnan(a)
            self.assertEqual(res.names, ['N', 'C'])

            res = torch.isinf(a)
            self.assertEqual(res.names, ['N', 'C'])

    def test_support_device_named_grad(self):
        named_tensor = torch.randn(3, 3, device='meta')
        with self.assertRaisesRegex(RuntimeError, 'NYI: named tensors only support CPU, CUDA'):
            named_tensor.rename_('N', 'C')
            named_tensor.names = ['N', 'C']
            named_tensor = torch.randn(3, 3, device='meta', names=['N', 'C'])


if __name__ == '__main__':
    run_tests()
