# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from argparse import ArgumentParser, BooleanOptionalAction

import torch
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
    ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.examples.models.llama.export_llama_lib import (
    build_args_parser,
    get_quantizer_and_quant_params,
)
from executorch.examples.models.llama.source_transformation.quantize import (
    EmbeddingQuantHandler,
    get_quant_weight_transform,
)
from executorch.examples.models.llama.source_transformation.sdpa import (
    replace_sdpa_with_custom_op,
)
from executorch.examples.models.llava.image_util import serialize_image
from executorch.examples.models.llava.model import LlavaModel
from executorch.exir import (
    EdgeCompileConfig,
    ExecutorchBackendConfig,
    to_edge_transform_and_lower,
)

from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import (
    ConstraintBasedSymShapeEvalPass,
    HintBasedSymShapeEvalPass,
)

from executorch.extension.llm.export.builder import DType, LLMEdgeManager
from executorch.extension.llm.tokenizer.tokenizer import Tokenizer
from executorch.util.activation_memory_profiler import generate_memory_trace
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from torch.export import Dim
from torch.nn.attention import SDPBackend

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


class LlavaEdgeManager(LLMEdgeManager):
    def export(self) -> "LlavaEdgeManager":
        dynamic_shape = self._get_dynamic_shape()
        # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
        # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
        with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
            self.export_program = torch.export.export(
                self.model,
                self.example_inputs,
                dynamic_shapes=dynamic_shape,
                strict=False,
            )
            self.pre_autograd_graph_module = self.export_program.module()
        return self


def export_text_model(llava, embeddings, dynamic_shapes):
    class LlavaTextModel(torch.nn.Module):
        """Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel."""

        def __init__(self, llava):
            super().__init__()
            self.text_model = llava.text_model

        def forward(self, input_pos, embeddings):
            return self.text_model(None, input_pos, embeddings)

    llava_text_model = LlavaTextModel(llava)

    text_model_em = LLMEdgeManager(
        model=llava_text_model,
        modelname="llava_text_model",
        max_seq_len=llava.text_model_args.max_seq_len,
        dtype=DType.fp32,
        use_kv_cache=True,
        example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
        dynamic_shapes=dynamic_shapes,
        args=llava.text_model_args,
    )

    dtype_override = DType.fp32
    parser = build_args_parser()
    args = parser.parse_args(
        ["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
    )
    quant_transform = get_quant_weight_transform(args, dtype_override, False)
    _, quantizers, _ = get_quantizer_and_quant_params(args)
    source_transforms = []
    if llava.use_sdpa_with_kv_cache_op:
        source_transforms.append(replace_sdpa_with_custom_op)
    source_transforms.append(quant_transform)
    manager = (
        text_model_em.set_output_dir("./")
        .to_dtype(dtype_override)
        .source_transform(source_transforms)
        .export()
        .pt2e_quantize(quantizers)
    )

    with torch.no_grad():
        text_model_ep = torch.export.export(
            manager.pre_autograd_graph_module,
            manager.example_inputs,
            dynamic_shapes=manager._get_dynamic_shape(),
        )
    return text_model_ep


def export_image_encoder(llava, resized, dynamic_shapes):
    class LlavaImageEncoder(torch.nn.Module):
        """Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel."""

        def __init__(self, llava):
            super().__init__()
            self.llava = llava

        def forward(self, images):
            return self.llava.image_embedding(images)

    llava_image_encode = LlavaImageEncoder(llava)

    # quantizer
    quantizer = XNNPACKQuantizer()
    quantizer.set_global(get_symmetric_quantization_config())

    manager = (
        LlavaEdgeManager(
            model=llava_image_encode,
            modelname="llava_image_encoder",
            max_seq_len=llava.text_model_args.max_seq_len,  # This may not be right
            dtype=DType.fp32,
            use_kv_cache=True,
            example_inputs=(resized,),
            dynamic_shapes=dynamic_shapes,
            args=None,
        )
        .export()
        .pt2e_quantize([quantizer])
    )

    # lower to executorch
    with torch.no_grad():
        image_encoder_ep = torch.export.export(
            manager.pre_autograd_graph_module,
            manager.example_inputs,
            dynamic_shapes=manager.dynamic_shapes,
        )
    return image_encoder_ep


def export_token_embedding(llava, prompt):
    def quant_embedding(model):
        return EmbeddingQuantHandler(
            model,
            bitwidth=8,
            group_size=32,
            packed=False,
        ).quantized_model()

    quantized_token_embed = quant_embedding(llava.model_.language_model.model)
    token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
    dynamic_shapes = [{1: token_dim_1}]
    with torch.no_grad():
        token_embedding_ep = torch.export.export(
            quantized_token_embed.embed_tokens, (prompt,), dynamic_shapes=dynamic_shapes
        )
    return token_embedding_ep


def export_all(llava_model: LlavaModel):
    llava = llava_model.get_eager_model()

    (
        prompt_before_image,
        resized,
        prompt_after_image,
    ) = llava_model.get_inputs_for_prefill()

    image_encoder_ep = export_image_encoder(
        llava, resized, llava_model._get_image_dynamic_shapes()
    )

    embeddings = llava.prefill_embedding(
        prompt_before_image, resized, prompt_after_image
    )

    text_model_ep = export_text_model(
        llava, embeddings, llava_model._get_prompt_dynamic_shapes()
    )

    token_embedding_ep = export_token_embedding(llava, prompt_before_image)

    lowered_and_edge = to_edge_transform_and_lower(
        {
            "image_encoder": image_encoder_ep,
            "token_embedding": token_embedding_ep,
            "text_model": text_model_ep,
        },
        partitioner={
            "image_encoder": [XnnpackPartitioner()],
            "text_model": [
                # First partition the DQLinear nodes, then partition the rest of the nodes,
                # to avoid multiple DQLinear nodes in the same partition,
                # to avoid holding multiple unpacked and packed weight buffers in memory,
                # to reduce peak memory footprint.
                XnnpackPartitioner(
                    config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
                    per_op_mode=True,
                ),
                XnnpackPartitioner(),
            ],
        },
        compile_config=EdgeCompileConfig(_check_ir_validity=False),
    )

    executorch_program = lowered_and_edge.to_executorch(
        ExecutorchBackendConfig(
            extract_delegate_segments=True,
            passes=[
                QuantFusionPass(),
            ],
            memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
            sym_shape_eval_pass={
                "image_encoder": ConstraintBasedSymShapeEvalPass(),
                "text_model": ConstraintBasedSymShapeEvalPass(),
                "token_embedding": HintBasedSymShapeEvalPass(),
            },
        )
    )
    for execution_plan in executorch_program._emitter_output.program.execution_plan:
        logging.info(
            f"Required memory for activation in bytes: {execution_plan.non_const_buffer_sizes}"
        )
    return executorch_program


def get_image_tensor_for_llava_runner(llava_model):
    # llava runner doesn't have image reader so an image tensor is needed.
    (resized,) = llava_model.get_example_inputs()

    serialize_image(resized, "image.pt")


def get_tokenizer_for_llava_runner(llava_model):
    # serialize tokenizer into tokenizer.bin
    llava_model.tokenizer.save_vocabulary("./")
    t = Tokenizer("tokenizer.model")
    t.export("tokenizer.bin")


def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--use-sdpa-with-kv-cache",
        default=True,
        action=BooleanOptionalAction,
        help="Use sdpa_with_kv_cache custom op in LLava text model.",
    )
    parser.add_argument(
        "--max-seq-len",
        default=768,
        type=int,
        help="Maximum sequence length for the text model.",
    )
    parser.add_argument(
        "--pte-name",
        default="llava_combined_xnnpack.pte",
        help="Name of the exported ExecuTorch program.",
    )
    parser.add_argument(
        "--with-artifacts",
        default=False,
        action=BooleanOptionalAction,
        help="Generate artifacts for llava runner.",
    )
    parser.add_argument(
        "--profile_memory",
        required=False,
        action="store_true",
        help="Generate chrome trace of activation memory for intermediate tensors.",
    )
    args = parser.parse_args()
    logging.info(
        f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}, max_seq_len: {args.max_seq_len}"
    )
    llava_model = LlavaModel(
        use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache,
        max_seq_len=args.max_seq_len,
    )

    executorch_program = export_all(llava_model)

    # memory profiling
    if args.profile_memory:
        for method_name in executorch_program.methods:
            generate_memory_trace(
                executorch_program,
                f"{args.pte_name}_{method_name}.json",
                method_name=method_name,
            )

    with open(args.pte_name, "wb") as f:
        executorch_program.write_to_file(f)
    logging.info(f"Exported ExecuTorch program to {args.pte_name}")

    # artifacts
    if args.with_artifacts:
        get_image_tensor_for_llava_runner(llava_model)
        get_tokenizer_for_llava_runner(llava_model)


if __name__ == "__main__":
    main()
