# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import re
from functools import partial
from pathlib import Path
from typing import Any, Dict, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer

from executorch.extension.llm.export.builder import DType

from sentencepiece import SentencePieceProcessor

try:
    from fairseq2.nn.embedding import (
        Embedding as fsEmbedding,
        StandardEmbedding as fsStandardEmbedding,
    )

    from fairseq2.nn.projection import Linear as fsLinear

    print("Using fairseq2 modules.")
except:
    fsEmbedding = nn.Embedding
    fsStandardEmbedding = nn.Embedding
    fsLinear = nn.Linear


def quantize(  # noqa C901
    model: torch.nn.Module,
    qmode: str,
    activation_dtype: Optional[DType],
    checkpoint_path: Optional[Path] = None,
    # following arguments only available when setting int4 or gptq quantization.
    group_size: Optional[int] = 128,
    # following arguments are only used for GPTQ
    calibration_tasks: Optional[list] = None,
    calibration_limit: Optional[int] = None,
    calibration_seq_length: Optional[int] = None,
    pad_calibration_inputs: bool = False,
    percdamp: float = 0.01,
    blocksize: int = 128,
    tokenizer_path: Optional[Path] = None,
    verbose: bool = False,
) -> torch.nn.Module:
    """
    Quantizes a model by converting all weights to int8.
    Args:
        model: A model to quantize.
        qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq
    Returns:
        A quantized model.
    """
    if activation_dtype is not None:
        torch_dtype = activation_dtype.to_torch_dtype()
    else:
        torch_dtype = torch.float16

    assert checkpoint_path, "Need to specify a checkpoint"
    # if checkpoint_path is None:
    #     checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")

    if qmode == "int8":
        # Add quantization mode options here: group size, bit width, etc.
        return WeightOnlyInt8QuantHandler(model).quantized_model()
    elif qmode.startswith("torchao:"):
        pattern = r"torchao:8da(\d+)w"
        matches = re.findall(pattern, qmode)
        assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
        bitwidth = int(matches[0][0])
        _load_torchao_ops_aten()
        from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer

        with torch.no_grad():
            model = Int8DynActIntxWeightLinearQuantizer(
                device="cpu",
                precision=torch.float32,
                groupsize=group_size,
                bitwidth=bitwidth,
                has_weight_zeros=False,
            ).quantize(model)

        if verbose:
            print("quantized model:", model)
        return model
    elif qmode == "8da4w":
        # Check for required args
        if group_size is None:
            raise Exception("For 8da4w quantization, group size must be specified.")
        from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

        model = Int8DynActInt4WeightQuantizer(
            precision=torch_dtype, groupsize=group_size
        ).quantize(model)

        if verbose:
            print("quantized model:", model)
        return model
    elif qmode == "8da4w-gptq":
        # Check for required args
        required_args: Optional[Any] = [
            group_size,
            calibration_limit,
            calibration_seq_length,
        ]
        if any(arg is None for arg in required_args):
            raise Exception(
                "For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified."
            )
        if calibration_tasks is None:
            calibration_tasks = ["wikitext"]

        try:
            # torchao 0.3+
            from torchao._eval import InputRecorder  # pyre-fixme[21]
        except ImportError:
            from torchao.quantization.GPTQ import InputRecorder  # pyre-ignore

        from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer

        if tokenizer_path is None:
            tokenizer_path = checkpoint_path.parent / "tokenizer.model"
        assert tokenizer_path.is_file(), tokenizer_path
        tokenizer = SentencePieceProcessor(  # pyre-ignore[28]
            model_file=str(tokenizer_path)
        )

        inputs = (
            InputRecorder(  # pyre-fixme[16]
                tokenizer,
                calibration_seq_length,
                None,  # input_prep_func
                pad_calibration_inputs,
                model.vocab_size,
            )
            .record_inputs(
                calibration_tasks,
                calibration_limit,
            )
            .get_inputs()
        )

        gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
            blocksize,
            percdamp,
            group_size,
        )
        model = gptq_quantizer.quantize(model, inputs)
        return model
    elif qmode == "vulkan_4w":
        q_group_size = 256 if group_size is None else group_size
        model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model)

        # Apply additional quantizer for linear layers that aren't lowered to Vulkan
        # at the moment
        from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

        model = Int8DynActInt4WeightQuantizer(
            precision=torch_dtype, groupsize=q_group_size
        ).quantize(model)

        return model
    else:
        raise Exception(f"Unrecognized quantize mode: {qmode}")


def dynamically_quantize_per_channel(
    x,
    quant_min,
    quant_max,
    target_dtype,
    group_size: Optional[int] = None,
    *,
    scales_dtype=torch.float16,
    enable_non_multiple_groups=True,
):
    """
    Dynamically quantize per channel.  This function is used for quantizing weights,
    for linear and embedding layers.

    Arguments:
        x: input tensor,
        quant_min: minimum value after quantization,
        quant_max: maximum value after quantization,
        target_dtype: target data type for weights after quantization,
        group_size: number of elements of the channel to quantize together

    Keyword arguments:
        scales_dtype: data type of scale,
        enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size,
                        with a final group of a size less than group size.

    Assumptions:
        This function assumes symmetric quantization, axis ==0 and a dense memory format.
    """

    # assumes symmetric quantization
    # assumes axis == 0
    # assumes dense memory format
    # TODO(future): relax ^ as needed

    x_shape_1 = x.shape[1]

    if group_size is None or group_size == 0:
        items = x_shape_1
    elif ((x_shape_1 % group_size) == 0) or not enable_non_multiple_groups:
        assert group_size > 0, "group size must be positive"
        assert (
            x_shape_1 % group_size
        ) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {group_size}"
        items = group_size
    else:
        assert group_size > 0, "group size must be positive"
        print(
            f"row-size of weight matrix {x_shape_1} is not divisible by group size {group_size}, using nearest neighbor rounding"
        )
        assert (
            x_shape_1 % group_size != 0
        ), f"expected x.shape[1] to not be a multiple of group size {group_size}, but got {x_shape_1}"
        padding = group_size - (x_shape_1 % group_size)
        x = F.pad(x, (0, padding))
        items = group_size

    # default setup for affine quantization of activations
    eps = torch.finfo(torch.float32).eps

    x = x.view(x.shape[0], x.shape[1] // items, items)
    # get min and max
    min_val, max_val = torch.aminmax(x, dim=2)
    # print(f"min_val {min_val}")
    # print(f"max_val {max_val}")

    # calculate scales and zero_points based on min and max
    # reference: https://fburl.com/code/srbiybme
    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
    device = min_val_neg.device

    # reference: https://fburl.com/code/4wll53rk
    max_val_pos = torch.max(-min_val_neg, max_val_pos)
    scales = max_val_pos / (float(quant_max - quant_min) / 2)
    # ensure scales is the same dtype as the original tensor
    scales = torch.clamp(scales, min=eps).to(x.dtype)
    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

    # quantize based on qmin/qmax/scales/zp
    # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
    x_div = x / scales.unsqueeze(-1)
    x_round = torch.round(x_div)
    x_zp = x_round + zero_points.unsqueeze(-1)
    quant = (
        torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1)
    )

    scales = scales.to(dtype=scales_dtype)
    quant = quant[:, :x_shape_1]

    return quant, scales, zero_points


#########################################################################
###                QuantHandler API definition                        ###


class QuantHandler:
    def __init__(self, mod):
        self.mod = mod

    def create_quantized_state_dict(self) -> Dict:  # "StateDict"
        pass

    def convert_for_runtime(self) -> nn.Module:
        pass

    def quantized_model(self) -> nn.Module:
        model_updated_state_dict = self.create_quantized_state_dict()
        self.convert_for_runtime()
        self.mod.load_state_dict(model_updated_state_dict)
        return self.mod


#########################################################################
###             Weight-only int8 per-channel quantized code           ###


def replace_linear_weight_only_int8_per_channel(module, node_type):
    for name, child in module.named_children():
        # print(f"name: {name}")
        if isinstance(child, nn.Linear):
            if (
                (node_type == "*")
                or (node_type == "output" and name == "output")
                or (node_type == "!output" and name != "output")
            ):
                # print(f"{name, child}")
                # print(f"in_features: {child.in_features}")
                # print(f"out_features: {child.out_features}")
                setattr(
                    module,
                    name,
                    WeightOnlyInt8Linear("cpu", child.in_features, child.out_features),
                )
        else:
            replace_linear_weight_only_int8_per_channel(child, node_type)


class WeightOnlyInt8QuantHandler(QuantHandler):
    def __init__(
        self,
        mod,
        device="cpu",
        *,
        node_type: str = "*",
        bitwidth: Optional[int] = None,
        group_size: Optional[int] = None,
    ):
        self.mod = mod
        self.group_size = group_size
        self.node_type = node_type
        if bitwidth is None:
            self.bitwidth = 8
        else:
            self.bitwidth = bitwidth

    @torch.no_grad()
    def create_quantized_state_dict(self) -> Dict:
        cur_state_dict = self.mod.state_dict()

        if self.bitwidth == 4:
            range_min = -8
            range_max = 7
        elif self.bitwidth == 8:
            range_min = -128
            range_max = 127
        else:
            raise ValueError(f"Unsupported bitwidth {self.bitwidth}")

        for fqn, mod in self.mod.named_modules():
            # print(f"maybe? quantize {fqn}...{type(mod)}")
            if isinstance(mod, torch.nn.Linear) or isinstance(mod, fsLinear):
                # print(f"candidate {fqn}, nodetype {self.node_type}")
                if (
                    (self.node_type == "*")
                    or (self.node_type == "output" and fqn in ["output", "final_proj"])
                    or (
                        self.node_type == "!output"
                        and fqn not in ["output", "final_proj"]
                    )
                ):
                    print(
                        f"quantize {self.node_type} {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
                    )

                    # print(f"initial weight shape {mod.weight.shape}")
                    input_weight = mod.weight.float()

                    # print(f"expanded weight shape {input_weight.shape}")
                    weight, scales, _ = dynamically_quantize_per_channel(
                        input_weight,
                        range_min,
                        range_max,
                        torch.int8,
                        self.group_size,
                        scales_dtype=mod.weight.dtype,
                    )

                    cur_state_dict[f"{fqn}.weight"] = weight
                    # squeeze makes group_size=rowsize unidimensional
                    cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)

        return cur_state_dict

    def convert_for_runtime(self) -> nn.Module:
        replace_linear_weight_only_int8_per_channel(self.mod, self.node_type)
        return self.mod

    def quantized_model(self) -> nn.Module:
        model_updated_state_dict = self.create_quantized_state_dict()
        self.convert_for_runtime()
        self.mod.load_state_dict(model_updated_state_dict)
        return self.mod


class WeightOnlyInt8Linear(torch.nn.Module):
    __constants__ = ["in_features", "out_features"]
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(
        self,
        device,
        in_features: int,
        out_features: int,
        bias: bool = True,
        dtype=None,
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer(
            "weight", torch.zeros((out_features, in_features), dtype=torch.int8)
        )
        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
        # return F.linear(input, self.weight.to(dtype=input.dtype)) * se...


def linear_forward_8da8w(
    x,
    weight_int8,
    scales,
    zeros,
    out_features,
    precision,
):
    from torchao.quantization.utils import per_token_dynamic_quant

    x = per_token_dynamic_quant(x)
    n_bit = 8
    quant_min = -(2 ** (n_bit - 1))
    quant_max = 2 ** (n_bit - 1) - 1
    w_dq = torch.ops.quantized_decomposed.dequantize_per_channel(
        weight_int8,
        scales,
        zeros,
        0,
        quant_min,
        quant_max,
        torch.int8,
        out_dtype=precision,
    )
    c = torch.nn.functional.linear(x, w_dq)

    return c


class Int8DynActInt8WeightLinear(torch.nn.Module):
    __constants__ = ["in_features", "out_features"]

    in_features: int
    out_features: int
    weight: torch.Tensor

    """
    This module implements a dynamic quantized linear layer with int8 weight.
    Weights are per channel quantized. Parameters of importance
    precision: precision of input and output. e.g. torch.float32 means input
    activation is float32 and output is float32.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias=True,
        device=None,
        dtype=None,
        precision: torch.dtype = torch.float32,
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        assert not bias, "require bias=False"
        self.precision = precision

        if dtype is not None:
            raise ValueError("Please specify 'precision' instead of 'dtype'")

        # currently storing unpacked int8 weights
        self.register_buffer(
            "weight",
            torch.zeros((out_features, in_features), dtype=torch.int8),
        )
        self.register_buffer(
            "scales",
            torch.zeros(
                (out_features),
                dtype=torch.float32,
            ),
        )
        self.register_buffer(
            "zeros",
            torch.zeros(
                (out_features),
                dtype=torch.float32,
            ),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        input = input.to(self.precision)
        return linear_forward_8da8w(
            input,
            self.weight,
            self.scales,
            self.zeros,
            self.out_features,
            self.precision,
        )


#########################################################################
#####                   embedding table quantization               ######


def replace_embedding_weight_only_grouped_int8_per_channel(
    module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False
):
    for name, child in module.named_children():
        # print(f"name: {name}")
        if isinstance(child, nn.Embedding):
            # print(f"{name, child}")
            # print(f"weights size: {child.weight.size()}")
            setattr(
                module,
                name,
                QuantizedGroupEmbedding(
                    device=device,
                    vocab_size=child.weight.shape[0],
                    embedding_dim=child.weight.shape[1],
                    group_size=group_size,
                    dtype=child.weight.dtype,
                    packed=packed,
                    bitwidth=bitwidth,
                ),
            )
        else:
            replace_embedding_weight_only_grouped_int8_per_channel(
                child, device, bitwidth, group_size, packed
            )


class EmbeddingQuantHandler(QuantHandler):
    def __init__(
        self,
        mod,
        device="cpu",
        *,
        bitwidth: int = 8,
        group_size: Optional[int] = None,
        packed=False,
    ):
        if isinstance(packed, str):
            packed = packed == "True"
        self.mod = mod
        self.device = device
        self.group_size = group_size
        self.bitwidth = bitwidth
        self.packed = packed
        if (bitwidth not in [2, 4]) and packed:
            raise RuntimeError("pack only works with bitsize 2, 4")

    @torch.no_grad()
    def create_quantized_state_dict(self, packed=False) -> Dict:
        cur_state_dict = self.mod.state_dict()

        if self.bitwidth == 2:
            range_min = -2
            range_max = 1
        elif self.bitwidth == 4:
            range_min = -8
            range_max = 7
        elif self.bitwidth == 8:
            range_min = -128
            range_max = 127
        else:
            raise ValueError(f"Unsupported bitwidth {self.bitwidth}")

        for fqn, mod in self.mod.named_modules():
            if isinstance(mod, nn.Embedding):
                # print("****")
                # print(f"Embedding identified: {fqn, mod}")
                # print(f"weights size: {mod.weight.size()}")
                # print(f"quantize {fqn}...")

                print(
                    f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
                )
                weight, scales, _ = dynamically_quantize_per_channel(
                    mod.weight.float(),
                    range_min,
                    range_max,
                    torch.int8,
                    self.group_size,
                    scales_dtype=mod.weight.dtype,
                )

                if packed:
                    if self.bitwidth == 2:
                        if weight.shape[-1] % 4 != 0:
                            raise RuntimeError("automatic padding not implemented yet")
                        weight_range_shifted = weight.add(2).view(torch.uint8)
                        weight_view = weight_range_shifted.view(
                            weight.shape[0], weight.shape[1] // 4, 4
                        )
                        weight_0 = weight_view[:, :, 0]
                        weight_1 = weight_view[:, :, 1] << 2
                        weight_2 = weight_view[:, :, 2] << 4
                        weight_3 = weight_view[:, :, 3] << 6
                        weight_packed = weight_0 + weight_1 + weight_2 + weight_3
                        weight = weight_packed
                    elif self.bitwidth == 4:
                        if weight.shape[-1] % 2 != 0:
                            raise RuntimeError("automatic padding not implemented yet")
                        weight_range_shifted = weight.add(8).view(torch.uint8)
                        weight_view = weight_range_shifted.view(
                            weight.shape[0], weight.shape[1] // 2, 2
                        )
                        weight_even = weight_view[:, :, 0] * 16  # left shift 4
                        weight_odd = weight_view[:, :, 1]
                        weight_packed = weight_even + weight_odd
                        weight = weight_packed

                weight = weight.to(device=self.device)
                scales = scales.to(device=self.device)
                # Update state dict
                cur_state_dict[f"{fqn}.weight"] = weight
                # squeeze makes group_size=rowsize unidimensional
                cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)

        return cur_state_dict

    def convert_for_runtime(self) -> nn.Module:
        replace_embedding_weight_only_grouped_int8_per_channel(
            self.mod, self.device, self.bitwidth, self.group_size, self.packed
        )
        return self.mod

    def quantized_model(self) -> nn.Module:
        model_updated_state_dict = self.create_quantized_state_dict(self.packed)
        self.convert_for_runtime()
        self.mod.load_state_dict(model_updated_state_dict)
        return self.mod


class QuantizedGroupEmbedding(torch.nn.Module):
    def __init__(
        self,
        device,
        vocab_size: int,
        embedding_dim: int,
        group_size: Optional[int] = None,
        dtype=torch.half,
        packed=False,
        bitwidth: int = 8,
    ) -> None:
        super().__init__()
        if group_size is None or group_size == 0:
            group_size = embedding_dim
        self.group_size = group_size
        self.dtype = dtype
        self.packed = packed
        self.bitwidth = bitwidth
        if not packed:
            self.register_buffer(
                "weight",
                torch.zeros(
                    (vocab_size, embedding_dim), dtype=torch.int8, device=device
                ),
            )
        else:  # packed
            if bitwidth == 2:
                self.register_buffer(
                    "weight",
                    torch.zeros(
                        (vocab_size, embedding_dim // 4),
                        dtype=torch.uint8,
                        device=device,
                    ),
                )
            elif bitwidth == 4:
                self.register_buffer(
                    "weight",
                    torch.zeros(
                        (vocab_size, embedding_dim // 2),
                        dtype=torch.uint8,
                        device=device,
                    ),
                )

        groups_per_row = (embedding_dim + group_size - 1) // group_size
        if groups_per_row > 1:
            self.register_buffer(
                "scales",
                torch.ones(
                    (vocab_size, groups_per_row), dtype=torch.float16, device=device
                ),
            )
        else:
            self.register_buffer(
                "scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
            )

    @torch.no_grad()
    def forward(self, indices: torch.Tensor) -> torch.Tensor:
        if not self.packed:  # 8bit
            return torch.ops.quantized_decomposed.embedding_byte.dtype(
                self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
            )
        else:  # packed
            if self.bitwidth == 2:
                return torch.ops.quantized_decomposed.embedding_2bit.dtype(
                    self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
                )

            # Remaining case (always return to make pyre happy)
            assert self.bitwidth == 4
            return torch.ops.quantized_decomposed.embedding_4bit.dtype(
                self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
            )


############################ Source Transform Start #######################


def get_quant_embedding_transform(args):
    if args.embedding_quantize.startswith("torchao:"):
        bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
        group_size = int(group_size)
        bitwidth = int(bitwidth)
        _load_torchao_ops_aten()
        from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer

        def _torchao_embedding_quantizer(model):
            with torch.no_grad():
                model = IntxWeightEmbeddingQuantizer(
                    device="cpu",
                    precision=torch.float32,
                    bitwidth=bitwidth,
                    groupsize=group_size,
                ).quantize(model)
            return model

        return _torchao_embedding_quantizer

    bitwidth, group_size = args.embedding_quantize.split(",")
    if group_size == "none" or group_size == "None" or group_size == "0":
        group_size = None
    else:
        group_size = int(group_size)
    bitwidth = int(bitwidth)
    return lambda model: EmbeddingQuantHandler(
        model,
        bitwidth=bitwidth,
        group_size=group_size,
        packed=(bitwidth in [2, 4]),
    ).quantized_model()


def get_quant_weight_transform(args, dtype_override, verbose):
    # If these optional args are None, don't provide them to quantize()
    quant_args_str = [
        "group_size",
        "calibration_tasks",
        "calibration_limit",
        "calibration_seq_length",
    ]
    arg_dict = vars(args)
    quant_args = {
        param: val
        for param in quant_args_str
        if (val := arg_dict.get(param)) is not None
    }

    return partial(
        quantize,
        **quant_args,
        qmode=args.quantization_mode,
        activation_dtype=dtype_override,
        checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None),
        tokenizer_path=(
            Path(path) if (path := args.tokenizer_path) is not None else None
        ),
    )


def _load_torchao_ops_aten():
    import glob
    import os

    libs = glob.glob(
        os.path.abspath(
            os.path.join(
                os.environ.get("CMAKE_INSTALL_PREFIX", ""),
                "lib/libtorchao_ops_aten.*",
            )
        )
    )
    assert (
        len(libs) == 1
    ), f"Expected 1 library but got {len(libs)}.  If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
    logging.info(f"Loading custom ops library: {libs[0]}")
    torch.ops.load_library(libs[0])


############################ Source Transform End #######################
