import dataclasses
import itertools
import platform
import time
from typing import Optional, Tuple

from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE
from mixtral_moe_quantize import (
    ConditionalFeedForwardInt8,
    WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler,
)
from model import Transformer as LLaMA
from quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler

import torch
import torch._inductor.config


torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.assert_indirect_indexing = False


@dataclasses.dataclass
class GPTModelConfig:
    name: str
    module: type
    mode: Optional[str]
    quantizer: type
    token_per_sec: float
    memory_bandwidth: float
    compilation_time: float


def device_sync(device):
    if "cuda" in device:
        torch.cuda.synchronize(device)
    elif "cpu" in device:
        pass
    else:
        print(f"device={device} is not yet suppported")


def get_arch_name() -> str:
    if torch.cuda.is_available():
        return torch.cuda.get_device_name()
    else:
        # This returns x86_64 or arm64 (for aarch64)
        return platform.machine()


def multinomial_sample_one_no_sync(
    probs_sort,
):  # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)


def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs


def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    probs = logits_to_probs(logits[0, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs


@torch.compile(fullgraph=True)
def prefill(
    model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> torch.Tensor:
    # input_pos: [B, S]
    logits = model(x, input_pos)
    return sample(logits, **sampling_kwargs)[0]


@torch.compile(fullgraph=True, mode="reduce-overhead")
def decode_one_token(
    model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
    # input_pos: [B, 1]
    assert input_pos.shape[-1] == 1
    logits = model(x, input_pos)
    return sample(logits, **sampling_kwargs)


def decode_n_tokens(
    model: torch.nn.Module,
    cur_token: torch.Tensor,
    input_pos: torch.Tensor,
    num_new_tokens: int,
    **sampling_kwargs,
):
    new_tokens, new_probs = [], []
    for i in range(num_new_tokens):
        with torch.nn.attention.sdpa_kernel(
            torch.nn.attention.SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token, next_prob = decode_one_token(
                model, cur_token, input_pos, **sampling_kwargs
            )
            input_pos += 1
            new_tokens.append(next_token.clone())
            new_probs.append(next_prob.clone())
            cur_token = next_token.view(1, -1)

    return new_tokens, new_probs


@torch.no_grad()
def generate(
    model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs
) -> torch.Tensor:
    device, dtype = prompt.device, prompt.dtype
    T = prompt.size(0)
    T_new = T + max_new_tokens
    max_seq_length = min(T_new, model.config.block_size)

    with torch.device(device):
        model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

    # create an empty tensor of the expected final shape and fill in the current tokens
    empty = torch.empty(T_new, dtype=dtype, device=device)
    empty[:T] = prompt
    seq = empty
    input_pos = torch.arange(0, T, device=device)

    next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
    seq[T] = next_token

    input_pos = torch.tensor([T], device=device, dtype=torch.int)

    generated_tokens, _ = decode_n_tokens(
        model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs
    )
    seq[T + 1 :] = torch.cat(generated_tokens)
    return seq


def _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16):
    with torch.device("meta"):
        model = x.module.from_name(x.name)
    model = model.to(dtype=precision)

    if x.mode == "int8":
        print("Using int8 weight-only quantization!")
        model = x.quantizer(model).convert_for_runtime()

    state_dict = model.state_dict()
    for k, v in state_dict.items():
        state_dict[k] = torch.nn.Parameter(
            torch.randn(v.shape, device=device).to(dtype=v.dtype),
            requires_grad=v.requires_grad,
        )
    model.load_state_dict(state_dict, assign=True)
    return model.eval()


# Only count activated parameters and buffers.
def _get_model_size(model):
    model_size = 0
    for name, child in model.named_children():
        if not isinstance(child, torch.nn.Embedding):
            model_size += sum(
                p.numel() * p.dtype.itemsize
                for p in itertools.chain(child.parameters(), child.buffers())
            )

    # Remove the inactivated experts from the model size if this is mixture of experts
    # architecture, since only activated experts are loaded.
    if hasattr(model.config, "num_experts"):
        config = model.config
        for submodule in model.modules():
            if isinstance(
                submodule, (ConditionalFeedForward, ConditionalFeedForwardInt8)
            ):
                model_size -= (
                    sum(
                        p.numel() * p.dtype.itemsize
                        for p in itertools.chain(
                            submodule.parameters(), child.buffers()
                        )
                    )
                    * (config.num_experts - config.num_activated_experts)
                    / config.num_experts
                )

    return model_size


def run_experiment(
    x: GPTModelConfig,
    num_samples: int = 5,
    max_new_tokens: int = 200,
    top_k: int = 200,
    temperature: float = 0.8,
    device: str = "cuda",
) -> None:
    print(f"Loading model {x.name}")
    t0 = time.time()
    model = _load_model(x, device=device)
    device_sync(device=device)  # MKG
    print(f"Time to load model: {time.time() - t0:.02f} seconds")

    prompt = torch.tensor(
        [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32
    )
    prompt_length = prompt.size(0)

    torch.manual_seed(1234)
    model_size = _get_model_size(model)

    aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []}
    start = -1
    compilation_time = None

    for i in range(start, num_samples):
        device_sync(device=device)  # MKG

        t0 = time.perf_counter()
        y = generate(
            model, prompt, max_new_tokens, temperature=temperature, top_k=top_k
        )

        if i == -1:
            compilation_time = time.perf_counter() - t0
            print(f"Compilation time: {compilation_time:.2f} seconds")
            continue

        device_sync(device=device)  # MKG
        t = time.perf_counter() - t0
        tokens_generated = y.size(0) - prompt_length
        tokens_sec = tokens_generated / t
        aggregate_metrics["tokens_per_sec"].append(tokens_sec)
        aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9)

    token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item()
    memory_bandwidth = torch.mean(
        torch.tensor(aggregate_metrics["memory_bandwidth"])
    ).item()
    print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec")
    print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s")
    print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
    return token_per_sec, memory_bandwidth, compilation_time


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_bf16(device: str = "cuda"):
    from benchmark import Experiment

    model = GPTModelConfig(
        "Llama-2-7b-chat-hf",
        LLaMA,
        "bfloat16",
        LLaMAWeightOnlyInt8QuantHandler,
        94,
        1253,
        162,
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_int8(device: str = "cuda"):
    from benchmark import Experiment

    model = GPTModelConfig(
        "Llama-2-7b-chat-hf",
        LLaMA,
        "int8",
        LLaMAWeightOnlyInt8QuantHandler,
        144,
        957,
        172,
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_mixtral_8x7b_int8(device: str = "cuda"):
    from benchmark import Experiment

    # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
    model = GPTModelConfig(
        "Mixtral-8x7B-v0.1",
        MixtralMoE,
        "int8",
        MixtralMoEWeightOnlyInt8QuantHandler,
        175,
        1130,
        162,
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]
