# Owner(s): ["module: intel"]

import itertools
import math
import random
from functools import partial
from itertools import product

import numpy as np

import torch
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
    dtypes,
    instantiate_device_type_tests,
    precisionOverride,
)
from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase


class TestBasicGEMM(TestCase):
    def _test_addmm_addmv(
        self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None
    ):
        dtype = t.dtype
        numpy_dtype = dtype
        if dtype in {torch.bfloat16, torch.half}:
            numpy_dtype = torch.float
        if dtype.is_complex:
            alpha = 0.9 + 0.3j if alpha is None else alpha
            beta = 0.5 + 0.6j if beta is None else beta
        else:
            alpha = 1.2 if alpha is None else alpha
            beta = 0.8 if beta is None else beta
        if activation == "gelu":
            res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
        else:
            res1 = f(t, m, v, alpha=alpha, beta=beta)
        res2 = torch.full_like(res1, math.nan)
        if transpose_out:
            res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
        if activation == "gelu":
            f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
        else:
            f(t, m, v, alpha=alpha, beta=beta, out=res2)
        m.to(numpy_dtype).cpu().numpy()
        v.to(numpy_dtype).cpu().numpy()
        res3 = alpha * (
            m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()
        )
        if beta != 0:
            res3 += (beta * t).to(numpy_dtype).cpu().numpy()
        if activation == "relu":
            res3 = res3 * (res3 > 0)
        elif activation == "gelu":
            res3_t = torch.from_numpy(res3).to(dtype)
            approximate = "tanh" if t.is_cuda else "none"
            res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
            res3 = res3_t.to(numpy_dtype).cpu().numpy()
        else:
            assert activation is None, f"unsupported activation {activation}"
        res3 = torch.from_numpy(res3).to(dtype)
        self.assertEqual(res1, res2)
        self.assertEqual(res1, res3)

    def _test_addmm_impl(self, func, activation, device, dtype):
        M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
        m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device)
        m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
        self._test_addmm_addmv(func, M, m1, m2, activation=activation)

        # vector-shaped bias and beta=1 result in epilogue fusion in CUDA
        V = torch.randn(25, device="cpu", dtype=torch.float32).to(dtype).to(device)
        self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)

        # Test 0-strided
        M = (
            torch.randn(10, 1, device="cpu", dtype=torch.float32)
            .to(dtype)
            .expand(10, 25)
            .to(device)
        )
        m1 = (
            torch.randn(10, 1, device="cpu", dtype=torch.float32)
            .to(dtype)
            .expand(10, 50)
            .to(device)
        )
        m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
        self._test_addmm_addmv(func, M, m1, m2, activation=activation)

        # Test beta=0, M=nan
        M = (
            torch.full((10, 25), math.nan, device="cpu", dtype=torch.float32)
            .to(dtype)
            .to(device)
        )
        m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device)
        m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
        self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation)

        # Test transpose
        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):

            def maybe_transpose(cond, m):
                if not cond:
                    return m
                return m.t().clone(memory_format=torch.contiguous_format).t()

            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
            self._test_addmm_addmv(
                func, M, m1, m2, transpose_out=t4, activation=activation
            )

            if t1:
                # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
                self._test_addmm_addmv(
                    func,
                    V,
                    m1,
                    m2,
                    beta=1,
                    transpose_out=t4,
                    activation=activation,
                )

    @precisionOverride(
        {
            torch.float: 1e-4,
            torch.half: 1e-1,
        }
    )
    @dtypes(torch.float32, torch.half)
    def test_addmm(self, device, dtype):
        self._test_addmm_impl(torch.addmm, None, device, dtype)

    @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4})
    @dtypes(torch.bfloat16, torch.half, torch.float)
    def test_addmv(self, device, dtype):
        # have to use torch.randn(...).to(bfloat16) instead of
        # torch.randn(..., dtype=bfloat16). randn does not support
        # bfloat16 yet.
        # "*0.2" to reduce errors for low precision
        ts = [
            0.2 * torch.randn(50, device=device).to(dtype),
            0.2 * torch.randn(1, device=device).to(dtype).expand(50),
        ]
        vs = [
            0.2 * torch.randn(100, device=device).to(dtype),
            0.2
            * torch.ones(1, device=device)
            .to(dtype)
            .expand(100),  # to reduce errors for low precision
        ]
        ms = [
            # 0d
            0.2
            * torch.ones((), device=device)
            .to(dtype)
            .expand(50, 100),  # to reduce errors for low precision
            # 1d
            0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100),
            # this initialization reduces errors for low precision for broadcasted matrices
            # by making sure that intermediate and result values are exactly representable
            # in low precision type
            0.2
            * torch.randint(3, (50, 1), dtype=torch.float, device=device)
            .to(dtype)
            .expand(50, 100),
            # 2d
            0.2 * torch.randn((50, 100), device=device).to(dtype),
            0.2 * torch.randn((100, 50), device=device).to(dtype).t(),
        ]
        for m, v, t in itertools.product(ms, vs, ts):
            self._test_addmm_addmv(torch.addmv, t, m, v)
        # Test beta=0, t=nan
        t = torch.full((50,), math.nan, device=device).to(dtype)
        for m, v in itertools.product(ms, vs):
            self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)

    @dtypes(
        torch.half,
        torch.float32,
    )
    def test_mm(self, device, dtype):
        def _test_mm(n, m, p, dtype, genf):
            # helper function
            def matrixmultiply(mat1, mat2):
                n = mat1.size(0)
                m = mat1.size(1)
                p = mat2.size(1)
                dtype_ = torch.float if dtype == torch.half else dtype
                if dtype == torch.half:
                    mat1 = mat1.float()
                    mat2 = mat2.float()
                res = torch.zeros(n, p, dtype=dtype_, device=device)
                for i, j in iter_indices(res):
                    res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m))
                return res.half() if dtype == torch.half else res

            # contiguous case
            mat1 = genf(n, m)
            mat2 = genf(m, p)
            res = torch.mm(mat1, mat2)

            res2 = matrixmultiply(mat1, mat2)
            self.assertEqual(res, res2)

            # non contiguous case 1
            mat1 = genf(n, m)
            mat2 = genf(p, m).t()
            res = torch.mm(mat1, mat2)

            res2 = matrixmultiply(mat1, mat2)
            self.assertEqual(res, res2)

            # non contiguous case 2
            mat1 = genf(m, n).t()
            mat2 = genf(m, p)
            res = torch.mm(mat1, mat2)

            res2 = matrixmultiply(mat1, mat2)
            self.assertEqual(res, res2)

            # non contiguous case 3
            mat1 = genf(m, n).t()
            mat2 = genf(p, m).t()
            res = torch.mm(mat1, mat2)

            res2 = matrixmultiply(mat1, mat2)
            self.assertEqual(res, res2)

            # test with zero stride
            mat1 = genf(n, m)
            mat2 = genf(m, 1).expand(m, p)
            res = torch.mm(mat1, mat2)

            res2 = matrixmultiply(mat1, mat2)
            self.assertEqual(res, res2)

            # explicitly exercise the _out variant in torch.mm().
            # contiguous case
            mat1 = genf(n, m)
            mat2 = genf(m, p)
            res = genf(n, p)
            torch.mm(mat1, mat2, out=res)

            res2 = matrixmultiply(mat1, mat2)
            self.assertEqual(res, res2)

            # explicitly exercise the _out variant in torch.mm().
            # non contiguous case 3
            mat1 = genf(m, n).t()
            mat2 = genf(p, m).t()
            res = genf(n, p)
            torch.mm(mat1, mat2, out=res)

            res2 = matrixmultiply(mat1, mat2)
            self.assertEqual(res, res2)

        def genf_int(x, y):
            return torch.randint(0, 100, (x, y), dtype=dtype, device=device)

        def genf_bfloat(x, y):
            return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1

        def genf_float(x, y):
            return torch.randn(x, y, dtype=dtype, device=device)

        def genf_Half(x, y):
            return torch.randn(x, y, dtype=dtype, device=device)

        for n, m, p in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]:
            if (dtype == torch.int32) or (dtype == torch.int64):
                genf = genf_int
            elif dtype == torch.bfloat16:
                genf = genf_bfloat
            elif dtype == torch.half:
                genf = genf_Half
            else:
                genf = genf_float

            _test_mm(n, m, p, dtype, genf)

    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
    @dtypes(torch.float32, torch.bfloat16, torch.half)
    def test_bmm(self, device, dtype):
        batch_sizes = [1, 10]
        M, N, O = 23, 15, 12
        numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32

        def invert_perm(p):
            d = {x: i for i, x in enumerate(p)}
            return (d[0], d[1], d[2])

        def generate_inputs(num_batches):
            # transposed tensors
            for perm1, perm2 in itertools.product(
                itertools.permutations((0, 1, 2)), repeat=2
            ):
                b1 = make_tensor(
                    (num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1
                )
                b2 = make_tensor(
                    (num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1
                )
                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
                yield b1, b2
            # broadcasting tensors
            for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
                shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
                shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
                b1 = make_tensor(
                    shape1, dtype=dtype, device=device, low=-0.1, high=0.1
                ).expand(num_batches, M, N)
                b2 = make_tensor(
                    shape2, dtype=dtype, device=device, low=-0.1, high=0.1
                ).expand(num_batches, N, O)
                yield b1, b2
            # zero-sized tensors
            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
                b1 = torch.randn(shape1, dtype=dtype, device=device)
                b2 = torch.randn(shape2, dtype=dtype, device=device)
                yield b1, b2

        for num_batches in batch_sizes:
            for (b1, b2), perm3 in itertools.product(
                generate_inputs(num_batches), itertools.permutations((0, 1, 2))
            ):
                res1 = torch.bmm(b1, b2)
                res2 = (
                    torch.full(
                        (num_batches, M, O), math.nan, dtype=dtype, device=device
                    )
                    .permute(perm3)
                    .contiguous()
                    .permute(invert_perm(perm3))
                )
                torch.bmm(b1, b2, out=res2)
                expect = torch.from_numpy(
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
                ).to(device=device, dtype=dtype)
                self.assertEqual(expect, res1)
                self.assertEqual(expect, res2)

                if self.device_type == "cuda":
                    # check that mixed arguments are rejected
                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu()))
                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2))
                    self.assertRaises(
                        RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu())
                    )

    def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
        getattr(out_tensor, func + "_")(b1, b2)
        self.assertEqual(out_tensor, ref)
        res3 = out_tensor.clone()

        with self.assertWarnsOnceRegex(
            UserWarning, f"This overload of {func}_ is deprecated"
        ):
            getattr(out_tensor, func + "_")(1, b1, b2)
        self.assertEqual(out_tensor, ref * 2),
        getattr(res3, func + "_")(b1, b2, beta=1)
        self.assertEqual(out_tensor, res3)

        with self.assertWarnsOnceRegex(
            UserWarning, f"This overload of {func}_ is deprecated"
        ):
            getattr(out_tensor, func + "_")(1.0, 0.5, b1, b2)
        self.assertEqual(out_tensor, ref * 2.5)
        getattr(res3, func + "_")(b1, b2, beta=1.0, alpha=0.5)
        self.assertEqual(out_tensor, res3)

        with self.assertWarnsOnceRegex(
            UserWarning, f"This overload of {func} is deprecated"
        ):
            self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))

        res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=0.5)
        self.assertEqual(res4, ref * 3),

        nan = torch.full_like(out_tensor, math.nan)
        res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)
        self.assertEqual(res5, ref)

        if b1.is_complex():
            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1j, alpha=0.5j)
            self.assertEqual(res6, out_tensor * 0.1j + 0.5j * ref)
        else:
            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1, alpha=0.5)
            self.assertEqual(res6, out_tensor * 0.1 + 0.5 * ref)

        res7 = torch.full_like(out_tensor, math.nan)
        getattr(torch, func)(nan, b1, b2, beta=0, out=res7)
        self.assertEqual(res7, ref)

    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
    @dtypes(torch.float32, torch.bfloat16, torch.half)
    def test_addbmm(self, device, dtype):
        num_batches = 2
        M, N, O = 16, 17, 18

        is_supported = True

        if not is_supported:
            b1 = make_tensor(
                (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1
            )
            b2 = make_tensor(
                (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1
            )
            t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1)
            self.assertRaisesRegex(
                RuntimeError,
                "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
                lambda: torch.addbmm(t, b1, b2),
            )
            return

        def invert_perm(p):
            d = {x: i for i, x in enumerate(p)}
            return (d[0], d[1], d[2])

        def generate_tensor():
            numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
            # transposed tensors
            for perm1, perm2 in itertools.product(
                itertools.permutations((0, 1, 2)), repeat=2
            ):
                for perm3 in itertools.permutations((0, 1)):
                    b1 = (
                        make_tensor(
                            (num_batches, M, N),
                            dtype=dtype,
                            device=device,
                            low=-1,
                            high=1,
                        )
                        * 0.1
                    )
                    b2 = (
                        make_tensor(
                            (num_batches, N, O),
                            dtype=dtype,
                            device=device,
                            low=-1,
                            high=1,
                        )
                        * 0.1
                    )
                    b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
                    b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
                    ref = (
                        torch.from_numpy(
                            b1.to(numpy_dtype).cpu().numpy()
                            @ b2.to(numpy_dtype).cpu().numpy()
                        )
                        .to(device=device, dtype=dtype)
                        .sum(0)
                    )
                    out_tensor = (
                        torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3)
                    )
                    yield b1, b2, ref, out_tensor
            # broadcasting tensors
            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
                b1 = (
                    make_tensor(
                        shape1, dtype=dtype, device=device, low=-1, high=1
                    ).expand(num_batches, M, N)
                    * 0.1
                )
                b2 = (
                    make_tensor(
                        shape2, dtype=dtype, device=device, low=-1, high=1
                    ).expand(num_batches, N, O)
                    * 0.1
                )
                ref = (
                    torch.from_numpy(
                        b1.to(numpy_dtype).cpu().numpy()
                        @ b2.to(numpy_dtype).cpu().numpy()
                    )
                    .to(device=device, dtype=dtype)
                    .sum(0)
                )
                out_tensor = torch.zeros_like(ref)
                yield b1, b2, ref, out_tensor
            # zero-sized tensors
            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
                b1 = (
                    make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1)
                    * 0.1
                )
                b2 = (
                    make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1)
                    * 0.1
                )
                ref = (
                    torch.from_numpy(
                        b1.to(numpy_dtype).cpu().numpy()
                        @ b2.to(numpy_dtype).cpu().numpy()
                    )
                    .to(device=device, dtype=dtype)
                    .sum(0)
                )
                out_tensor = torch.zeros_like(ref)
                yield b1, b2, ref, out_tensor

        for b1, b2, ref, out_tensor in generate_tensor():
            self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)

    @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
    @dtypes(torch.float32, torch.bfloat16, torch.half)
    def test_baddbmm(self, device, dtype):
        num_batches = 10
        M, N, O = 12, 8, 50

        def invert_perm(p):
            d = {x: i for i, x in enumerate(p)}
            return (d[0], d[1], d[2])

        def generate_tensor():
            numpy_dtype = (
                dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32
            )
            # transposed tensors
            for perm1, perm2, perm3 in itertools.product(
                itertools.permutations((0, 1, 2)), repeat=3
            ):
                b1 = make_tensor(
                    (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1
                )
                b2 = make_tensor(
                    (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1
                )
                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
                ref = torch.from_numpy(
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
                ).to(device=device, dtype=dtype)
                out_tensor = torch.zeros_like(ref)
                out_tensor = (
                    out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3))
                )
                yield b1, b2, ref, out_tensor
            # broadcasting tensors
            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
                b1 = make_tensor(
                    shape1, dtype=dtype, device=device, low=-1, high=1
                ).expand(num_batches, M, N)
                b2 = make_tensor(
                    shape2, dtype=dtype, device=device, low=-1, high=1
                ).expand(num_batches, N, O)
                ref = torch.from_numpy(
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
                ).to(device=device, dtype=dtype)
                out_tensor = torch.zeros_like(ref)
                yield b1, b2, ref, out_tensor
            # zero-sized tensors
            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2)
                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2)
                ref = torch.from_numpy(
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
                ).to(device=device, dtype=dtype)
                out_tensor = torch.zeros_like(ref)
                yield b1, b2, ref, out_tensor

        for b1, b2, ref, out_tensor in generate_tensor():
            self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)

    def test_tensordot(self, device):
        a = torch.arange(60.0, device=device).reshape(3, 4, 5)
        b = torch.arange(24.0, device=device).reshape(4, 3, 2)
        c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
        cn = torch.from_numpy(
            np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=([1, 0], [0, 1]))
        )
        self.assertEqual(c, cn)

        cout = torch.zeros((5, 2), device=device)
        torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
        self.assertEqual(c, cout)

        a = torch.randn(2, 3, 4, 5, device=device)
        b = torch.randn(4, 5, 6, 7, device=device)
        c = torch.tensordot(a, b, dims=2).cpu()
        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2))

        with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
            torch.tensordot(a, b, dims=-1)

        self.assertEqual(c, cn)
        c = torch.tensordot(a, b).cpu()
        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
        self.assertEqual(c, cn)

        a = torch.tensordot(torch.tensor(0.0), torch.tensor(0.0), 0)
        an = torch.from_numpy(
            np.tensordot(
                np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0
            )
        )
        self.assertEqual(a, an)

    @dtypes(torch.float)
    @precisionOverride({torch.float32: 1e-4})
    def test_1_sized_with_0_strided(self, device, dtype):
        a = make_tensor((8, 1, 64), dtype=dtype, device=device)
        a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
        b = make_tensor((8, 64, 512), dtype=dtype, device=device)
        b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512])
        res = torch.bmm(a_strided, b_strided)
        expect = torch.from_numpy(a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(
            device=device, dtype=dtype
        )
        self.assertEqual(expect, res)

    def _select_broadcastable_dims(self, dims_full=None):
        # select full dimensionality
        if dims_full is None:
            dims_full = []
            ndims = random.randint(1, 4)
            dims_full = [random.randint(1, 8) for _ in range(ndims)]
        else:
            ndims = len(dims_full)

        # select actual dimensions for ops:
        # larger: full ndims, individual sizes may be reduced
        # smaller: possibly reduced ndims, sizes may be reduced
        smaller_ndims = random.randint(1, ndims)
        dims_small = []
        dims_large = []
        for i in range(ndims - 1, -1, -1):
            j = random.randint(1, 3)
            if j == 1:  # no reduced singleton dimension
                ds = dims_full[i]
                dl = dims_full[i]
            elif j == 2:  # larger may have reduced singleton dimension
                ds = dims_full[i]
                dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
            elif j == 3:  # smaller may have reduced singleton dimension
                ds = 1
                dl = dims_full[i]
            dims_large = [dl] + dims_large
            if len(dims_small) < smaller_ndims:
                dims_small = [ds] + dims_small
        return (dims_small, dims_large, dims_full)

    def test_broadcast_fused_matmul(self, device):
        fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]

        for fn in fns:
            batch_dim = random.randint(1, 8)
            n_dim = random.randint(1, 8)
            m_dim = random.randint(1, 8)
            p_dim = random.randint(1, 8)

            def dims_full_for_fn():
                if fn == "baddbmm":
                    return (
                        [batch_dim, n_dim, p_dim],
                        [batch_dim, n_dim, m_dim],
                        [batch_dim, m_dim, p_dim],
                    )
                elif fn == "addbmm":
                    return (
                        [n_dim, p_dim],
                        [batch_dim, n_dim, m_dim],
                        [batch_dim, m_dim, p_dim],
                    )
                elif fn == "addmm":
                    return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
                elif fn == "addmv":
                    return ([n_dim], [n_dim, m_dim], [m_dim])
                elif fn == "addr":
                    return ([n_dim, m_dim], [n_dim], [m_dim])
                else:
                    raise AssertionError("unknown function")

            (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
            (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)

            t0_small = torch.randn(*t0_dims_small, device=device).float()
            t1 = torch.randn(*t1_dims, device=device).float()
            t2 = torch.randn(*t2_dims, device=device).float()

            t0_full = t0_small.expand(*t0_dims_full).to(device)

            fntorch = getattr(torch, fn)
            r0 = fntorch(t0_small, t1, t2)
            r1 = fntorch(t0_full, t1, t2)
            self.assertEqual(r0, r1)

    @dtypes(torch.float32)
    def test_strided_mm_bmm(self, device, dtype):
        # Tests strided view case with stride smaller than corresponding dimension size
        x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device)
        new_shape = [2, 2, 2]
        new_stride = [3, 1, 1]
        sx = torch.as_strided(x, size=new_shape, stride=new_stride)

        torch_fn = lambda x: torch.bmm(x, x)  # noqa: E731
        np_fn = lambda x: np.matmul(x, x)  # noqa: E731
        self.compare_with_numpy(torch_fn, np_fn, sx)

        torch_fn = lambda x: torch.mm(x, x)  # noqa: E731
        self.compare_with_numpy(torch_fn, np_fn, sx[0])

    def test_mm_empty_inputs_mixed_dtype_errors(self, device):
        a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
        b = torch.randn(10, 20, dtype=torch.float32, device=device)
        with self.assertRaisesRegex(
            RuntimeError, "expected .* and .* to have the same dtype, but got:"
        ):
            torch.mm(a, b)

    def test_matmul_45724(self, device):
        # https://github.com/pytorch/pytorch/issues/45724
        a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
        b = torch.rand(65537, 64, 22, device=device, dtype=torch.half)
        c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device)
        cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).half()
        torch.matmul(a, b, out=c)
        self.assertEqual(c, cpu_result)

    @dtypes(
        torch.int16,
        torch.int32,
        torch.int64,
        torch.float16,
        torch.float32,
        torch.float64,
    )
    def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
        batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
        batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
        input_tensor = torch.rand((1, 2, 2), device=device).to(dtype)
        if dtype != torch.float32:
            with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"):
                y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0)
        else:
            out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan)
            y_ref = torch.bmm(batch1, batch2)
            y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
            self.assertEqual(out, y_ref)

    @dtypes(torch.float)
    def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
        for shape in [[3, 2, 2], [2, 20, 20]]:
            mat1, mat2 = (
                torch.randn(shape, dtype=dtype, device=device) for _ in range(2)
            )
            inputs = [
                torch.randn(shape, dtype=dtype, device=device),
                torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan),
            ]
            outs = [
                None,
                torch.randn(shape, dtype=dtype, device=device),
                torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan),
            ]
            options = itertools.product(inputs, outs)
            for input, out in options:
                y_ref = torch.bmm(mat1, mat2)
                y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
                self.assertEqual(y_ref, y)

    @dtypes(torch.float)
    def test_addmm_sizes(self, device, dtype):
        for m in [0, 1, 25]:
            for n in [0, 1, 10]:
                for k in [0, 1, 8]:
                    M = torch.randn(n, m, device=device).to(dtype)
                    m1 = torch.randn(n, k, device=device).to(dtype)
                    m2 = torch.randn(k, m, device=device).to(dtype)
                    self._test_addmm_addmv(torch.addmm, M, m1, m2)

                    m1 = torch.randn(n, k + 1, device=device).to(dtype)
                    m2 = torch.randn(k, m, device=device).to(dtype)
                    self.assertRaisesRegex(
                        RuntimeError,
                        f"{n}x{k + 1}.*{k}x{m}",
                        lambda: torch.addmm(M, m1, m2),
                    )
                    self.assertRaisesRegex(
                        RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2)
                    )

    @precisionOverride(
        {
            torch.double: 1e-8,
            torch.float: 1e-4,
            torch.bfloat16: 5e-2,
            torch.half: 5e-2,
            torch.cfloat: 1e-4,
            torch.cdouble: 1e-8,
        }
    )
    @dtypes(torch.float32, torch.bfloat16, torch.half)
    def test_addmm_gelu(self, device, dtype):
        self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)

    @precisionOverride(
        {
            torch.double: 1e-8,
            torch.float: 1e-4,
            torch.bfloat16: 5e-2,
            torch.half: 5e-2,
            torch.cfloat: 1e-4,
            torch.cdouble: 1e-8,
        }
    )
    @dtypes(torch.float32, torch.bfloat16, torch.half)
    def test_addmm_relu(self, device, dtype):
        self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)

    @dtypes(torch.float, torch.bfloat16, torch.half)
    def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
        # tests (o, s)*(s).  o is output size, s is summed size.
        o = 5
        s = 3
        a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s)
        x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype)
        y_data = torch.ones(o, device=device, dtype=dtype)
        control = torch.tensor(
            [15.0, 33.0, 51.0, 69.0, 87.0], device=device, dtype=dtype
        )

        def _test(row_major, incx, incy, lda_tail):
            if row_major:
                a_storage = torch.full(
                    (o, s + lda_tail), float("nan"), device=device, dtype=dtype
                )
            else:
                a_storage = torch.full(
                    (s, o + lda_tail), float("nan"), device=device, dtype=dtype
                ).permute(1, 0)
            a = a_storage[:o, :s].copy_(a_data)

            x_storage = torch.full((s, incx), float("nan"), device=device, dtype=dtype)
            x = x_storage[:, 0].copy_(x_data)

            y_storage = torch.full((o, incy), float("nan"), device=device, dtype=dtype)
            y = y_storage[:, 0].copy_(y_data)

            self._test_addmm_addmv(torch.addmv, y, a, x)

        for row_major, incx, incy, lda_tail in itertools.product(
            (False, True), (1, 2), (1, 2), (0, 1)
        ):
            _test(row_major, incx, incy, lda_tail)

    @precisionOverride(
        {
            torch.double: 1e-8,
            torch.float: 1e-4,
            torch.bfloat16: 0.6,
            torch.half: 1e-1,
            torch.cfloat: 1e-4,
            torch.cdouble: 1e-8,
        }
    )
    @dtypes(torch.bfloat16, torch.half, torch.float32)
    def test_corner_cases_of_cublasltmatmul(self, device, dtype):
        # common case
        M = torch.randn(128, device=device).to(dtype)
        m1 = torch.randn(2048, 2400, device=device).to(dtype)
        m2 = torch.randn(128, 2400, device=device).to(dtype)
        torch.nn.functional.linear(m1, m2, M)
        # Ntrans_B has ld >> rows
        m1 = torch.rand([128, 2400]).to(dtype).to(device).t()
        m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340]
        M = torch.rand([128]).to(dtype).to(device)
        torch.addmm(M, m2.t(), m1)
        # trans_A has ld >> rows
        m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t()
        m2 = torch.randn(2048, 2400, device=device).to(dtype)
        M = torch.rand([128]).to(dtype).to(device)
        torch.addmm(M, m2, m1)
        # large tensor dim > 65535
        M = torch.randn(16, device=device).to(dtype)
        m1 = torch.randn(32, 131071, device=device).to(dtype)
        m2 = torch.randn(16, 131071, device=device).to(dtype)
        torch.nn.functional.linear(m1, m2, M)

    def test_blas_empty(self, device):
        def fn(torchfn, *args, test_out=False, **kwargs):
            def call_torch_fn(*args, **kwargs):
                return torchfn(
                    *tuple(
                        torch.randn(shape, device=device)
                        if isinstance(shape, tuple)
                        else shape
                        for shape in args
                    ),
                    **kwargs,
                )

            result = call_torch_fn(*args, **kwargs)
            if not test_out:
                return result
            else:
                out = torch.full_like(result, math.nan)
                out1 = call_torch_fn(*args, **kwargs, out=out)
                return out

        # mm, addmm
        self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape)
        self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape)
        self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape)
        self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape)
        self.assertEqual(
            torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))
        )
        self.assertEqual(
            torch.zeros((5, 6), device=device),
            fn(torch.mm, (5, 0), (0, 6), test_out=True),
        )

        self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape)
        self.assertEqual((0, 1), fn(torch.addmm, (1,), (0, 17), (17, 1)).shape)
        t = torch.randn((5, 6), device=device)
        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6)))
        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True))

        # mv, addmv
        self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape)
        self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape)
        self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,)))
        self.assertEqual(
            torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True)
        )

        self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape)
        t = torch.randn((3,), device=device)
        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,)))
        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True))

        # bmm, baddbmm
        self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape)
        self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape)
        self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape)
        self.assertEqual(
            torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))
        )
        self.assertEqual(
            torch.zeros((3, 5, 6), device=device),
            fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True),
        )

        self.assertEqual(
            (0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape
        )
        self.assertEqual(
            (3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape
        )
        self.assertEqual(
            (0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape
        )
        self.assertEqual(
            (3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape
        )
        c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5)
        self.assertEqual(
            -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2)
        )  # Issue #33467
        self.assertEqual(
            -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True)
        )  # Issue #33467

        # addbmm
        self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape)
        self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape)
        t = torch.randn((5, 6), device=device)
        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6)))
        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True))

        # matmul
        self.assertEqual(torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,)))
        self.assertEqual(
            torch.tensor(0.0, device=device),
            fn(torch.matmul, (0,), (0,), test_out=True),
        )
        self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape)
        self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape)
        self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape)
        self.assertEqual(
            torch.zeros((5, 3, 4), device=device),
            fn(torch.matmul, (5, 3, 0), (5, 0, 4)),
        )
        self.assertEqual(
            torch.zeros((5, 3, 4), device=device),
            fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True),
        )

        # dot
        self.assertEqual(torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,)))
        self.assertEqual(
            torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True)
        )

    def test_large_bmm_backward(self, device):
        A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
        B = torch.randn([1, 1024, 65536], device=device, requires_grad=True)
        G = torch.randn([1024, 2, 65536], device=device)

        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
        (A @ B).backward(G)

    def test_large_bmm_mm_backward(self, device):
        A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
        B = torch.randn([1024, 65536], device=device, requires_grad=True)
        G = torch.randn([1024, 2, 65536], device=device)

        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
        (A @ B).backward(G)

    def check_single_matmul(self, x, y):
        def assertEqual(answer, expected):
            if x.dtype.is_floating_point or x.dtype.is_complex:
                k = max(x.shape[-1], 1)  # Scale the atol with the size of the matrix
                self.assertEqual(
                    answer,
                    expected,
                    msg=f"{x.shape} x {y.shape} = {answer.shape}",
                    atol=k * 5e-5,
                    rtol=1e-4,
                )
            else:
                self.assertEqual(
                    answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}"
                )

        # test x @ y
        expected = np.matmul(x.cpu(), y.cpu())
        ans = torch.matmul(x, y)
        self.assertTrue(ans.is_contiguous())
        assertEqual(ans, expected)

        # test out
        out = torch.empty_like(ans)
        ans = torch.matmul(x, y, out=out)
        self.assertIs(ans, out)
        self.assertTrue(ans.is_contiguous())
        assertEqual(ans, expected)

    def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3):
        """
        Generates sequences of tuples (x, y) of with size(x) = x_dim and
        size(y) <= y_dim that are compatible wrt. matmul
        """
        assert x_dim >= 1
        assert y_dim >= 2
        x = x_dim
        for y in range(1, y_dim + 1):
            for batch, mn in product(
                product(range(batch_size), repeat=max(x - 2, y - 2, 0)),
                product(range(matrix_size), repeat=min(y, 2)),
            ):
                if x == 1:
                    size_x = mn[:1]
                    size_y = batch + mn
                    yield size_x, size_y
                else:
                    for k in range(matrix_size):
                        size_x = (k,) + mn[:1]
                        if x > 2:
                            size_x = batch[-(x - 2) :] + size_x
                        size_y = mn
                        if y > 2:
                            size_y = batch[-(y - 2) :] + size_y
                        yield size_x, size_y

    @dtypes(torch.float)
    def test_matmul_small_brute_force_1d_Nd(self, device, dtype):
        make_arg = partial(make_tensor, device=device, dtype=dtype)

        for (size_x, size_y), nctg_x, nctg_y in product(
            self.gen_sizes_matmul(1), (True, False), (True, False)
        ):
            x = make_arg(size_x, noncontiguous=nctg_x)
            y = make_arg(size_y, noncontiguous=nctg_y)
            self.check_single_matmul(x, y)

    @dtypes(torch.float)
    def test_matmul_small_brute_force_2d_Nd(self, device, dtype):
        make_arg = partial(make_tensor, device=device, dtype=dtype)

        for (size_x, size_y), nctg_x, nctg_y in product(
            self.gen_sizes_matmul(2), (True, False), (True, False)
        ):
            x = make_arg(size_x, noncontiguous=nctg_x)
            y = make_arg(size_y, noncontiguous=nctg_y)
            self.check_single_matmul(x, y)

    @dtypes(torch.float)
    def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
        make_arg = partial(make_tensor, device=device, dtype=dtype)

        for (size_x, size_y), nctg_x, nctg_y in product(
            self.gen_sizes_matmul(3), (True, False), (True, False)
        ):
            x = make_arg(size_x, noncontiguous=nctg_x)
            y = make_arg(size_y, noncontiguous=nctg_y)
            self.check_single_matmul(x, y)

    @dtypes(torch.float)
    def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
        a = torch.empty(
            (256, 512), device=device, dtype=dtype, requires_grad=True
        ).unsqueeze(0)
        b = torch.empty(
            (4, 128, 512), device=device, dtype=dtype, requires_grad=True
        ).transpose(-1, -2)
        c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0)

        torch.matmul(a.detach(), b.detach(), out=c)

        with self.assertRaisesRegex(
            RuntimeError,
            "functions with out=... arguments don't support automatic differentiation",
        ):
            torch.matmul(a, b, out=c)

        with torch.no_grad():
            torch.matmul(a, b, out=c)


instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)

if __name__ == "__main__":
    run_tests()
