# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Load llama model from a GGUF file, quantized in Q4_0 format.
# For float weights, we load them directly from the GGUF file.
# For Q4_0 weights, we load them into a Tensor subclass (GGMLInt4LinearWeight).
# This is done by replacing the linear weight with the subclass.

import logging
import os
from typing import Callable, Dict, Mapping

import torch
from executorch.examples.models.llama.experimental.subclass import (
    _unpack_two_uint8,
    GGMLInt4LinearWeight,
    to_float,
)
from executorch.extension.gguf_util.converters.llama_converter import (
    _convert_gguf_tensor_name_to_llama_nn,
    _create_pt_model,
)
from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file
from gguf import ReaderTensor
from gguf.constants import GGMLQuantizationType
from torchao.quantization.subclass import QuantizedLinearWeightBase

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


def _replace_with_custom_fn_if_matches_filter(
    # pyre-fixme[2]: Parameter must be annotated.
    model,
    replacement_fn,
    filter_fn,
    cur_fqn="",
) -> None:
    """
    For each `child` in `model`, replaces it with `replacement_fn(child)`
    if `filter_fn(child)` is `True`
    """
    if filter_fn(model, cur_fqn[:-1]):
        model = replacement_fn(model, cur_fqn[:-1])
        return model
    else:
        for name, child in model.named_children():
            new_child = _replace_with_custom_fn_if_matches_filter(
                child, replacement_fn, filter_fn, f"{cur_fqn}{name}."
            )
            if new_child is not child:
                setattr(model, name, new_child)
        return model


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _get_subclass_inserter(
    weight_map: Dict[str, ReaderTensor]
) -> Callable[[torch.nn.Module, str], torch.nn.Module]:
    def insert_subclass(lin, fqn):
        # TODO: replace weights with gguf format tensor
        # packed tensor should have size [numel / 32, 18]
        fqn = fqn + ".weight"
        assert (
            fqn in weight_map
        ), f"Expect {fqn} to be in weight map but not found. All keys are {weight_map.keys()}"
        tensor = weight_map[fqn]
        print(fqn, tensor.shape, tensor.data.shape, lin.weight.shape)
        packed = torch.from_numpy(tensor.data).reshape(-1, 18)
        scale = torch.tensor(_unpack_two_uint8(packed[:, :2]), dtype=torch.float16)
        lin.weight = torch.nn.Parameter(
            GGMLInt4LinearWeight(packed, scale, lin.weight.shape)
        )
        return lin

    return insert_subclass


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _get_filter_fn(
    weight_map: Dict[str, ReaderTensor]
) -> Callable[[torch.nn.Module, str], bool]:
    def _is_linear(mod, fqn):
        return (
            isinstance(mod, torch.nn.Linear)
            and hasattr(mod, "weight")
            and weight_map[fqn + ".weight"].tensor_type == GGMLQuantizationType.Q4_0
            and not isinstance(mod.weight, QuantizedLinearWeightBase)
        )

    return _is_linear


def change_linear_weights_to_q4_0_tensors(
    model: torch.nn.Module, gguf_weights: GGUFWeights
) -> None:
    """
    Converts all linear weight tensors to the
    `GGMLInt4LinearWeight` tensor subclass,
    effectively applying the same form of quantization
    as apply_dynamic_quant while not modifying the linear modules.
    """
    assert gguf_weights is not None, "Must provide gguf_weights"
    weight_map = {
        _convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor
        for tensor in gguf_weights.tensors
    }

    _replace_with_custom_fn_if_matches_filter(
        model,
        _get_subclass_inserter(weight_map),
        _get_filter_fn(weight_map),
    )


def get_float_weights(
    pt_model: torch.nn.Module, gguf_weights: GGUFWeights
) -> Mapping[str, torch.Tensor]:
    """
    Returns a mapping from the fqn to the float weight tensor. Even though
    the model is quantized in Q4_0, these weights are still stored as float.
    Args:
        pt_model (torch.nn.Module): The model to load the weights.
        gguf_weights (GGUFWeights): The weights to extract the weights from.
    """
    state_dict = {}
    for tensor in gguf_weights.tensors:
        model_key = _convert_gguf_tensor_name_to_llama_nn(tensor.name)
        if (
            tensor.tensor_type == GGMLQuantizationType.F32
            or tensor.tensor_type == GGMLQuantizationType.F16
        ):
            print(tensor.name)
            reversed_shape = tensor.shape[::-1]
            new_tensor = tensor.data.reshape(reversed_shape)
            state_dict[model_key] = torch.from_numpy(new_tensor)
        # Load token_embd.weight which is quantized in Q4_0 and we dequantize it into float.
        elif tensor.tensor_type == GGMLQuantizationType.Q4_0:
            if tensor.name == "token_embd.weight":
                print(tensor.name)
                unpacked = to_float(torch.from_numpy(tensor.data.reshape(-1, 18)))
                state_dict[model_key] = unpacked.reshape(
                    pt_model.params.vocab_size, pt_model.params.dim
                )

    # We need to fake initialize the mask, to match with the llama_transformer.py
    for id in range(pt_model.params.n_layers):
        mask_name = f"layers.{id}.attention.mask"
        mask = torch.full(
            (1, 1, pt_model.params.max_seq_len, pt_model.params.max_seq_len),
            float("-inf"),
        )
        mask = torch.triu(mask, diagonal=1)
        state_dict[mask_name] = mask
    return state_dict


def load_gguf_q4_0(gguf_file: str) -> torch.nn.Module:
    assert os.path.isfile(gguf_file), f"Expect a valid gguf_file path, got {gguf_file}"

    logging.info(f"Loading GGUF file: {gguf_file}")
    gguf_model_args, gguf_weights = load_file(gguf_file)

    logging.info("Creating the PyTorch model")
    pt_model = _create_pt_model(
        gguf_model_args,
    )

    logging.info("Load float weights")
    state_dict = get_float_weights(pt_model, gguf_weights)
    pt_model.load_state_dict(state_dict, strict=False)

    logging.info("Change linear weights to Q4_0 tensors")
    change_linear_weights_to_q4_0_tensors(pt_model, gguf_weights)

    pt_model = pt_model.to(dtype=torch.float16)

    return pt_model
