# Owner(s): ["module: sparse"]
import itertools
import random
import unittest

import torch
from torch import nn
import torch.nn.functional as F

from torch.sparse import (
    SparseSemiStructuredTensor,
    SparseSemiStructuredTensorCUSPARSELT,
    SparseSemiStructuredTensorCUTLASS,
    to_sparse_semi_structured,
)

from torch.sparse._semi_structured_conversions import (
    sparse_semi_structured_from_dense_cutlass,
    _sparse_semi_structured_tile,
    _compute_compressed_swizzled_bitmask,
)

from torch.testing import make_tensor
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_device_type import (
    dtypes,
    instantiate_device_type_tests,
)

from torch.testing._internal.common_dtype import all_types_and_complex
import torch._dynamo.test_case
from torch.testing._internal.common_utils import (
    parametrize,
    run_tests,
    subtest,
    TestCase,
    TEST_WITH_ROCM,
    IS_WINDOWS,
)

import pytest

from torch.utils._triton import has_triton

SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict()

_IS_SM8X = False
_IS_SM9X = False

if torch.cuda.is_available():
    _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
    _IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9

    # CUTLASS kernels only work for Ampere
    if _IS_SM8X:
        SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS

    # add cuSPASRELt tests if available
    if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X):
        SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT

inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8)
training_dtypes = dtypes(torch.float16, torch.bfloat16)
parametrize_backends = parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)

atol_rtol_kw = {
    torch.float16: {
        "rtol": 1e-3,
        "atol": 1e-3,
    },
    torch.bfloat16: {
        "rtol": 1e-1,
        "atol": 1e-1,
    },
}

def sparse24_largest_mask_2d(original):
    sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original)
    return sparse.to_dense().bool()

def sparsify24_dense(original):
    return sparse24_largest_mask_2d(original) * original

def rand_sparse_semi_structured_mask(
    r, c, dtype=torch.float16, device="cuda", choice=None
):
    """
    This function returns a 1:2 sparse matrix of size (r, c).
    Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
    """

    choices = [[0, 1], [1, 0]]
    mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]

    return (
        torch.tensor(mask_entries, dtype=dtype, device=device)
        .reshape(r, c)
        .contiguous()
    )

def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
    pattern = '2by4' if dtype != torch.float32 else '1by2'
    if pattern == '1by2':
        ksparse = 2
        choices = [
            [0, 1],
            [1, 0]
        ]
    elif pattern == '2by4':
        ksparse = 4
        choices = [
            [1, 1, 0, 0],
            [1, 0, 1, 0],
            [1, 0, 0, 1],
            [0, 1, 1, 0],
            [0, 1, 0, 1],
            [0, 0, 1, 1]
        ]
    mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)]
    mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
    dense = make_tensor(r, c, dtype=dtype, device=device)
    dense[dense == 0] = 1  # To prevent zeros except where mask applied.
    dense = dense.masked_fill(~mask, 0)
    return dense


def rand_sparse_semi_structured_all_patterns(r, c, dtype, device):
    pattern = '2by4' if dtype != torch.float32 else '1by2'
    if pattern == '1by2':
        ksparse = 2
        choices = [
            [[0, 0], [0, 1]],
            [[0, 1], [0, 1]],
            [[1, 0], [1, 0]],
            [[1, 1], [1, 0]]
        ]
    elif pattern == '2by4':
        ksparse = 4
        choices = [
            [[0, 0, 0, 0], [0, 0, 1, 1]],
            [[0, 0, 0, 1], [0, 0, 1, 1]],
            [[0, 0, 1, 0], [0, 0, 1, 1]],
            [[0, 0, 1, 1], [0, 0, 1, 1]],
            [[0, 1, 0, 0], [0, 1, 1, 0]],
            [[0, 1, 0, 1], [0, 1, 0, 1]],
            [[0, 1, 1, 0], [0, 1, 1, 0]],
            [[0, 1, 1, 1], [0, 1, 0, 1]],
            [[1, 0, 0, 0], [1, 0, 1, 0]],
            [[1, 0, 0, 1], [1, 0, 0, 1]],
            [[1, 0, 1, 0], [1, 0, 1, 0]],
            [[1, 0, 1, 1], [1, 0, 0, 1]],
            [[1, 1, 0, 0], [1, 1, 0, 0]],
            [[1, 1, 0, 1], [1, 1, 0, 0]],
            [[1, 1, 1, 0], [1, 1, 0, 0]],
            [[1, 1, 1, 1], [1, 1, 0, 0]],
        ]
    mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // ksparse)]

    COL_INV, COL_VAL = 0, 1
    mask_entries_inv = [choices[i][COL_INV] for i in mask_rows]
    mask_entries_val = [choices[i][COL_VAL] for i in mask_rows]
    mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device)
    mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device)
    dense = make_tensor(r, c, dtype=dtype, device=device)
    dense[dense == 0] = 1   # To prevent zeros except where mask below applied.
    dense_inv = dense.masked_fill(~mask_inv, 0)
    dense_val = dense_inv.masked_fill(~mask_val, 0)

    return dense_inv, dense_val


class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):

    def setUp(self):
        if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0:
            self.skipTest('semi-structured sparsity has no available backend!')
        super().setUp()

    def tearDown(self):
        super().tearDown()

    @staticmethod
    def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
        """
        Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
        We expect:
            (1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()`
            (2) Inductor should fuse the .contiguous() call into the relu
        """

        class Model(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = nn.Linear(128, 128)

            def forward(self, x):
                x = self.linear(x)
                x = x.contiguous()
                return torch.nn.functional.relu(x)

        input = torch.rand(dense_input_shape, device="cuda").half()
        model = Model().eval().cuda().half()
        mod_linear = model.linear
        m, n = mod_linear.weight.shape
        mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
        # set masked weight
        mod_linear.weight = nn.Parameter(mod_linear.weight * mask)

        dense_result = model(input)
        mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight))
        sparse_result = model(input)

        model = torch.compile(model, backend="inductor", fullgraph=True)
        sparse_compile_result = model(input)

        # test that sparse_compile_result and dense_result are numerically close
        torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
        # assert sparse and sparse_compile have the same strides,
        # as meta registrations may return contiguous tensors when the output is transposed
        # https://github.com/pytorch/pytorch/pull/114477
        assert sparse_result.stride() == sparse_compile_result.stride()

    @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
    @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
    def test_mlp_contiguous_relu_compile_cusparselt(self):
        """
        test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
        """
        for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
            SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape)


    @unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine")
    @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
    def test_mlp_contiguous_relu_compile_cutlass(self):
        """
        test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
        """
        for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
            SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)


    @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
    @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
    def test_sp24_compile(self) -> None:
        x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
        e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)

        def fn(x, e):
            y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
            y = y.t()
            return x @ y

        # Eager
        output = fn(x, e)
        output.backward(output)
        # Torch compile
        output = torch.compile(fn)(x, e)
        output.backward(output)

class TestSparseSemiStructured(TestCase):

    def setUp(self):
        if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0:
            self.skipTest('semi-structured sparsity has no available backend!')
        if IS_WINDOWS:
            self.skipTest("torch.compile not supported on windows")

    @inference_dtypes
    @parametrize_backends
    def test_to_sparse_semi_structured(self, dtype, backend):
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
        A_sparse = to_sparse_semi_structured(A)

        assert A.shape == A_sparse.shape
        assert A.device == A_sparse.device
        assert A.dtype == A_sparse.dtype

        assert isinstance(A, torch.Tensor)
        assert isinstance(A_sparse, SparseSemiStructuredTensor)

    @inference_dtypes
    @parametrize_backends
    @parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
    def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
        """
        Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
        """
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
        A_sparse = to_sparse_semi_structured(A)

        B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)

        # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
        if dtype is torch.int8:
            if backend == "cutlass":
                with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
                    sparse_result = torch.mm(A_sparse, B)
            else:
                with self.assertRaisesRegex(RuntimeError,
                                            "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
                    sparse_result = torch.mm(A_sparse, B)
        else:
            dense_result = torch.mm(A, B)
            sparse_result = torch.mm(A_sparse, B)
            torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

    @inference_dtypes
    @parametrize_backends
    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
    def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
        """
        Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
        and will throw an error for int8 + padding
        """
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
        A_sparse = to_sparse_semi_structured(A)

        B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)

        # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
        if dtype is torch.int8 and dense_input_shape in {(1, 128)}:
            # padding with int8 throws an error because transposing B yields a contiguous output
            # and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
            if backend == "cutlass":
                with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
                    sparse_result = torch.mm(A_sparse, B.t())
            else:
                with self.assertRaisesRegex(RuntimeError,
                                            "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
                    sparse_result = torch.mm(A_sparse, B.t())
        elif dtype is torch.int8:
            # test transpose
            dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
            sparse_result = torch.mm(A_sparse, B.t())
            torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
        else:
            # test transpose
            dense_result = torch.mm(A, B.t())
            sparse_result = torch.mm(A_sparse, B.t())
            torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

    @inference_dtypes
    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
    @parametrize_backends
    def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
        """
        Ensure torch.mm(A_sparse.t(), B) throws error
        """
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")
        A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
        A_sparse = to_sparse_semi_structured(A)

        B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)

        with self.assertRaisesRegex(
            NotImplementedError,
            r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
        ):
            torch.mm(A_sparse.t(), B)

    @inference_dtypes
    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
    @parametrize_backends
    def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
        """
        Ensure torch.mm(A, B_sparse.t()) is correct
        """
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")
        B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
        B_sparse = to_sparse_semi_structured(B)

        A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)

        # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
        if dtype is torch.int8:
            dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
            sparse_result = torch.mm(A, B_sparse.t())
        else:
            dense_result = torch.mm(A, B.t())
            sparse_result = torch.mm(A, B_sparse.t())

        torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

    @inference_dtypes
    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
    @parametrize_backends
    def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
        """
        Ensure torch.mm(A, B_sparse) throws error
        """
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")
        B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
        B_sparse = to_sparse_semi_structured(B)

        A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)

        with self.assertRaisesRegex(
            NotImplementedError,
            r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
        ):
            sparse_result = torch.mm(A, B_sparse)

    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
    @parametrize("inference_mode", [subtest(True), subtest(False)])
    @parametrize_backends
    def test_linear(self, dense_input_shape, inference_mode, device, backend):
        """
        Test nn.Linear has the same numerics
        """
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")
        input = torch.rand((dense_input_shape), device=device).half()
        model = nn.Linear(128, 256).to(device).half()
        m, n = model.weight.shape
        mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
        # set masked weight
        model.weight = nn.Parameter(model.weight * mask)

        dense_result = model(input)

        model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))

        if inference_mode:
            with torch.inference_mode():
                sparse_result = model(input)
        else:
            sparse_result = model(input)

        torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

    @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
    @parametrize_backends
    def test_mlp(self, device, dense_input_shape, backend):
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        input = torch.rand(dense_input_shape, device=device).half()
        model = (
            nn.Sequential(
                nn.Linear(128, 256),
                nn.Linear(256, 128),
            )
            .half()
            .to(device)
        )

        for i in range(2):
            m, n = model[i].weight.shape
            mask = rand_sparse_semi_structured_mask(
                m, n, device=device, dtype=torch.bool
            )
            # set masked weight
            model[i].weight = nn.Parameter(model[i].weight * mask)

        dense_result = model(input)

        for i in range(2):
            model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight))

        sparse_result = model(input)

        torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

    @parametrize_backends
    def test_values(self, backend):
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")
        A = rand_sparse_semi_structured_mask(128, 128)
        A_sparse = to_sparse_semi_structured(A)
        assert A_sparse.values().shape == (128, 64)
        assert (A_sparse.values() == 1).all()

    @parametrize_backends
    def test_indices(self, backend):
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")
        A = rand_sparse_semi_structured_mask(128, 128)
        A_sparse = to_sparse_semi_structured(A)
        assert A_sparse.indices().shape == (128, 8)

    @inference_dtypes
    @parametrize_backends
    def test_min_sparse_shape(self, dtype, device, backend):
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype]
        A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
        A_sparse = to_sparse_semi_structured(A)
        B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
        if dtype == torch.int8:
            dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8)
            # int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
            B_t = B.t().contiguous()
            sparse_res = torch.mm(A_sparse, B_t.t())
        else:
            dense_res = torch.mm(A, B)
            sparse_res = torch.mm(A_sparse, B)
        torch.testing.assert_close(sparse_res, dense_res, rtol=1e-3, atol=1e-3)

    @inference_dtypes
    @parametrize_backends
    def test_unsupported_shape(self, dtype, device, backend):
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")
        A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device)
        with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
            A_sparse = to_sparse_semi_structured(A)

    @dtypes(*all_types_and_complex())
    @parametrize_backends
    def test_unsupported_dtype(self, dtype, device, backend):
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")
        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)

        if dtype not in SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS:
            with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
                A_sparse = to_sparse_semi_structured(A)
        else:
            A_sparse = to_sparse_semi_structured(A)

    @parametrize_backends
    def test_unsupported_dim(self, device, backend):
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")
        A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)

        with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
            A_sparse = to_sparse_semi_structured(A)


def create_random_mask(shape) -> torch.Tensor:
    r = random.Random(0)
    mask = torch.zeros(shape, dtype=torch.bool)
    for line in range(mask.shape[0]):
        for col in range(0, mask.shape[1], 4):
            sparsity = r.choice(
                [
                    [False, False, True, True],
                    [False, True, False, True],
                    [True, False, False, True],
                    [False, True, True, False],
                    [True, False, True, False],
                    [True, True, False, False],
                ]
            )
            mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool)
    return mask

class TestSparseSemiStructuredTraining(TestCase):

    def setUp(self):
        if not _IS_SM8X:
            self.skipTest("SparseSemiStructuredTensor training only supported on SM8x (Ampere)")

        if IS_WINDOWS:
            self.skipTest('CUTLASS not supported on windows')


    @training_dtypes
    def test_prune_dense_static_sort(self, dtype) -> None:
        # Ideally we would like to clone and compare, but that won't work because the sorting order will be different
        # instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
        dense = torch.randn(128, 128, device="cuda", dtype=dtype)
        pruned = _sparse_semi_structured_tile(dense)

        # CUTLASS
        reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy")
        torch.testing.assert_close(pruned, reference_cutlass.to_dense())

        packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
        packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
        meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride())
        meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride())
        compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned)
        compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape,
                                                                             reference_cutlass.compressed_swizzled_bitmask.stride())
        cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape,
                                                    packed_cutlass,
                                                    meta_cutlass,
                                                    packed_t_cutlass,
                                                    meta_t_cutlass,
                                                    compressed_swizzled_bitmask)
        torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense())

        # CUSPARSELT
        reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned,
                                                                                            algorithm="largest_abs_values_greedy")
        torch.testing.assert_close(pruned, reference_cusparselt.to_dense())

        packed_cusparselt = torch._cslt_compress(pruned)
        packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
        cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape,
                                                          packed_cusparselt,
                                                          None,
                                                          packed_t_cusparselt,
                                                          None,
                                                          compressed_swizzled_bitmask)
        torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense())



    @training_dtypes
    @parametrize_backends
    def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
        inp = torch.tensor(
            [[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
            device="cuda",
            dtype=dtype,
        )
        inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1)
        sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy")

        mask = sInp.to_dense() / inp
        assert mask[:4, :4].int().tolist() == [
            [1, 1, 0, 0],
            [0, 1, 1, 0],
            [0, 0, 1, 1],
            [1, 0, 0, 1],
        ]

    @training_dtypes
    def test_gemm(self, dtype) -> None:
        M, N, K = 32, 32, 64
        a = torch.randn([M, K], device="cuda", dtype=dtype)
        b = torch.randn([K, N], device="cuda", dtype=dtype)
        mask = rand_sparse_semi_structured_mask(M, K, dtype=torch.bool)

        a.masked_fill_(~mask, 0)

        a_sparse = to_sparse_semi_structured(a)

        masked_a = a * mask
        ref_out = masked_a @ b
        sp24_out = a_sparse @ b
        torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype])


    @training_dtypes
    @parametrize_backends
    def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
        M, N = 128, 256
        # Construct x to make sure we always have exactly 8 elements per 4x4 tile
        a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :]
        a = a.repeat(M // 8, N // 8)
        assert a.shape == (M, N)
        a = a.cuda().to(dtype)
        b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype)

        a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a)

        mask_dense = sparse24_largest_mask_2d(a).to(dtype)

        if backend == "cutlass":
            assert isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS)
            (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(
                mask_dense, use_cutlass=True)

            sparse_mask = SparseSemiStructuredTensorCUTLASS(
                mask_dense.shape,
                packed=packed,
                meta=meta,
                packed_t=packed_t,
                meta_t=meta_t,
                compressed_swizzled_bitmask=bitmask,
            )
            torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta)

        ref_gemm = (mask_dense * a) @ b
        pack_gemm = a_sparse @ b
        torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])

    @training_dtypes
    def test_pack_both_ways_id(self, dtype) -> None:
        N = 512
        torch.manual_seed(0)
        a = torch.randn([N, N], dtype=dtype, device="cuda")
        b = torch.eye(N, dtype=dtype, device="cuda")

        packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[
            :4
        ]
        # Heuristic to ensure we pack the same values
        torch.testing.assert_close(
            packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum()
        )

        mask_dense = sparse24_largest_mask_2d(a.to(dtype))

        ref_gemm = mask_dense * a
        # Test A@B
        pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t()
        max_diff = (ref_gemm - pack_gemm).abs().argmax()
        torch.testing.assert_close(
            ref_gemm, pack_gemm,
            **atol_rtol_kw[dtype]
        ), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
        # Test A.t@B
        pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
        max_diff = (ref_gemm - pack_gemm).abs().argmax()

        torch.testing.assert_close(
            ref_gemm, pack_gemm,
            **atol_rtol_kw[dtype]
        ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"

    @training_dtypes
    def test_pack_both_ways_edge_case1(self, dtype) -> None:
        # In this case, the heuristic will keep 7 values out of 16
        # instead of 8. let's see how the kernel handles this
        quad = torch.tensor(
            [
                [2, -1, -2, -3],  # Should be packed as `2 <null>`
                [-1, 8, -1, 6],
                [-1, -1, 4, 5],
                [-1, 3, 7, -1],
            ],
            dtype=dtype,
            device="cuda",
        )
        a = torch.randn([32, 64], dtype=dtype, device="cuda")
        a[:4, :4] = quad
        packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4]
        # Check first line in A
        assert packed[0, 0].item() == 2
        assert packed[0, 1].item() == 0
        # And first column in A.t
        assert packed_t[0, 0].item() == 2
        assert packed_t[0, 1].item() == 0

    @training_dtypes
    def test_sp24_apply(self, dtype) -> None:
        M, N = 256, 1024
        x = torch.randn([M, N], dtype=dtype, device="cuda")
        (
            packed,
            meta,
            packed_t,
            meta_t,
            bitmask,
        ) = torch._sparse_semi_structured_tile(x)
        packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
        torch.testing.assert_close(packed, packed2)
        torch.testing.assert_close(packed_t, packed_t2)

    @training_dtypes
    def test_sp24_apply_dense(self, dtype) -> None:
        M, N = 256, 1024
        x = torch.randn([M, N], dtype=dtype, device="cuda")
        (
            packed,
            meta,
            packed_t,
            meta_t,
            bitmask,
        ) = torch._sparse_semi_structured_tile(x)

        expected = SparseSemiStructuredTensorCUTLASS(
            x.shape,
            packed=packed,
            meta=meta,
            packed_t=packed_t,
            meta_t=meta_t,
            compressed_swizzled_bitmask=bitmask,
        ).to_dense()

        packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
        sparse = SparseSemiStructuredTensorCUTLASS(
            x.shape,
            packed=packed2,
            meta=meta,
            packed_t=packed_t2,
            meta_t=meta_t,
            compressed_swizzled_bitmask=bitmask,
        )

        dense = torch._sparse_semi_structured_apply_dense(x, bitmask)

        torch.testing.assert_close(dense, expected)
        torch.testing.assert_close(sparse.to_dense(), expected)


    @training_dtypes
    def test_sp24_matmuls(self, dtype) -> None:
        M, N, K = 64, 256, 1024
        a = torch.randn([M, K], device="cuda", dtype=dtype)
        b = torch.randn([K, N], device="cuda", dtype=dtype)
        a_m = sparse24_largest_mask_2d(a)
        b_m = sparse24_largest_mask_2d(b)
        (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a)
        a_s = SparseSemiStructuredTensorCUTLASS(
            a.shape,
            packed=packed,
            meta=meta,
            packed_t=packed_t,
            meta_t=meta_t,
            compressed_swizzled_bitmask=bitmask,
        )
        (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b)
        b_s = SparseSemiStructuredTensorCUTLASS(
            b.shape,
            packed=packed,
            meta=meta,
            packed_t=packed_t,
            meta_t=meta_t,
            compressed_swizzled_bitmask=bitmask,
        )

        torch.testing.assert_close(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1.5e-1)
        torch.testing.assert_close(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1.5e-1)
        torch.testing.assert_close(
            a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1
        )
        torch.testing.assert_close(
            a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
        )

    def test_sp24_matmuls_mat_vec(self) -> None:
        a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
        b = torch.randn([128], device="cuda", dtype=torch.float16)
        a_m = sparse24_largest_mask_2d(a)
        a_s = to_sparse_semi_structured(a)

        with pytest.raises(NotImplementedError):
            torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])


    def test_sp24_matmuls_bmm(self) -> None:
        a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
        b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
        a_m = sparse24_largest_mask_2d(a)
        a_s = to_sparse_semi_structured(a)

        with pytest.raises(NotImplementedError):
            torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])

class TestSparseSemiStructuredCUTLASS(TestCase):
    """
    This contains CUTLASS specific tests for
         - torch._sparse_semi_structured_linear
    """
    def setUp(self):
        if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
            self.skipTest('CUTLASS not enabled')

    @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
    @inference_dtypes
    def test_linear_cutlass(self, device, dtype):

        def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
            weight = rand_sparse_semi_structured(m, k, dtype, device)
            input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
            bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None

            dtype_dense = torch.float32
            input_dense = input.to(dtype_dense)
            weight_dense = weight.to(dtype_dense)
            bias_dense = bias.to(dtype_dense) if add_bias else None
            output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense)
            if activation == "relu":
                relu = torch.nn.ReLU()
                output0 = relu(output0)
            elif activation == "silu":
                silu = torch.nn.SiLU()
                output0 = silu(output0)

            compressed = to_sparse_semi_structured(weight)

            weight_sparse = compressed.values()
            meta = compressed.indices()

            output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation,
                                                           out_dtype=dtype_out if dtype == torch.int8 else None)
            torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)

        if dtype == torch.float32:
            # Inputs are converted to TF32 internally for sparse GEMM,
            # so make dense GEMM to do the same for matching results.
            orig = torch.backends.cuda.matmul.allow_tf32
            torch.backends.cuda.matmul.allow_tf32 = True

        batch_shapes = [[], [3], [3, 1]]
        dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
        activations = [None, "relu", "silu"]
        rtol, atol = 1e-3, 1e-3
        if dtype == torch.bfloat16:
            rtol, atol = 5e-3, 5e-3
        elif dtype == torch.float32:
            rtol, atol = 1e-3, 75e-2
        for batch_shape, m, n, k, add_bias, activation in \
                itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations):
            if activation == "silu" and dtype == torch.int8:
                continue  # SiLU not supported for integer inputs

            m = 2 ** m * 32
            n = 2 ** n * 32
            k = 2 ** k * 128
            run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol)

        if dtype == torch.float32:
            torch.backends.cuda.matmul.allow_tf32 = orig


    @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
    @parametrize("backend", ["cutlass"])
    @inference_dtypes
    def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend):
        SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
        if backend == "cutlass" and IS_WINDOWS:
            self.skipTest("CUTLASS not supported on Windows")

        def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
            mat1 = rand_sparse_semi_structured(m, k, dtype, device)
            # mat2 transposed as int8 case supports only row-major/column-major combination
            mat2 = make_tensor((n, k), dtype=dtype, device=device).t()
            input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None

            if use_input:
                if dtype.is_floating_point:
                    alpha = 1.3
                    beta = -0.7
                else:
                    alpha = 2
                    beta = -3

            dtype_dense = torch.float32
            mat1_dense = mat1.to(dtype_dense)
            mat2_dense = mat2.to(dtype_dense)
            if not use_input:
                output0 = torch.mm(mat1_dense, mat2_dense)
            else:
                input_dense = input.to(dtype_dense)[:, None]
                output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta)

            compressed = to_sparse_semi_structured(mat1)

            mat1_sparse = compressed.values()
            mat1_meta = compressed.indices()

            if not use_input:
                output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out)
            else:
                output1 = torch._sparse_semi_structured_addmm(
                    input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out
                )
            torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)

        if dtype == torch.float32:
            # Inputs are converted to TF32 internally for sparse GEMM,
            # so make dense GEMM to do the same for matching results.
            orig = torch.backends.cuda.matmul.allow_tf32
            torch.backends.cuda.matmul.allow_tf32 = True

        dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
        rtol, atol = 1e-3, 1e-3
        if dtype == torch.bfloat16:
            rtol, atol = 5e-3, 5e-3
        elif dtype == torch.float32:
            rtol, atol = 1e-3, 75e-2
        for m, n, k, use_input in \
                itertools.product(range(3), range(3), range(3), (False, True)):
            m = 2 ** m * 32
            n = 2 ** n * 32
            k = 2 ** k * 128
            run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol)

        if dtype == torch.float32:
            torch.backends.cuda.matmul.allow_tf32 = orig


    @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
    @inference_dtypes
    def test_conversions(self, device, dtype):

        def run_test(r, c, device, dtype):
            dense_ref = rand_sparse_semi_structured(r, c, dtype, device)

            compressed = to_sparse_semi_structured(dense_ref)

            # The torch.ops.aten._to_sparse_semi_structured operator
            # uses CUTLASS to perform conversion from given dense
            # matrix to the pair of corresponding sparse and metadata
            # matrices, with the later used here as a reference to
            # compare the metadata matrix produced by conversion
            # performed by SparseSemiStructuredTensor class
            # constructor against.
            _, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref)

            meta = compressed.indices()
            torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0)

            dense = compressed.to_dense()
            torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0)

        shapes = [[32, 128], [32, 256], [64, 128], [64, 256]]
        for r, c in shapes:
            run_test(r, c, device, dtype)

    @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
    @inference_dtypes
    def test_conversions_all_patterns(self, device, dtype):
        r, c = 32, 128

        dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device)

        compressed = to_sparse_semi_structured(dense_inv)
        dense = compressed.to_dense()

        torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)



CUSPARSELT_NUM_ALG_IDS = 4
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]


class TestSparseSemiStructuredCUSPARSELT(TestCase):
    """
    This contains cuSPARSELt specific tests for
        torch._cslt_compress
        torch._cslt_sparse_mm
    """
    def setUp(self):
        if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
            self.skipTest('cuSPARSELt not enabled')

    @parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
    @parametrize("dense_input_shape", [(128, 128)])
    def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
        A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
        A_compressed = torch._cslt_compress(A)

        B = torch.rand(dense_input_shape, device=device).to(torch.int8)

        dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype)
        sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
        torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

    @unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling")
    @training_dtypes
    def test_cslt_sparse_mm_alpha(self, dtype, device):
        A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
        B = torch.ones((256, 128), device=device).to(dtype)
        alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
        bias = torch.ones(128, device=device).to(dtype)

        A_compressed = torch._cslt_compress(A)
        sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, bias=bias)

        alpha_scaled = torch.stack([alpha] * 128).t()
        dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32))
        dense_result = dense_result.to(dtype)

        torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)

    @parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
    def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
        A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
        B = torch.ones((128, 256), device=device).to(torch.int8).t()
        alpha = torch.Tensor([2**(-i) if out_dtype is not torch.int32 else 1
                              for i in range(128)]).cuda()

        A_compressed = torch._cslt_compress(A)
        sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=out_dtype).cpu()

        alpha_scaled = torch.stack([alpha] * 128).t()
        dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
        dense_result = dense_result.to(out_dtype)

        torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)

    @parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
    @inference_dtypes
    def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id):
        # alg_id=3 not supported for float32 dtype
        if dtype == torch.float32 and alg_id == 3:
            return
        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
        A_compressed = torch._cslt_compress(A)
        B = torch.ones((128, 128), device=device).to(dtype)

        A_compressed = torch._cslt_compress(A)
        sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)

        dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
        dense_result = dense_result.to(dtype)

        torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)

    @inference_dtypes
    def test_cslt_sparse_mm_search(self, device, dtype):
        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
        A_compressed = torch._cslt_compress(A)
        B = torch.ones((128, 128), device=device).to(dtype)

        A_compressed = torch._cslt_compress(A)
        alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
        # for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error
        # when setting using the last one (4)
        # in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
        # TODO Move this into the cuSPARSELt backendk
        assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)

    def test_cusparselt_backend(self):
        version = _get_torch_cuda_version()
        assert torch.backends.cusparselt.is_available()

        # CUDA 11.8 has cuSPARSELt v0.4.0 support
        if version == (11, 8):
            assert torch.backends.cusparselt.version() == 400
        # CUDA 12.1 has cuSPARSELt v0.5.2 support
        elif version == (12, 1):
            assert torch.backends.cusparselt.version() == 502
        # CUDA 12.4+ has cuSPARSELt v0.6.2 support
        elif version >= (12, 4):
            assert torch.backends.cusparselt.version() == 602
        else:
            assert torch.backends.cusparselt.version() is None

if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) > 0:
    instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
if "cutlass" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
    instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
    instantiate_device_type_tests(TestSparseSemiStructuredTraining, globals(), only_for="cuda")
if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
    instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")

if __name__ == "__main__":
    run_tests()
