# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Adapted from gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
import argparse

from typing import Optional, Tuple

import torch

from executorch.examples.models.llama.experimental.load_gguf_q4_0 import load_gguf_q4_0
from sentencepiece import SentencePieceProcessor


def multinomial_sample_one_no_sync(
    probs_sort,
):  # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)


def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs


def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    probs = logits_to_probs(logits[0, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs


def encode_tokens(tokenizer, string, bos=True, device="cpu"):
    tokens = tokenizer.encode(string)
    if bos:
        tokens = [tokenizer.bos_id()] + tokens
    return torch.tensor(tokens, dtype=torch.int, device=device)


def decode_one_token(
    model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
    logits = model(x)
    return sample(logits, **sampling_kwargs)


def prefill(model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
    return decode_one_token(model, x, **sampling_kwargs)[0]


def decode_n_tokens(
    model: torch.nn.Module,
    cur_token: torch.Tensor,
    num_new_tokens: int,
    callback=lambda _: _,
    **sampling_kwargs,
):
    print(f"cur_token: {cur_token}")
    new_tokens, new_probs = [], []
    for _ in range(num_new_tokens):
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_mem_efficient=False, enable_math=True
        ):  # Actually better for Inductor to codegen attention here
            next_token, next_prob = decode_one_token(
                model, cur_token.view(1, -1), **sampling_kwargs
            )
            new_tokens.append(next_token.clone())
            # print(next_token)
            callback(next_token)
            new_probs.append(next_prob.clone())
            cur_token = torch.cat((cur_token.squeeze(), next_token), dim=0)
            # print(cur_token)

    return new_tokens, new_probs


@torch.no_grad()
def generate(
    model: torch.nn.Module,
    prompt: torch.Tensor,
    max_new_tokens: int,
    *,
    interactive: bool,
    callback=lambda x: x,
    **sampling_kwargs,
) -> torch.Tensor:
    """
    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
    """

    # create an empty tensor of the expected final shape and fill in the current tokens
    T = prompt.size(0)
    T_new = T + max_new_tokens
    # if interactive:
    #     max_seq_length = 350
    # else:
    #     max_seq_length = min(T_new, model.params.max_seq_len)

    device, dtype = prompt.device, prompt.dtype

    # with torch.device(device):
    #     model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

    # create an empty tensor of the expected final shape and fill in the current tokens
    empty = torch.empty(T_new, dtype=dtype, device=device)
    empty[:T] = prompt
    seq = empty
    # input_pos = torch.arange(0, T, device=device)

    next_token = prefill(model, prompt.view(1, -1), **sampling_kwargs)
    seq[T] = next_token
    callback(next_token)

    cur_tokens = torch.cat((prompt, next_token), dim=0)
    # input_pos = torch.tensor([T], device=device, dtype=torch.int)

    generated_tokens, _ = decode_n_tokens(
        model,
        cur_tokens.view(1, -1),
        # input_pos,
        max_new_tokens - 1,
        callback=callback,
        **sampling_kwargs,
    )
    seq[T + 1 :] = torch.cat(generated_tokens)

    return seq


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--gguf_file",
        type=str,
        help="The GGUF file to load.",
    )
    parser.add_argument(
        "--tokenizer_path",
        type=str,
        help="The tokenizer.model path.",
    )
    parser.add_argument(
        "--prompt", type=str, default="Hello, my name is", help="Input prompt."
    )

    args = parser.parse_args()

    tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))
    encoded = encode_tokens(tokenizer, args.prompt, bos=True, device="cpu")

    pt_model = load_gguf_q4_0(args.gguf_file)

    max_new_tokens = 100
    buffer = [tokenizer.decode(encoded.tolist())]
    period_id = tokenizer.encode(".")[0]
    done_generating = False

    def callback(x):
        nonlocal done_generating
        if done_generating:
            return
        buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
        if x.item() == tokenizer.eos_id():
            done_generating = True
        if len(buffer) == 4 or done_generating:
            print("".join(buffer), end="", flush=True)
            buffer.clear()

    generate(
        pt_model,
        encoded,
        max_new_tokens,
        interactive=False,
        callback=callback,
        temperature=1.0,
        top_k=10,
    )


if __name__ == "__main__":
    main()
