import argparse
import random

import torch


def bench(nt_a, nt_b, niter):
    # Warmup
    nt_c = nt_a.bmm(nt_b)

    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for iter in range(niter):
        nt_c = nt_a.bmm(nt_b)
    end_event.record()
    torch.cuda.synchronize()
    runtime = (start_event.elapsed_time(end_event)) / niter
    return runtime


def sweep_n(niter, dtype):
    for ntensor in [4, 8, 16, 32, 64, 128, 256]:
        tensors = [torch.randn(256, random.randint(100, 200)) for t in range(ntensor)]
        nt_a = torch.nested.nested_tensor(
            tensors,
            dtype=dtype,
            device="cuda",
        )
        nt_b = torch.nested.nested_tensor(
            [t.t() for t in tensors],
            dtype=dtype,
            device="cuda",
        )
        runtime = bench(nt_a, nt_b, niter)
        nt_a_size = torch.ops.aten._nested_tensor_size(nt_a)
        lengths = nt_a_size[:, 1]
        print(
            ",".join(
                map(
                    str,
                    [
                        ntensor,
                        dtype,
                        lengths.min().item(),
                        lengths.float().mean().item(),
                        lengths.max().item(),
                        runtime,
                    ],
                )
            )
        )


if __name__ == "__main__":
    random.seed(123)
    parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
    parser.add_argument("--niter", default="10", type=int)

    args = parser.parse_args()
    niter = args.niter

    print("ntensor,dtype,min_length,mean_length,max_length,runtime")
    sweep_n(niter, torch.float32)
    sweep_n(niter, torch.float16)
