# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from functools import partial
from typing import Any

import torch
from executorch.extension.pybindings.aten_lib import ExecuTorchModule  # @manual

from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from torchtune.data import AlpacaToMessages
from torchtune.data._collate import padded_collate_sft
from torchtune.datasets import PackedDataset, SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from tqdm import tqdm


class TrainingModule(torch.nn.Module):
    """
    The model being trained should return the loss from forward(). This
    class wraps the actual model and computes the loss for an LLM
    fine-tuning task. The loss is computed as the cross entropy between
    the tokens and a shifted version of the labels so we learn to predict
    the next token.
    """

    def __init__(
        self, model: torch.nn.Module, loss: torch.nn.modules.loss._Loss
    ) -> None:
        super().__init__()
        self.model = model
        self.loss = loss

    def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        # Output is of the shape (seq_len, vocab_size).
        logits = self.model(input)
        logits = logits[..., :-1, :].contiguous()
        labels = labels[..., 1:].contiguous()
        logits = logits.transpose(1, 2)
        return self.loss(logits, labels)


def python_code_instructions_alpaca(tokenizer: ModelTokenizer) -> PackedDataset:
    """
    Python code instruction-input-output pairs from iamtarun/python_code_instructions_18k_alpaca templated with Alpaca.
    """
    ds = SFTDataset(
        # pyre-ignore[6]: Incompatible parameter type
        model_transform=tokenizer,
        source="iamtarun/python_code_instructions_18k_alpaca",
        message_transform=AlpacaToMessages(
            train_on_input=False,
        ),
        # pyre-ignore[6]: Incompatible parameter type
        split="train",
    )
    if tokenizer.max_seq_len is None:
        raise ValueError(
            "PackedDataset requires a max_seq_len to be set on the tokenizer."
        )
    return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=False)


def update_function(
    param: torch.Tensor,
    grad: torch.Tensor,
    learning_rate: float,
    weight_decay: float = 1.0,
) -> None:
    """SGD update function."""
    grad = grad + weight_decay * param
    param.sub_(learning_rate * grad)


def eval_model(
    model: ExecuTorchModule,
    dataloader: DataLoader,
    loss_fn: torch.nn.modules.loss._Loss,
    max_seq_len: int,
    num_eval_steps: int,
) -> float:
    total_loss = 0
    for i, batch in tqdm(enumerate(dataloader), total=num_eval_steps):
        if i >= num_eval_steps:
            break
        tokens, labels = batch["tokens"], batch["labels"]
        token_size = tokens.shape[1]
        labels_size = labels.shape[1]

        tokens, labels = batch["tokens"], batch["labels"]
        token_size = tokens.shape[1]
        labels_size = labels.shape[1]

        # Fixed length for now. We need to resize as the input shapes
        # should be the same passed as examples to the export function.
        if token_size > max_seq_len:
            tokens = tokens[:, :max_seq_len]
        else:
            tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0)

        if labels_size > max_seq_len:
            labels = labels[:, :max_seq_len]
        else:
            labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)

        out = model.forward((tokens, labels))
        loss = out[0]
        total_loss += loss
    return total_loss / num_eval_steps


def get_dataloader(
    cfg: Any,  # pyre-ignore[2]
    ds: Dataset[Any],  # pyre-ignore[2]
    tokenizer: Any,  # pyre-ignore[2]
    loss_fn: torch.nn.modules.loss._Loss,
) -> DataLoader:
    """Given a dataset, tokenizer, and loss function, return a dataloader."""
    packed = cfg.dataset.get("packed", False)

    sampler = DistributedSampler(
        ds,
        num_replicas=1,
        rank=0,
        shuffle=cfg.shuffle,
        seed=0,
    )
    dataloader = DataLoader(
        dataset=ds,
        sampler=sampler,
        batch_size=cfg.batch_size,
        collate_fn=(
            partial(
                padded_collate_sft,
                padding_idx=tokenizer.pad_id,
                ignore_idx=loss_fn.ignore_index,
            )
            if not packed
            else None
        ),
    )
    return dataloader
