# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import json
import os
from typing import Dict, Tuple

import torch
from executorch.examples.models.checkpoint import (
    get_checkpoint_dtype,
    get_default_model_resource_dir,
)

from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer

try:
    from .fairseq2 import convert_to_llama_checkpoint

except ImportError:

    def convert_to_llama_checkpoint(**kwargs):
        raise NotImplementedError(
            "Please install fairseq2 with `pip install fairseq2`."
        )


from ..model_base import EagerModelBase


class Llama2Model(EagerModelBase):
    def __init__(self, **kwargs):
        resource_dir = get_default_model_resource_dir(__file__)

        # Use single checkpoint file.
        checkpoint_path = kwargs.get(
            "checkpoint", resource_dir / "demo_rand_params.pth"
        )
        params_path = kwargs.get("params", resource_dir / "demo_config.json")

        # Check if checkpoint_dir was provided for a sharded checkpoint.
        checkpoint_dir = kwargs.get("checkpoint_dir", None)

        self.use_kv_cache = kwargs.get("use_kv_cache", False)
        self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
        self.generate_full_logits = kwargs.get("generate_full_logits", False)
        self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
        self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
        self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
        self.max_seq_len = kwargs.get("max_seq_len", 128)
        self.args = kwargs.get("args", None)

        # The example is using a dummy small model with random weights for demo purpose only.
        # Follow the instruction in https://github.com/facebookresearch/llama to download the model.
        device = "cpu"
        # flake8: noqa: TOR102
        cps = []
        # Load sharded checkpoint.
        if checkpoint_dir is not None:
            # Load multiple checkpoint; ignore the single path.
            checkpoint_path = None
            for i in range(4):
                cp_name = f"consolidated.{i}.pth"
                print(f"Loading {cp_name}")
                cps.append(
                    torch.load(
                        os.path.join(checkpoint_dir, cp_name),
                        map_location=device,
                        mmap=True,
                    )
                )
            checkpoint = {}
            for key in cps[0].keys():
                if not torch.allclose(cps[0][key], cps[1][key]):
                    values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
                    if "wo" in key or "w2" in key:
                        # Concat on dim=1 for "wo" and "w2".
                        checkpoint[key] = torch.cat(values, dim=1)
                    else:
                        # Concat on dim=0 for everything else.
                        checkpoint[key] = torch.cat(values, dim=0)
                else:
                    # Do not duplicate layers shared between each checkpoint.
                    checkpoint[key] = cps[0][key]
        # Load single checkpoint.
        else:
            checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)

        # If given checkpoint is fairseq, convert to llama checkpoint.
        fairseq2_checkpoint = kwargs.get("fairseq2", False)
        if fairseq2_checkpoint:
            print("Using fairseq2 checkpoint")
            checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
        if "model" in checkpoint:
            # NB: some checkpoint contains a "model" field, which is the actual weights dict
            checkpoint = checkpoint["model"]

        # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
        if (not fairseq2_checkpoint) and checkpoint.get(
            "final_proj.weight", None
        ) is not None:
            raise ValueError(
                """
************************************************************
This looks like a Fairseq2 checkpoint (based on the presence
of `final_proj.weight`.

You can import Fairseq2 checkpoints using the --fairseq2
option, but --fairseq2 was not specified.  Please verify
the checkpoint format to avoid generating faulty models.
************************************************************
"""
            )

        # Get checkpoint dtype.
        self.dtype = get_checkpoint_dtype(checkpoint)

        with open(params_path, "r") as f:
            params = json.loads(f.read())
        output_prune_map = None
        if self.output_prune_map_path is not None:
            with open(self.output_prune_map_path, "r") as f:
                output_prune_map = json.load(f)
            # Change keys from string to int (json only supports string keys).
            output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
        input_prune_map = None
        if self.input_prune_map_path is not None:
            with open(self.input_prune_map_path, "r") as f:
                input_prune_map = json.load(f)
            # Change keys from string to int (json only supports string keys).
            input_prune_map = {int(k): v for (k, v) in input_prune_map.items()}

        model_args: ModelArgs = ModelArgs(
            max_seq_len=self.max_seq_len,
            max_batch_size=1,
            use_kv_cache=self.use_kv_cache,
            use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
            generate_full_logits=self.generate_full_logits,
            input_prune_map=input_prune_map,
            output_prune_map=output_prune_map,
            enable_dynamic_shape=self.enable_dynamic_shape,
            **params,
        )
        if kwargs.get("verbose", False):
            print("============= weights ================")
            print("{key} : {weights.numel()} : {weights.size()}")
            for key, weights in checkpoint.items():
                print(f"{key} : {weights.numel()} : {weights.size()}")
            print("============= /weights ================")

        # Within the device="meta" context, tensors that are created do not carry data.
        # They possess all other metadata a tensor carries such as size, stride, requires_grad.
        with torch.device("meta"):
            self.model_ = Transformer(model_args)

        if "int8" in str(checkpoint_path):
            print("Using int8 weight-only quantization!")
            # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.examples.models.source_transformation.quantize`
            from ..source_transformation.quantize import WeightOnlyInt8QuantHandler

            simple_quantizer = WeightOnlyInt8QuantHandler(self.model_)
            self.model_ = simple_quantizer.convert_for_runtime()
        elif "8da4w" in str(checkpoint_path):
            print("Using int4 weight and int8 dynamic activation quantization!")
            from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

            self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
                self.model_
            )
        elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
            print("Using SPIN quantization.")
            self._transform_for_pre_quantization(checkpoint, model_args)

            from .source_transformation.pre_quantization import (
                sanitize_checkpoint_from_pre_quantization,
            )

            sanitize_checkpoint_from_pre_quantization(checkpoint)
        elif hasattr(self.args, "use_qat") and self.args.use_qat:
            print("Using QAT quantization.")
            self._transform_for_pre_quantization(checkpoint, model_args)
            if hasattr(self.args, "use_lora") and self.args.use_lora:
                assert model_args.lora_args["rank"] == self.args.use_lora
                from .source_transformation.lora import (
                    transform_linear_for_lora_after_quantization,
                )

                self.model_ = transform_linear_for_lora_after_quantization(
                    self.model_,
                    checkpoint,
                    self.args.use_lora,
                )

            from .source_transformation.pre_quantization import (
                sanitize_checkpoint_from_pre_quantization,
            )

            sanitize_checkpoint_from_pre_quantization(checkpoint)

        # assign=True: load params/buffers by assignment instead of performing an in-place copy.
        # Because we are using device="meta", tensors do not have memory associated with them
        # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
        missing, unexpected = self.model_.load_state_dict(
            checkpoint,
            strict=False,
            assign=True,
        )  # self.model_ = Transformer(gptconf)
        if kwargs.get("verbose", False):
            print("============= missing keys ================")
            print(missing)
            print("============= /missing ================")
            print("============= unexpected keys ================")
            print(unexpected)
            print("============= /unexpected ================")

        # Prune the input layer if input_prune_map is provided
        if input_prune_map is not None:
            from .source_transformation.prune_vocab import prune_input_vocab

            self.model_ = prune_input_vocab(self.model_, input_prune_map)

        # Prune the output layer if output_prune_map is provided
        if output_prune_map is not None:
            from .source_transformation.prune_vocab import prune_output_vocab

            self.model_ = prune_output_vocab(self.model_, output_prune_map)

    def get_eager_model(self) -> torch.nn.Module:
        if self.dtype:
            # convert to the type of the provided checkpoint
            # input and output are torch.long, so signature unchanged
            return self.model_.to(self.dtype)
        else:
            # int8 quantization code has some bf16,
            # switch all to FP32
            return self.model_.to(torch.float32)

    def get_example_inputs(self):
        if self.use_kv_cache:
            return self.get_example_inputs_kvcache_sdpa()
        else:
            return (
                torch.tensor(
                    [[1, 2, 3]], dtype=torch.long
                ),  # tokens, with kv cache our input token length is always just 1 token.
            )

    # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
    def get_example_inputs_kvcache_sdpa(self):
        if self.enable_dynamic_shape:
            return (
                torch.tensor([[2, 3, 4]], dtype=torch.long),
                torch.tensor([0], dtype=torch.long),
            )
        else:
            return (
                torch.tensor(
                    [[1]], dtype=torch.long
                ),  # tokens, with kv cache our input token length is always just 1 token.
                torch.tensor(
                    [0], dtype=torch.long
                ),  # start_pos, what token of output are we on.
            )

    def _transform_for_pre_quantization(self, checkpoint, model_args):
        assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
        assert self.args.preq_mode in [
            "8da4w",
            "8da4w_output_8da8w",
        ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
        assert hasattr(
            self.args, "preq_group_size"
        ), "preq_group_size must be specified"
        assert hasattr(self.args, "dtype_override"), "dtype_override must be specified"
        from .source_transformation.pre_quantization import (
            transform_linear_for_pre_quantization,
        )

        assert self.args.preq_group_size == model_args.quantization_args["group_size"]

        mapping = {
            "fp32": torch.float32,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }

        # Transform the output layer first if needed.
        if self.args.preq_mode == "8da4w_output_8da8w":
            from .source_transformation.pre_quantization import (
                transform_output_linear_for_pre_quantization,
            )

            self.model_ = transform_output_linear_for_pre_quantization(
                module=self.model_,
                checkpoint=checkpoint,
                dtype=mapping[self.args.dtype_override],
            )

        self.model_ = transform_linear_for_pre_quantization(
            self.model_,
            checkpoint,
            self.args.preq_group_size,
            mapping[self.args.dtype_override],
        )

        embedding_bit_width, embedding_group_size = None, None
        if hasattr(self.args, "preq_embedding_quantize"):
            embedding_bit_width, embedding_group_size = (
                self.args.preq_embedding_quantize.split(",")
            )
            from .source_transformation.pre_quantization import (
                transform_embedding_for_pre_quantization,
            )

            if (
                embedding_group_size == "none"
                or embedding_group_size == "None"
                or embedding_group_size == "0"
            ):
                embedding_group_size = None
            else:
                embedding_group_size = int(embedding_group_size)

            self.model_ = transform_embedding_for_pre_quantization(
                self.model_,
                checkpoint,
                mapping[self.args.dtype_override],
                int(embedding_bit_width),
                embedding_group_size,
            )
