# Owner(s): ["module: inductor"]
import contextlib
import os
import subprocess
import sys
from unittest.mock import patch

import torch
import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
from torch._dynamo.testing import rand_strided
from torch._inductor import config
from torch._inductor.codecache import PyCodeCache
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache
from torch.testing import FileCheck
from torch.testing._internal.common_device_type import expectedFailureXPU
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU


class TestKernelBenchmark(TestCase):
    device_type = GPU_TYPE

    @classmethod
    def setUpClass(cls):
        cls.exit_stack = contextlib.ExitStack()
        cls.exit_stack.enter_context(patch.object(config, "benchmark_kernel", True))

    @classmethod
    def tearDownClass(cls):
        cls.exit_stack.close()

    def setUp(self):
        super().setUp()
        PyCodeCache.cache.clear()

    def get_compiled_module(self):
        compiled_module = None
        for v in PyCodeCache.cache.values():
            if hasattr(v, "benchmark_compiled_module"):
                self.assertTrue(
                    compiled_module is None, "Found multiple compiled modules"
                )
                compiled_module = v

        self.assertTrue(compiled_module is not None)
        return compiled_module

    def verify_compiled_kernels(self, GB_count=1):
        compiled_module = self.get_compiled_module()

        # now run the compiled module in subprocess and check its output
        bench_out = subprocess.check_output(
            f"{sys.executable} {compiled_module.__file__} -kc".split(),
            stderr=subprocess.STDOUT,
        ).decode()

        # make sure we have the bandwidth information in the output
        FileCheck().check_count(
            "GB/s",
            GB_count,
            exactly=1,
        ).run(bench_out)

    def verify_remove_inductor_deps(self, compiled_module):
        try:
            out = subprocess.check_output(
                f"{sys.executable} {compiled_module.__file__}".split(),
                env={**os.environ.copy(), "TORCHINDUCTOR_DUMP_LAUNCH_PARAMS": "1"},
                stderr=subprocess.STDOUT,
            )
        except subprocess.CalledProcessError as e:
            print(
                "Failed when runinng triton code with TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1",
                e,
            )
            print(e.output.decode())
            raise e
        from torch.utils._get_clean_triton import get_clean_triton

        cleaned_triton = get_clean_triton(
            compiled_module.__file__, f"{compiled_module.__file__}.cleaned"
        )
        self.assertTrue("@triton_heuristics" not in cleaned_triton)
        self.assertTrue(".run(" not in cleaned_triton)
        try:
            out = subprocess.check_output(
                f"{sys.executable} {compiled_module.__file__}.cleaned".split(),
                stderr=subprocess.STDOUT,
            )
        except subprocess.CalledProcessError as e:
            print("Failed when when running cleaned triton", e)
            print(e.output.decode())
            print(cleaned_triton)
            raise e
        return cleaned_triton

    def check_bandwidth(self, compiled_module, num_gb):
        # now run the compiled module in subprocess and check its output
        bench_out = subprocess.check_output(
            f"{sys.executable} {compiled_module.__file__} -k".split(),
            stderr=subprocess.STDOUT,
        ).decode()

        # make sure we have the bandwidth information in the output
        FileCheck().check_count(
            f"{num_gb} GB ",
            1,
            exactly=1,
        ).run(bench_out)

    def test_pw_kernel_benchmark(self):
        @torch.compile
        def f(x):
            return torch.sin(x) + torch.cos(x)

        inp = torch.rand(2, 3).to(device=GPU_TYPE)
        out = f(inp)
        self.verify_compiled_kernels()

    @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
    @fresh_inductor_cache()
    def test_matmul_triton_kernel_benchmark(self):
        M = 12544
        N = 256
        K = 64
        a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
        b = torch.rand(N, K, dtype=torch.float16, device=GPU_TYPE).t()

        @torch.compile
        def f(a, b):
            return torch.relu(a @ b)

        f(a, b)
        self.verify_compiled_kernels()

    @expectedFailureXPU
    @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
    @fresh_inductor_cache()
    def test_mm_triton_kernel_benchmark(self):
        M = 2048
        N = 2432
        K = 1949
        K_2 = 3581
        a = rand_strided((M, K_2), (K_2, 1), device="cuda", dtype=torch.float16)
        b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float16)

        @torch.compile
        def f(a, b):
            a_1 = torch.narrow(a, 1, 0, K)
            c = torch.mm(a_1, b)
            return c

        f(a, b)
        self.verify_compiled_kernels(GB_count=3)

        # make sure we correctly generate the grid info
        compiled_module = self.get_compiled_module()
        with open(compiled_module.__file__) as f:
            source_code = f.read()
        lines = source_code.split("\n")
        meta = [l for l in lines if "meta0 = {" in l]
        scope = {}
        from torch._inductor.kernel.mm_common import mm_grid

        exec(meta[0], scope)
        grid = mm_grid(M, N, scope["meta0"])
        FileCheck().check_count(
            f"grid={grid}",
            2,
            exactly=1,
        ).run(source_code)

    def test_matmul_bandwidth_computation(self):
        """
        The test does a matmul and then mul. Without max-autotune, we use
        the matmul in aten. So there is a single triton kernel for mul.
        The kernel we generated is like:

            @triton.jit
            def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):

        Note the in_out_ptr0 argument. It's for a 1000x1000 tensor, but it's
        inplace udpated, so when computing the bandwidth, we should count
        the total memory access as 2 * 1000 * 1000 * 4 = 8MB. This amount is
        what this test asserts.
        """
        torch.set_float32_matmul_precision("high")  # suggested by a warning

        @torch.compile
        def f(x, y):
            z = x @ y
            w = z * z
            return w

        M, N, K = 1000, 1000, 10
        x = torch.rand(M, K).to(device=GPU_TYPE)
        y = torch.rand(K, N).to(device=GPU_TYPE)
        out = f(x, y)

        compiled_module = self.get_compiled_module()

        self.check_bandwidth(compiled_module, 0.008)

    def test_unused_input_bandwidth_computation(self):
        M, N = 5, 1000000

        @torch.compile
        def f(a, b, c):
            return a + c

        a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
        b = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
        c = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
        torch._dynamo.mark_dynamic(a, 0)
        torch._dynamo.mark_dynamic(b, 0)
        torch._dynamo.mark_dynamic(c, 0)
        inputs = (a, b, c)
        out = f(*inputs)

        compiled_module = self.get_compiled_module()
        # num_gb = size_a + size_c + size_out
        # num_gb = (5 * 1000000 + 5 * 1000000 + 5 * 1000000) * 2 / 1e9
        #        = 0.030
        self.check_bandwidth(compiled_module, "0.030")

    def test_reduction_bandwidth_computation(self):
        @torch.compile
        def f(a):
            return torch.sum(a, dim=1)

        a = torch.rand(1000, 20, 1000, dtype=torch.float16, device=GPU_TYPE)
        inputs = (a,)
        out = f(*inputs)

        compiled_module = self.get_compiled_module()
        # num_gb = size_a + size_out
        # num_gb = (1000 * 20 * 1000 + 1000 * 1000) * 2 / 1e9
        #        = 0.042
        self.check_bandwidth(compiled_module, "0.042")

    @config.patch(max_autotune=True)
    def test_fused_layernorm_bandwidth_computation(self):
        M, N = 10, 1000000

        @torch.compile
        def f(a, b, c, d):
            x0 = a + b
            x1 = torch.nn.functional.layer_norm(
                x0, normalized_shape=(N,), weight=c, bias=d, eps=1e-05
            )
            x2 = torch.sigmoid(x1)
            return x0 * x2

        a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
        b = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
        c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
        d = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
        inputs = (a, b, c, d)
        out = f(*inputs)

        compiled_module = self.get_compiled_module()
        # num_gb = size_a + size_b + size_c + size_d + size_out
        # num_gb = (10 * 1000000 + 1000000 + 1000000 + 1000000 + 10 * 1000000) * 2 / 1e9
        #        = 0.046
        self.check_bandwidth(compiled_module, "0.046")

    def test_slice_add_cat_bandwidth_computation(self):
        M, N = 5, 1000000

        @torch.compile
        def f(a, b, c):
            x0 = torch.narrow(b, 1, N, N)
            # broadcasting
            x1 = x0 + c
            return torch.cat([a, x1], dim=1)

        a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
        b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE)
        c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
        torch._dynamo.mark_dynamic(a, 0)
        torch._dynamo.mark_dynamic(b, 0)
        inputs = (a, b, c)
        out = f(*inputs)

        compiled_module = self.get_compiled_module()
        # we overestimate the size of "slice_b" due to torch.cat
        # num_gp = size_a + size_slice_b + size_c + size_out
        # num_gb = (5 * 1000000 + 5 * 2000000 + 1000000 + 5 * 2000000) * 2 / 1e9
        #        = 0.052
        self.check_bandwidth(compiled_module, "0.052")

    def test_slice_add_bandwidth_computation(self):
        M, N = 5, 1000000

        @torch.compile
        def f(a, b, c):
            x0 = torch.narrow(b, 1, N, N)
            return a + x0 + c

        a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
        b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE)
        c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
        torch._dynamo.mark_dynamic(a, 0)
        torch._dynamo.mark_dynamic(b, 0)
        inputs = (a, b, c)
        out = f(*inputs)

        compiled_module = self.get_compiled_module()
        # num_gb = size_a + size_slice_b + size_c + out_size
        # num_gb = (5 * 1000000 + 5 * 1000000 + 1000000 + 5 * 1000000) * 2 / 1e9
        #        = 0.032
        self.check_bandwidth(compiled_module, "0.032")

    def test_mm_slice_add_bandwidth_computation(self):
        M, N, K = 1000, 1000, 30

        @torch.compile
        def f(a, b, c):
            x0 = torch.mm(a, b)
            x1 = torch.narrow(c, 1, 20 * N, N)
            x2 = torch.narrow(c, 1, 21 * N, N)
            return x0 + x1 + x2

        a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
        b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
        c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE)
        inputs = (a, b, c)
        out = f(*inputs)

        compiled_module = self.get_compiled_module()
        # torch.mm becomes an extern kernel, so we measure the nbytes
        # for the pointwise add kernel:
        # num_gb = x0 + 2 * size_slice_c + size_out
        # num_gb = (1000 * 1000 + 2 * 1000 * 1000 + 1000 * 1000) * 2/ 1e9
        #        = 0.008
        num_gb = "0.008"
        if GPU_TYPE == "xpu":
            # In XPU backend, mm + add + add will be fused as admm + add
            # And CUDA prefer not fuse add + mm, please check in function
            # `should_prefer_unfused_addmm` in torch/_inductor/fx_passes/post_grad.py
            num_gb = "0.006"

        self.check_bandwidth(compiled_module, num_gb)

    def test_mm_slice_add_bandwidth_computation_2(self):
        M, N, K = 1000, 1000, 30

        @torch.compile
        def f(a, b, c):
            x0 = torch.mm(a, b)
            x1 = torch.narrow(c, 1, 20 * N, N)
            x2 = torch.narrow(c, 1, 20 * N, N)
            return x0 + x1 + x2

        a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
        b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
        c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE)
        inputs = (a, b, c)
        out = f(*inputs)

        compiled_module = self.get_compiled_module()
        # torch.mm becomes an extern kernel, so we measure the nbytes
        # for the pointwise add kernel:
        # num_gb = x0 + size_slice_c + size_out
        # num_gb = (1000 * 1000 + 1000 * 1000 + 1000 * 1000) * 2 / 1e9
        #        = 0.006
        # note that we only count one size_slice_c because two accesses
        # have the same index.
        self.check_bandwidth(compiled_module, "0.006")

    @expectedFailureXPU
    @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
    def test_slice_mm_bandwidth_computation(self):
        M, N, K = 1000, 2000, 3000

        @torch.compile
        def f(a, b):
            x = torch.narrow(a, 1, K, K)
            return torch.mm(x, b)

        a = torch.rand(M, 3 * K, dtype=torch.float16, device=GPU_TYPE)
        b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
        torch._dynamo.mark_dynamic(a, 0)
        inputs = (a, b)
        out = f(*inputs)

        compiled_module = self.get_compiled_module()

        # c[1000, 2000] = x[1000, 3000] @ b[3000, 2000]
        # num_gb = (1000 * 2000 + 1000 * 3000 + 3000 * 2000) * 2 / 1e9
        #        = 0.022
        self.check_bandwidth(compiled_module, "0.022")

    def test_star_dep(self):
        """
        Test the bandwidth estimation for StarDep
        """

        @torch.compile
        def f(a, b):
            a[b] = 3.0

        a = torch.rand(10000, 5000, device=GPU_TYPE)
        b = torch.randint(
            0, 10000, [20000], device=GPU_TYPE, dtype=torch.int32
        ).unsqueeze(1)
        f(a, b)
        compiled_module = self.get_compiled_module()
        # 20000 * 4 = 80KB for b
        # 20000 * 5000 * 4 = 200MB for a
        self.check_bandwidth(compiled_module, "0.200")

    def test_split_scan(self):
        @torch.compile
        def f(a):
            return a.cumsum(-1)

        a = torch.rand(10000, 5000, device=GPU_TYPE)
        f(a.reshape(-1))
        compiled_module = self.get_compiled_module()
        # 10000 * 5000 * 4 = 200 MB for a
        # Double that for output as well
        self.check_bandwidth(compiled_module, "0.400")

    @config.patch("triton.unique_kernel_names", True)
    @config.patch(benchmark_kernel=False)
    @config.patch(compile_threads=1)
    def test_remove_inductor_deps(self):
        @torch.compile
        def f(a):
            return a.cos().sin()

        a = torch.randn(5, device=GPU_TYPE)
        f(a)
        compiled_module = self.get_compiled_module()
        cleaned_triton = self.verify_remove_inductor_deps(compiled_module)

    @config.patch("triton.unique_kernel_names", True)
    @config.patch(benchmark_kernel=False)
    @config.patch(compile_threads=1)
    def test_remove_inductor_deps_multiple_kernels(self):
        @torch.compile
        def f(a):
            a = torch.mm(a, a)
            a = a.cos().sin()
            a = torch.mm(a, a)
            a = torch.softmax(a, dim=-1)
            return a

        a = torch.randn(5, 5, device=GPU_TYPE)
        f(a)
        compiled_module = self.get_compiled_module()
        self.verify_remove_inductor_deps(compiled_module)

    @config.patch("triton.unique_kernel_names", True)
    @config.patch("triton.unique_kernel_names", True)
    @config.patch(benchmark_kernel=False)
    @config.patch(compile_threads=1)
    @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
    def test_remove_inductor_deps_templates(self):
        @torch.compile
        def f(a):
            a = torch.mm(a, a)
            a = a.cos()
            a = torch.mm(a, a)
            a = a.sin()
            return a

        a = torch.randn(128, 128, device=GPU_TYPE)
        f(a)
        compiled_module = self.get_compiled_module()
        self.verify_remove_inductor_deps(compiled_module)

    @config.patch("triton.unique_kernel_names", True)
    @config.patch(benchmark_kernel=False)
    @config.patch(compile_threads=1)
    def test_remove_inductor_deps_scalar(self):
        @torch.compile
        def f(a, b):
            return a + b

        a = torch.tensor(1.0, device=GPU_TYPE)
        b = torch.tensor(2.0, device=GPU_TYPE)
        f(a, b)
        compiled_module = self.get_compiled_module()
        self.verify_remove_inductor_deps(compiled_module)


if __name__ == "__main__":
    if HAS_GPU:
        run_tests()
