# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Providing builders for LLM models. These builders help user to build LLM
# eager models, apply source transformations and quantization and export them to
# ExecuTorch.

# pyre-unsafe

import logging
from enum import Enum
from typing import Any, Callable, Dict, List, Optional

import torch
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
    DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.exir import EdgeProgramManager
from executorch.exir.backend.partitioner import Partitioner

from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig

from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.export import export_for_training
from torch.nn.attention import SDPBackend

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


class DType(Enum):
    fp32 = "fp32"
    fp16 = "fp16"
    bf16 = "bf16"

    def to_torch_dtype(self) -> torch.dtype:
        mapping = {
            DType.fp32: torch.float32,
            DType.fp16: torch.float16,
            DType.bf16: torch.bfloat16,
        }
        if self not in mapping:
            raise ValueError(f"Unsupported dtype {self}")
        return mapping[self]


class LLMEdgeManager:
    """
    Host a torch.nn.Module for LLM model and facilitates exporting to ExecuTorch.
    """

    def __init__(
        self,
        model,
        modelname,
        max_seq_len,
        dtype,
        use_kv_cache,
        example_inputs,
        example_kwarg_inputs: Optional[Dict] = None,
        args: Optional[Any] = None,
        enable_dynamic_shape: bool = False,
        generate_full_logits: bool = False,
        calibration_tasks: Optional[List[str]] = None,
        calibration_limit: Optional[int] = None,
        calibration_seq_length: Optional[int] = None,
        calibration_data: Optional[str] = None,
        tokenizer_path: Optional[str] = None,
        verbose: bool = False,
        metadata: Optional[dict] = None,
        dynamic_shapes: Optional[Any] = None,
    ):
        self.model = model
        # graph module returned from export()
        self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
        self.modelname = modelname
        self.max_seq_len = max_seq_len
        self.dtype = dtype
        self.example_inputs = example_inputs
        self.example_kwarg_inputs = example_kwarg_inputs
        self.use_kv_cache = use_kv_cache
        self.generate_full_logits = generate_full_logits
        self.enable_dynamic_shape = enable_dynamic_shape
        self.verbose = verbose
        self.metadata = metadata
        self.applied_source_transforms = []
        self.edge_manager: Optional[EdgeProgramManager] = None
        self.export_program = None
        self.output_dir = "."
        self.dynamic_shapes = dynamic_shapes
        self._saved_pte_filename = None
        self.args = args
        self.calibration_tasks = calibration_tasks
        self.calibration_limit = calibration_limit
        self.calibration_seq_length = calibration_seq_length
        self.calibration_data = calibration_data
        self.tokenizer_path = tokenizer_path

    def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
        """
        Set the directory where the .pte file will be saved.
        Args:
            output_dir (str): The directory to store the .pte file.
        """
        self.output_dir = output_dir
        return self

    def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
        """
        Convert the model to the specified dtype.
        Args:
            dtype_override (Optional[DType]): Override the dtype of the model.
        """
        assert not dtype_override or isinstance(
            dtype_override, DType
        ), "Override dtype needs to be of type <DType>"
        if dtype_override is not None and dtype_override != self.dtype:
            torch_dtype = dtype_override.to_torch_dtype()
            logging.info(f"model.to {torch_dtype}")
            self.model = self.model.to(dtype=torch_dtype)
            self.dtype = dtype_override
        return self

    def source_transform(
        self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]]
    ) -> "LLMEdgeManager":
        """
        Apply source transforms to the model. The transforms are callables that
        takes nn.Module as input and returns nn.Module.
        Args:
            transforms (List[Callable[[torch.nn.Module], torch.nn.Module]]): A
                list of source transforms.
        """
        for transform in transforms:
            self.model = transform(self.model)
        self.applied_source_transforms.extend(transforms)

        if self.verbose:
            logging.info(f"Applied source transforms: {self.applied_source_transforms}")
        logging.info(f"Model after source transforms: {self.model}")
        return self

    def _get_dynamic_shape(self) -> Any:
        if self.dynamic_shapes:
            return self.dynamic_shapes

        dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)

        if not self.use_kv_cache:
            # Only one input argument: tokens
            self.dynamic_shapes = ({1: dim},)
        elif self.enable_dynamic_shape:
            # Two input arguments: tokens and input_pos but input_pos is static shape
            self.dynamic_shapes = ({1: dim}, {0: 1})
        else:
            # Two input arguments: tokens and input_pos but both are of static shape
            self.dynamic_shapes = None
        return self.dynamic_shapes

    def _get_edge_config(self) -> EdgeCompileConfig:
        edge_config = EdgeCompileConfig(
            _check_ir_validity=False,
            _skip_type_promotion=bool(self.dtype == DType.fp16),
            _skip_dim_order=True,
        )
        return edge_config

    def export(self) -> "LLMEdgeManager":
        dynamic_shape = self._get_dynamic_shape()
        # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
        # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
        with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
            if hasattr(self.args, "qnn") and self.args.qnn:
                # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
                # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
                exported_module = torch.export.export(
                    self.model,
                    self.example_inputs,
                    self.example_kwarg_inputs,
                    dynamic_shapes=dynamic_shape,
                    strict=True,
                )
            else:
                logging.info("Exporting with:")
                logging.info(f"inputs: {self.example_inputs}")
                logging.info(f"kwargs: {self.example_kwarg_inputs}")
                logging.info(f"dynamic shapes: {dynamic_shape}")
                exported_module = export_for_training(
                    self.model,
                    self.example_inputs,
                    kwargs=self.example_kwarg_inputs,
                    dynamic_shapes=dynamic_shape,
                )
            # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
            #  `Module`.
            self.pre_autograd_graph_module = exported_module.module()
            if hasattr(self.args, "export_only") and self.args.export_only:
                torch.export.save(exported_module, self.args.output_name)

        return self

    def pt2e_calibrate(
        self,
        prepared_module,
        calibration_tasks,
        calibration_limit,
        calibration_seq_length,
        calibration_data,
        tokenizer_path,
    ):
        logging.info("Run calibration...")
        try:
            from executorch.examples.models.llama.eval_llama_lib import (
                GraphModuleEvalWrapper,
            )
            from lm_eval.evaluator import simple_evaluate
        except ImportError:
            raise ImportError(
                "Please install the llm eval dependency via examples/models/llama/install_requirements.sh"
            )

        tokenizer = get_tokenizer(tokenizer_path)

        def calibrate_template(
            module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
        ):
            # TODO: change criteria & support batch inputs if necessary
            pos = torch.tensor(0, dtype=torch.int64)
            token_list = tokenizer.encode(prompts, bos=True, eos=False)

            with torch.no_grad():
                while token_list[-1] != tokenizer.eos_id and pos < max_len:
                    logits = module(
                        torch.full((1, 1), token_list[pos]),
                        torch.tensor((pos,)),
                    )
                    pos += 1
                    if pos >= len(token_list):
                        if self.generate_full_logits:
                            token_list.append(
                                torch.argmax(logits[:, -1], dim=-1).item()
                            )
                        else:
                            token_list.append(torch.argmax(logits[:], dim=-1).item())

        calibrate_template(
            module=prepared_module,
            tokenizer=tokenizer,
            prompts=calibration_data,
            max_len=calibration_seq_length,
        )

        eval_wrapper = GraphModuleEvalWrapper(
            model=prepared_module,
            tokenizer=tokenizer,
            max_seq_length=calibration_seq_length,
            use_kv_cache=self.use_kv_cache,
            generate_full_logits=self.generate_full_logits,
            enable_dynamic_shape=self.enable_dynamic_shape,
        )

        # Evaluate the model
        with torch.no_grad():
            eval_results = simple_evaluate(
                model=eval_wrapper,
                tasks=calibration_tasks,
                limit=calibration_limit,
            )

        for task, res in eval_results["results"].items():
            print(f"{task}: {res}")
        logging.info("Calibration finish...")

    def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
        """
        Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
        Args:
            quantizers (Optional[List[Quantizer]]): A list of quantizers.
        """
        assert (
            self.edge_manager is None
        ), "export_to_edge is already called, please call pt2e_quantize before export_to_edge"
        logging.info(f"Using pt2e {quantizers} to quantizing the model...")

        # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
        # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
        if quantizers:
            with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
                if self.verbose:
                    logging.info(f"Applied quantizers: {quantizers}")
                composed_quantizer = ComposableQuantizer(quantizers)
                assert (
                    self.pre_autograd_graph_module is not None
                ), "Please run export() first"
                m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
                logging.info(
                    f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
                )
                # Calibrate
                if (
                    self.calibration_tasks is not None
                    and self.calibration_limit is not None
                    and self.calibration_seq_length is not None
                    and self.calibration_data is not None
                    and self.tokenizer_path is not None
                ):
                    logging.info(
                        f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
                    )
                    self.pt2e_calibrate(
                        prepared_module=m,
                        calibration_tasks=self.calibration_tasks,
                        calibration_limit=self.calibration_limit,
                        calibration_seq_length=self.calibration_seq_length,
                        calibration_data=self.calibration_data,
                        tokenizer_path=self.tokenizer_path,
                    )
                else:
                    logging.info(
                        "No calibration provided, using dummy input to calibrate..."
                    )
                    m(*self.example_inputs)
                m = convert_pt2e(m)
                DuplicateDynamicQuantChainPass()(m)
                self.pre_autograd_graph_module = m
            return self
        else:
            logging.info("No quantizer provided, passing...")
            return self

    def export_to_edge(self) -> "LLMEdgeManager":
        """
        Export the model to Edge dialect and retrieve a LLMEdgeManager.
        """
        dynamic_shape = self._get_dynamic_shape()
        edge_config = self._get_edge_config()

        # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
        # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
        with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
            if self.pre_autograd_graph_module is None:
                # Run export() if it didn't run
                self.export()
            self.edge_manager = export_to_edge(
                self.pre_autograd_graph_module,  # pyre-fixme[6]
                self.example_inputs,
                example_kwarg_inputs=self.example_kwarg_inputs,
                dynamic_shapes=dynamic_shape,
                edge_constant_methods=self.metadata,
                edge_compile_config=edge_config,
                verbose=self.verbose,
            )
        return self

    def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
        """
        Partition the model and lower to different backends. The signature is
        aligned with the signature of `to_backend` method of EdgeManager.
        Args:
            partitioners (Optional[List[Partitioner]]): One or more
                partitioner to be sent to EdgeManager.to_backend().
        """
        if partitioners is None:
            logging.info("No partitioner provided, passing...")
        else:
            for partitioner in partitioners:
                if partitioner is not None:
                    assert (
                        self.edge_manager is not None
                    ), "Need to run export_to_edge() first"
                    self.edge_manager = self.edge_manager.to_backend(partitioner)
                    if self.verbose:
                        logging.info(
                            format_delegated_graph(
                                self.edge_manager.exported_program().graph_module
                            )
                        )
                        logging.info(f"Applied partitioners: {partitioner}")
                else:
                    logging.info("No partitioner provided, passing...")
                    continue

        return self

    def to_executorch(self) -> "LLMEdgeManager":
        """
        Lower the model to executorch and get an ExecutorchProgram.
        """
        assert self.edge_manager, "Need to run export_to_edge() first"
        self.export_program = self.edge_manager.to_executorch(
            ExecutorchBackendConfig(
                extract_delegate_segments=True,
                passes=[
                    # If there are Linear operations left in the graph, let's execute
                    # them with the optimized op_linear rather than materializing a
                    # transpose followed by a regular op_mm.
                    ConvertToLinearPass(),
                    QuantFusionPass(),
                ],
                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
                sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
            )
        )
        logging.info(
            "Required memory for activation in bytes: {}".format(
                self.export_program._emitter_output.program.execution_plan[
                    0
                ].non_const_buffer_sizes
            ),
        )
        return self

    def save_to_pte(self, output_name: str) -> None:
        """
        Save the model to a .pte file.
        Args:
            output_name (Optional[str]): The name of the .pte file.
        """
        assert output_name, "Need a valid output name"
        filename = save_pte_program(self.export_program, output_name, self.output_dir)
        self._saved_pte_filename = filename

    def get_saved_pte_filename(self) -> Optional[str]:
        """
        Return the filename of the most recenet saved .pte file. Return None if the model is not saved.
        """
        return self._saved_pte_filename
