# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import List, Optional

import torch

from executorch.extension.llm.tokenizer.utils import get_tokenizer


def sample_top_p(probs, p):
    """
    Perform top-p (nucleus) sampling on a probability distribution.

    Args:
        probs (torch.Tensor): Probability distribution tensor.
        p (float): Probability threshold for top-p sampling.

    Returns:
        torch.Tensor: Sampled token indices.

    Note:
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
        exceeds the threshold p. The distribution is re-normalized based on the selected tokens.
    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token


def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
    if temperature > 0:
        probs = torch.softmax(logits / temperature, dim=-1)
        return sample_top_p(probs, top_p).item()
    # Pyre-ignore[7]: Incompatible return type [7]: Expected `int` but got `Union[bool, float, int]`
    return torch.argmax(logits, dim=-1).item()


class LlamaRunner(ABC):
    def __init__(
        self,
        tokenizer_path: str,
        max_seq_len: int,
        max_batch_size: int,
        use_kv_cache: bool,
        vocab_size: int,
        device: str = "cpu",
    ):
        """
        Constructor.

        Args:
        tokenizer_path: path to tokenizer.model file.
        max_seq_len: max length of the output sequence, after which the output will be clipped.
        max_batch_size: max batch size.
        use_kv_cache: whether to use a KV cache.
        vocab_size: number of items in the vocab.
        device: device to run the runner on.
        """
        self.max_seq_len = max_seq_len
        self.max_batch_size = max_batch_size
        self.use_kv_cache = use_kv_cache
        self.tokenizer = get_tokenizer(tokenizer_path)
        self.device = device
        assert vocab_size == self.tokenizer.n_words

    @abstractmethod
    def forward(
        self,
        tokens: torch.Tensor,
        input_pos: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        pass

    def generate(  # noqa: C901
        self,
        prompt_tokens: List[int],
        max_seq_len: int,
        temperature: float = 0.8,
        top_p: float = 0.9,
        echo: bool = False,
        pos_base: int = 0,
    ) -> List[int]:
        # Prefill
        logits = self.forward(
            tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
            input_pos=(
                torch.tensor([pos_base], dtype=torch.long, device=self.device)
                if self.use_kv_cache
                else None
            ),
        )

        current_token = next_token(logits, temperature, top_p)
        print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
        tokens = prompt_tokens + [current_token]

        while len(tokens) < max_seq_len:
            if self.use_kv_cache:
                logits = self.forward(
                    tokens=torch.tensor(
                        [[current_token]], dtype=torch.long, device=self.device
                    ),
                    input_pos=torch.tensor(
                        [pos_base + len(tokens) - 1],
                        dtype=torch.long,
                        device=self.device,
                    ),
                )
            else:
                logits = self.forward(
                    tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
                )

            # If the logits aren't already clipped to only contain the last logit, clip them.
            current_token = next_token(logits, temperature, top_p)
            tokens.append(current_token)

            if current_token == self.tokenizer.eos_id or (
                hasattr(self.tokenizer, "stop_tokens")
                and current_token in self.tokenizer.stop_tokens
            ):
                break

            print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
        print("\n")

        return tokens if echo else tokens[len(prompt_tokens) :]

    def text_completion(
        self,
        prompt: str,
        temperature: float = 0.6,
        top_p: float = 0.9,
        echo: bool = False,
    ) -> List[int]:
        """
        Perform text completion for a prompt using the language model.

        Args:
            prompt (str): Text prompt for completion.
            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

        Returns:
            Generated list of tokens.

        Note:
            This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
        """
        return self.generate(
            prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
            max_seq_len=self.max_seq_len,
            temperature=temperature,
            top_p=top_p,
            echo=echo,
        )

    def chat_completion(
        self,
        temperature: float = 0.6,
        top_p: float = 0.9,
    ) -> List[int]:
        """
        Perform multi-turn chat with the language model.

            Args:
                prompt (str): Text prompt for completion.
                temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
                top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
                echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

            Returns:
                Generated list of tokens.

            Note:
                This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
        """
        exit_prompt = "exit"
        tokens = []
        prompt = input("Me: ")
        while prompt and prompt != exit_prompt:
            print("LLM: ", end="", flush=True)
            new_tokens = self.generate(
                prompt_tokens=self.tokenizer.encode(
                    self._format_prompt(prompt), bos=True, eos=False
                ),
                max_seq_len=self.max_seq_len,
                temperature=temperature,
                top_p=top_p,
                echo=True,
                pos_base=len(tokens) - 1 if len(tokens) > 0 else 0,
            )
            tokens.extend(new_tokens)
            prompt = input("Me: ")
        return tokens

    def _format_prompt(self, prompt: str) -> str:
        return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
