#
#  Copyright (c) 2023 Apple Inc. All rights reserved.
#  Provided subject to the LICENSE file in the top level directory.
#

# Example script for exporting simple models to flatbuffer

import argparse
import copy
import logging

import torch
from examples.apple.mps.scripts.bench_utils import bench_torch, compare_outputs
from executorch import exir
from executorch.backends.apple.mps import MPSBackend
from executorch.backends.apple.mps.partition import MPSPartitioner
from executorch.devtools import BundledProgram, generate_etrecord
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.devtools.bundled_program.serialize import (
    serialize_from_bundled_program_to_flatbuffer,
)

from executorch.exir import (
    EdgeCompileConfig,
    EdgeProgramManager,
    ExecutorchProgramManager,
)
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.extension.export_util.utils import export_to_edge, save_pte_program

from ....models import MODEL_NAME_TO_MODEL
from ....models.model_factory import EagerModelFactory

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


def get_bundled_program(executorch_program, example_inputs, expected_output):
    method_test_suites = [
        MethodTestSuite(
            method_name="forward",
            test_cases=[
                MethodTestCase(
                    inputs=example_inputs, expected_outputs=[expected_output]
                )
            ],
        )
    ]
    logging.info(f"Expected output: {expected_output}")

    bundled_program = BundledProgram(executorch_program, method_test_suites)
    bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(
        bundled_program
    )
    return bundled_program_buffer


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--model_name",
        required=True,
        help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
    )

    parser.add_argument(
        "--use_fp16",
        default=True,
        action=argparse.BooleanOptionalAction,
        help="Whether to automatically convert float32 operations to float16 operations.",
    )

    parser.add_argument(
        "--use_partitioner",
        default=True,
        action=argparse.BooleanOptionalAction,
        help="Use MPS partitioner to run the model instead of using whole graph lowering.",
    )

    parser.add_argument(
        "--bench_pytorch",
        default=False,
        action=argparse.BooleanOptionalAction,
        help="Bench ExecuTorch MPS foward pass with PyTorch MPS forward pass.",
    )

    parser.add_argument(
        "-b",
        "--bundled",
        action="store_true",
        required=False,
        default=False,
        help="Flag for bundling inputs and outputs in the final flatbuffer program",
    )

    parser.add_argument(
        "-c",
        "--check_correctness",
        action="store_true",
        required=False,
        default=False,
        help="Whether to compare the ExecuTorch MPS results with the PyTorch forward pass",
    )

    parser.add_argument(
        "--generate_etrecord",
        action="store_true",
        required=False,
        default=False,
        help="Generate ETRecord metadata to link with runtime results (used for profiling)",
    )

    parser.add_argument(
        "--checkpoint",
        required=False,
        default=None,
        help="checkpoing for llama model",
    )

    parser.add_argument(
        "--params",
        required=False,
        default=None,
        help="params for llama model",
    )

    args = parser.parse_args()
    return args


def get_model_config(args):
    model_config = {}
    model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0]
    model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1]

    if args.model_name == "llama2":
        if args.checkpoint:
            model_config["checkpoint"] = args.checkpoint
        if args.params:
            model_config["params"] = args.params
        model_config["use_kv_cache"] = True
    return model_config


if __name__ == "__main__":
    args = parse_args()

    if args.model_name not in MODEL_NAME_TO_MODEL:
        raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")

    model_config = get_model_config(args)
    model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config)

    model = model.eval()

    # Deep copy the model inputs to check against PyTorch forward pass
    if args.check_correctness or args.bench_pytorch:
        model_copy = copy.deepcopy(model)
        inputs_copy = []
        for t in example_inputs:
            inputs_copy.append(t.detach().clone())
        inputs_copy = tuple(inputs_copy)

    # pre-autograd export. eventually this will become torch.export
    with torch.no_grad():
        model = torch.export.export_for_training(model, example_inputs).module()
        edge: EdgeProgramManager = export_to_edge(
            model,
            example_inputs,
            edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
        )

    edge_program_manager_copy = copy.deepcopy(edge)

    compile_specs = [CompileSpec("use_fp16", bytes([args.use_fp16]))]

    logging.info(f"Edge IR graph:\n{edge.exported_program().graph}")
    if args.use_partitioner:
        edge = edge.to_backend(MPSPartitioner(compile_specs=compile_specs))
        logging.info(f"Lowered graph:\n{edge.exported_program().graph}")

        executorch_program = edge.to_executorch(
            config=ExecutorchBackendConfig(extract_delegate_segments=False)
        )
    else:
        lowered_module = to_backend(
            MPSBackend.__name__, edge.exported_program(), compile_specs
        )
        executorch_program: ExecutorchProgramManager = export_to_edge(
            lowered_module,
            example_inputs,
            edge_compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        ).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))

    dtype = "float16" if args.use_fp16 else "float32"
    model_name = f"{args.model_name}_mps_{dtype}"

    if args.bundled:
        expected_output = model(*example_inputs)
        bundled_program_buffer = get_bundled_program(
            executorch_program, example_inputs, expected_output
        )
        model_name = f"{model_name}_bundled.pte"

    if args.generate_etrecord:
        etrecord_path = "etrecord.bin"
        logging.info("generating etrecord.bin")
        generate_etrecord(etrecord_path, edge_program_manager_copy, executorch_program)

    if args.bundled:
        with open(model_name, "wb") as file:
            file.write(bundled_program_buffer)
        logging.info(f"Saved bundled program to {model_name}")
    else:
        save_pte_program(executorch_program, model_name)

    if args.bench_pytorch:
        bench_torch(executorch_program, model_copy, example_inputs, model_name)

    if args.check_correctness:
        compare_outputs(
            executorch_program, model_copy, inputs_copy, model_name, args.use_fp16
        )
