import time

import torch
import torch.utils
from functorch.compile import aot_function, tvm_compile


a = torch.randn(2000, 1, 4, requires_grad=True)
b = torch.randn(1, 2000, 4)


def f(a):
    return (a * b).sum(dim=0)


fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
compiled_f = aot_function(f, fw_compiler, bw_compiler)

# fw_compiler = lambda x, _: x
# bw_compiler = lambda x, _: x
iters = 10
out = compiled_f(a)
out.sum().backward()


def bench(func):
    begin = time.time()
    for _ in range(iters):
        out = func(a).sin()
        out.sum().backward()
        a.grad = None
    print(time.time() - begin)


def bench_jax():
    import jax
    import jax.numpy as jnp

    jax_a = jnp.array(a.detach().numpy())
    jax_b = jnp.array(b.detach().numpy())

    def f(a):
        return jnp.sin((a * jax_b).sum(axis=[0])).sum()

    jit_f = jax.jit(jax.grad(f))
    jit_f(jax_a)
    begin = time.time()
    for _ in range(iters):
        out = jit_f(jax_a)
    out.block_until_ready()
    print(time.time() - begin)
    # for


bench(f)
bench(compiled_f)
# bench_jax()
