# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import codecs
import getpass
import json
import os
import time
from multiprocessing.connection import Client

import torch
from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo

from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner

from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO
from executorch.backends.qualcomm.utils.utils import (
    capture_program,
    convert_linear_to_conv2d,
    generate_htp_compiler_spec,
    generate_qnn_executorch_compiler_spec,
    get_soc_to_chipset_map,
)
from executorch.examples.qualcomm.oss_scripts.llama2.model.static_llama import (
    LlamaModel,
    ModelArgs,
)
from executorch.examples.qualcomm.utils import (
    make_output_dir,
    make_quantizer,
    setup_common_args_and_variables,
    SimpleADB,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.extension.llm.export.builder import DType

from sentencepiece import SentencePieceProcessor
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e


pte_filename = "llama2_qnn"


def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
    """
    This function is specific for matmul op 16a8w.
    """

    from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
    from executorch.backends.qualcomm.quantizer.quantizer import (
        get_16a8w_qnn_ptq_config,
        get_8a8w_qnn_ptq_config,
        QuantizationConfig,
    )
    from torch.ao.quantization.quantizer import (
        QuantizationAnnotation,
        SharedQuantizationSpec,
    )
    from torch.fx import Node

    def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
        input_qspec_map = {}
        input_act = node.args[0]
        input_spec = quantization_config.input_activation
        input_qspec_map[input_act] = input_spec

        input_act1 = node.args[1]
        input_spec1 = quantization_config.weight
        input_qspec_map[input_act1] = input_spec1

        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )

    def annotate_cat(node: Node, quantization_config: QuantizationConfig):
        input_nodes = node.args[0]

        first_input_node = input_nodes[0]
        input_qspec_map = {}
        input_qspec_map[first_input_node] = quantization_config.input_activation
        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
            (first_input_node, node)
        )

        for input_node in input_nodes[1:]:
            if input_node not in input_qspec_map:
                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec

        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=share_qparams_with_input_act0_qspec,
            _annotated=True,
        )

    def annotate_single_in_single_out(
        node: Node, quantization_config: QuantizationConfig
    ) -> None:

        input_qspec_map = {}
        input_act = node.args[0]
        input_qspec_map[input_act] = quantization_config.input_activation

        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )

    def annotate_matmul_input1(node: Node):
        quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True)
        while isinstance(node, Node) and node.op == "call_function":
            if node.target in [
                torch.ops.aten.permute.default,
                torch.ops.aten.transpose.int,
            ]:
                annotate_single_in_single_out(node, quantization_config_8a8w)
                node = node.args[0]
            elif node.target == torch.ops.aten.cat.default:
                annotate_cat(node, quantization_config_8a8w)
                node = node.args[0][0]
            else:
                node = node.args[0]

    quantization_config_16a8w = get_16a8w_qnn_ptq_config()

    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
            annotate_matmul(node, quantization_config_16a8w)
            annotate_matmul_input1(node.args[1])


def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
    from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
    from executorch.backends.qualcomm.quantizer.quantizer import (
        get_ptq_per_channel_quant_config,
        QuantizationConfig,
    )
    from torch.ao.quantization.quantizer import QuantizationAnnotation
    from torch.fx import Node

    def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
        input_qspec_map = {}
        input_act = node.args[0]
        input_spec = quantization_config.input_activation
        input_qspec_map[input_act] = input_spec

        weight = node.args[1]
        input_qspec_map[weight] = quantization_config.weight

        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
            input_qspec_map=input_qspec_map,
            output_qspec=quantization_config.output_activation,
            _annotated=True,
        )

    quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
        torch.uint16, weight_dtype=torch.int8
    )
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
            if "nn_module_stack" in node.meta:
                module_values_list = list(node.meta["nn_module_stack"].values())
                full_qualified_name = module_values_list[0][0]
                if full_qualified_name == "L['self'].llama.output":
                    annotate_conv2d(
                        node, quantization_config=quantization_config_16a8w_per_channel
                    )


def calibrate(
    example_inputs,
    user_prompts,
    module: torch.fx.GraphModule,
    tokenizer_model_path="tokenizer.model",
):
    sp_model = SentencePieceProcessor(model_file=tokenizer_model_path)
    _, _, atten_mask, k_caches, v_caches = example_inputs

    # TODO: change criteria & support batch inputs if necessary
    pos = torch.tensor(0, dtype=torch.int32)
    token_list = [sp_model.bos_id()]
    for prompt in user_prompts.split():
        token_list += sp_model.encode(prompt)

    def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
        probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True)
        probs_sum = torch.cumsum(probs_sort, dim=-1)
        mask = probs_sum - probs_sort > top_p
        probs_sort[mask] = 0
        probs_sort /= probs_sort.sum(dim=-1, keepdim=True)
        next_token = torch.multinomial(probs_sort, num_samples=1)
        return probs_indices.gather(dim=-1, index=next_token)

    with torch.no_grad():
        while token_list[-1] != sp_model.eos_id() and pos < 128:
            logits, new_k_caches, new_v_caches = module(
                torch.full((1, 1), token_list[pos]),
                torch.full((1, 1), pos),
                atten_mask,
                *k_caches,
                *v_caches,
            )
            k_caches = [
                torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
                for i, k_cache in enumerate(k_caches)
            ]
            v_caches = [
                torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
                for i, v_cache in enumerate(v_caches)
            ]

            pos += 1
            atten_mask[0][-pos - 1] = 0
            if pos >= len(token_list):
                probs = torch.softmax(logits[:, -1] / 0.8, dim=-1)
                token_list.append(sample_top_p(probs, 0.9).item())

    print(f"calibration data:\n{sp_model.decode(token_list)}")


class SingleLlama:
    def __init__(self, llama_model) -> None:
        super().__init__()
        self.llama_model = llama_model
        self.quant_dtype = None
        self.llama_meta = self.llama_model.get_metadata()
        self.has_quant_io = False
        tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs()
        self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches)

    def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type):
        if not self.has_quant_io:
            return

        # shape of k caches and v caches
        input_cache_shape = {
            (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]),
            (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]),
        }
        for n in gm.graph.nodes:
            if (
                n.op == "placeholder"
                and len(users := list(n.users)) == 1
                and users[0].meta["val"].size()[-2:] in input_cache_shape
            ):
                n.meta[QCOM_QUANTIZED_IO] = kv_type
            elif n.op == "output":
                for a in n.args[0]:
                    if (
                        a.meta["val"].flatten().size()[0]
                        == self.llama_meta["get_head_dim"]
                    ):
                        a.meta[QCOM_QUANTIZED_IO] = kv_type

    def quantize(self, quant_dtype, custom_annotations=()):
        self.quant_dtype = quant_dtype
        quantizer = make_quantizer(
            quant_dtype=quant_dtype,
            per_channel_conv=True,
            per_channel_linear=True,
            act_observer=MinMaxObserver,
        )
        quantizer.add_custom_quant_annotations(custom_annotations)

        self.has_quant_io = True
        fx_graph_module = None

        with torch.no_grad():
            fx_graph_module = torch.export.export(
                self.llama_model, self.inputs
            ).module()
            fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
        print("Quantizing the model...")
        calibrate(
            self.get_example_inputs(),
            args.prompt,
            fx_graph_module,
            tokenizer_model_path=args.tokenizer_model,
        )

        self.llama_model = convert_pt2e(fx_graph_module)

    def lowering_modules(
        self, work_space, kv_type=torch.uint8, soc_model=QcomChipset.SM8650
    ):
        executorch_config = ExecutorchBackendConfig(
            passes=[
                BuildQuantIo(),
            ],
            # For shared buffer, user must pass the memory address
            # which is allocated by RPC memory to executor runner.
            # Therefore, won't want to pre-allocate
            # by memory manager in runtime.
            memory_planning_pass=MemoryPlanningPass(
                alloc_graph_input=False,
                alloc_graph_output=False,
            ),
            extract_delegate_segments=True,
        )
        with torch.no_grad():
            # backend option
            backend_options = generate_htp_compiler_spec(use_fp16=False)
            compiler_specs = generate_qnn_executorch_compiler_spec(
                soc_model=soc_model,
                backend_options=backend_options,
                shared_buffer=True,
            )
            partitioner = QnnPartitioner(compiler_specs)
            edge_prog = capture_program(self.llama_model, self.inputs)
            self._tag_kv_ios(edge_prog.exported_program.graph_module, kv_type=kv_type)
            edge_prog_mgr = EdgeProgramManager(
                edge_programs={"forward": edge_prog.exported_program},
                constant_methods=self.llama_meta,
                compile_config=EdgeCompileConfig(_check_ir_validity=False),
            )
            edge_prog_mgr = edge_prog_mgr.to_backend(partitioner)
            exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
            with open(f"{work_space}/{pte_filename}.pte", "wb") as file:
                exec_prog_mgr.write_to_file(file)

    def get_example_inputs(self):
        return self.llama_model.get_example_inputs()


def compile(args):
    os.makedirs(args.artifact, exist_ok=True)
    start_ts = time.time()
    with open(args.params) as f:
        config = ModelArgs(**json.load(f))
        # TODO: support batch inputs if necessary
        config.max_batch_size = 1
        config.max_seq_len = 1024
    state_dict = torch.load(
        args.checkpoint, weights_only=True, map_location="cpu", mmap=True
    )
    end_load_ts = time.time()
    print("torch.load checkpoint", end_load_ts - start_ts)

    llama_instance = None
    with torch.device("meta"):
        llama_instance = LlamaModel(config, output_new_cache_only=True)
    if "model" in state_dict:
        state_dict = state_dict["model"]
    llama_instance.load_state_dict(
        state_dict,
        strict=False,
        assign=True,
    )
    end_load_state_dict_ts = time.time()
    print("instance.load_state_dict", end_load_state_dict_ts - end_load_ts)

    for layer in llama_instance.layers:
        if getattr(layer.attention, "prepare_sha", None):
            layer.attention.prepare_sha()

    kv_type = torch.uint8
    assert args.ptq in [
        "8a8w",
        "16a4w",
    ], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w."
    quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
    assert args.tokenizer_model is not None, "Need tokenizer model for calibration"

    if args.dtype_override is not None:
        dtype_override = DType[args.dtype_override]
        llama_instance = llama_instance.to(dtype_override.to_torch_dtype())

    llama_instance = convert_linear_to_conv2d(llama_instance)
    single_llama = SingleLlama(llama_instance.eval())

    start_quantize_ts = time.time()
    single_llama.quantize(
        quant_dtype,
        custom_annotations=(
            annotate_matmul_16a8w,
            annotate_linear_16a8w_in_affine_layer,
        ),
    )
    end_quantize_ts = time.time()
    print("single_llama.quantize(quant_dtype)", end_quantize_ts - start_quantize_ts)
    single_llama.lowering_modules(
        args.artifact, kv_type=kv_type, soc_model=get_soc_to_chipset_map()[args.model]
    )
    end_lowering_ts = time.time()
    print("Complete Compile", end_lowering_ts - end_quantize_ts)


def inference(args, pre_gen_pte=""):
    workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"

    runner_args = " ".join(
        [
            f"--model_path {pte_filename}.pte",
            "--output_folder_path outputs",
            f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}",
            f'--prompt "{args.prompt}"',
            f"--seq_len {args.seq_len}",
            f"--temperature {args.temperature}",
        ]
    )
    runner_cmd = " ".join(
        [
            f"cd {workspace} &&",
            f"./qnn_llama_runner {runner_args}",
        ]
    )

    pte_path = (
        f"{pre_gen_pte}/{pte_filename}.pte"
        if pre_gen_pte
        else f"{args.artifact}/{pte_filename}.pte"
    )
    adb = SimpleADB(
        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
        build_path=f"{args.build_folder}",
        pte_path=pte_path,
        workspace=workspace,
        device_id=args.device,
        host_id=args.host,
        soc_model=args.model,
        shared_buffer=args.shared_buffer,
        runner="examples/qualcomm/oss_scripts/llama2/qnn_llama_runner",
    )
    # No pregen inputs, input_list is not required
    adb.push(inputs=[], input_list="", files=[args.tokenizer_bin])
    adb.execute(custom_runner_cmd=runner_cmd)

    # collect output data
    output_data_folder = f"{args.artifact}/outputs"
    make_output_dir(output_data_folder)
    outputs = []

    def post_process():
        for f in sorted(
            os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1])
        ):
            with codecs.open(
                os.path.join(output_data_folder, f),
                "r",
                encoding="utf-8",
                errors="replace",
            ) as fdata:
                outputs.append(fdata.read())

    adb.pull(output_path=args.artifact, callback=post_process)

    if args.ip and args.port != -1:
        with Client((args.ip, args.port)) as conn:
            conn.send(
                json.dumps(
                    {
                        "result": outputs,
                    }
                )
            )
    else:
        for idx, output in enumerate(outputs):
            print(f"Results[{idx}]:\n{output}")


# flake8: noqa: C901
if __name__ == "__main__":
    parser = setup_common_args_and_variables()
    parser.add_argument(
        "-a",
        "--artifact",
        help="path for storing generated artifacts and output by this example. Default ./llama2_qnn",
        default="./llama2_qnn",
        type=str,
    )

    parser.add_argument(
        "-P",
        "--ptq",
        help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.",
        default="16a4w",
    )

    parser.add_argument(
        "--checkpoint",
        help="Pass llama2 checkpoint.",
        required=True,
        type=str,
    )

    parser.add_argument(
        "--params",
        help="Pass llama2 params json file.",
        required=True,
        type=str,
    )

    parser.add_argument(
        "--tokenizer_bin",
        help="Pass llama2 tokenizer binary.",
        required=True,
        type=str,
    )

    parser.add_argument(
        "--tokenizer_model",
        help="Pass llama2 tokenizer model.",
        type=str,
        default=None,
    )

    parser.add_argument(
        "--prompt",
        help="User prompts for llama2.",
        required=True,
        type=str,
    )

    parser.add_argument(
        "--seq_len",
        help="Ouput sequence length for llama2.",
        default=128,
        type=int,
    )

    parser.add_argument(
        "--temperature",
        help="Sampling temperature for llama2.",
        default=0.8,
        type=float,
    )

    parser.add_argument(
        "-d",
        "--dtype-override",
        default="fp32",
        type=str,
        choices=["fp32", "fp16"],
        help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32",
    )

    parser.add_argument(
        "--pre_gen_pte",
        help="Run the Pre-generated llama2 in the given directory",
        type=str,
    )

    args = parser.parse_args()
    if args.compile_only and args.pre_gen_pte:
        exit("Cannot set both compile_only and pre_gen_pte as true")

    if args.pre_gen_pte:
        inference(args, args.pre_gen_pte)
        exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")

    if args.compile_only:
        compile(args)
        exit(f"Finish compile_only and save to {args.artifact}")

    try:
        compile(args)
        inference(args)
    except Exception as e:
        if args.ip and args.port != -1:
            with Client((args.ip, args.port)) as conn:
                conn.send(json.dumps({"Error": str(e)}))
        else:
            raise Exception(e)
