# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import argparse

import torch
from executorch.examples.llm_pte_finetuning.training_lib import (
    eval_model,
    get_dataloader,
    update_function,
)

from executorch.extension.pybindings.aten_lib import (  # @manual
    _load_for_executorch_from_buffer,
)
from omegaconf import OmegaConf
from torch.nn import functional as F
from torchtune import config
from tqdm import tqdm

parser = argparse.ArgumentParser(
    prog="Runner",
    description="Fine tunes LoRA model using ExecuTorch.",
    epilog="Model exported to be used for fine-tuning.",
)
parser.add_argument("--cfg", type=str, help="Path to the config file.")
parser.add_argument("--model_file", type=str, help="Path to the ET model file.")


def main() -> None:
    args = parser.parse_args()
    config_file = args.cfg
    file = args.model_file
    cfg = OmegaConf.load(config_file)
    tokenizer = config.instantiate(
        cfg.tokenizer,
    )

    loss_fn = config.instantiate(cfg.loss)

    ds = config.instantiate(cfg.dataset, tokenizer)
    train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
    train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)
    val_dataloader = get_dataloader(cfg, val_set, tokenizer, loss_fn)

    max_seq_len = cfg.tokenizer.max_seq_len
    # Num of steps to run training. Assume 1 epoch
    num_steps = 100
    with open(file, "rb") as f:
        model_bytes = f.read()
        et_mod = _load_for_executorch_from_buffer(model_bytes)

        # Evaluate the model before training.
        print("Evaluating the model before training")
        eval_loss = eval_model(
            model=et_mod,
            dataloader=val_dataloader,
            loss_fn=loss_fn,
            max_seq_len=max_seq_len,
            num_eval_steps=10,
        )
        print("Eval loss: ", eval_loss)

        # Based on executorch/extension/training/module/training_module.cpp
        # grads run from [grad_start, param_start]
        # params run from [param_start, outputs_end]
        grad_start = et_mod.run_method("__et_training_gradients_index_forward", [])[0]
        param_start = et_mod.run_method("__et_training_parameters_index_forward", [])[0]
        learning_rate = 5e-3
        f.seek(0)
        losses = []
        for i, batch in tqdm(enumerate(train_dataloader), total=num_steps):
            # Run for a limited number of steps.
            if i >= num_steps:
                break
            tokens, labels = batch["tokens"], batch["labels"]
            token_size = tokens.shape[1]
            labels_size = labels.shape[1]

            # Fixed length for now. We need to resize as the input shapes
            # should be the same passed as examples to the export function.
            if token_size > max_seq_len:
                tokens = tokens[:, :max_seq_len]
            else:
                tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0)

            if labels_size > max_seq_len:
                labels = labels[:, :max_seq_len]
            else:
                labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)

            # Do not clone outputs, since we want the original weights to be returned
            # for us to update with the gradients in-place.
            # See https://github.com/pytorch/executorch/blob/main/extension/pybindings/pybindings.cpp#L736
            # for more info.
            out = et_mod.forward((tokens, labels), clone_outputs=False)

            loss = out[0]
            losses.append(loss.item())
            with torch.no_grad():
                for grad, param in zip(out[grad_start:param_start], out[param_start:]):
                    update_function(param, grad, learning_rate)

        print("Losses: ", losses)
        # Evaluate the model after training.
        eval_loss = eval_model(
            model=et_mod,
            dataloader=val_dataloader,
            loss_fn=loss_fn,
            max_seq_len=max_seq_len,
            num_eval_steps=10,
        )
    print("Eval loss: ", eval_loss)


if __name__ == "__main__":
    main()
