# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import collections
import copy
import os
import subprocess
import tempfile
import unittest
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import torch

from executorch import exir
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
from executorch.backends.qualcomm.utils.utils import (
    capture_program,
    get_soc_to_chipset_map,
)
from executorch.devtools import generate_etrecord, Inspector
from executorch.examples.qualcomm.utils import (
    generate_inputs,
    make_output_dir,
    SimpleADB,
)

from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager
from torch.ao.quantization.quantize_pt2e import (
    convert_pt2e,
    prepare_pt2e,
    prepare_qat_pt2e,
)


def generate_context_binary(
    module: torch.nn.Module,
    inputs: Dict[str, torch.Tensor],
    quantized: bool,
    artifact_dir: str,
):
    # we also expect clang showing in PATH or context may fail to generate
    qnn_sdk = os.environ.get("QNN_SDK_ROOT", None)
    ndk = os.environ.get("ANDROID_NDK_ROOT", None)
    assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable"
    assert ndk, "ANDROID_NDK_ROOT was not found in environment variable"

    inputs_tup = tuple(inputs.values())
    jit_module = torch.jit.trace(module, inputs_tup)
    torch.jit.save(jit_module, f"{artifact_dir}/jit_module.pt")

    # input data
    if quantized:
        input_list = []
        for name, data in inputs.items():
            file_name = f"{artifact_dir}/{name}.raw"
            data.detach().numpy().tofile(file_name)
            input_list.append(file_name)

        with open(f"{artifact_dir}/input_list.txt", "w") as f:
            f.write(" ".join(input_list))

    # flow of qnn tools
    target = "x86_64-linux-clang"
    inputs_str = [
        f"-d '{k}' {str(tuple(v.shape)).replace(' ', '')[1:-1]}"
        for k, v in inputs.items()
    ]
    cmds = [
        # setup qnn env
        f"source {qnn_sdk}/bin/envsetup.sh;"
        # qnn-pytorch-converter
        f"{qnn_sdk}/bin/{target}/qnn-pytorch-converter",
        f"-i {artifact_dir}/jit_module.pt",
        *inputs_str,
        f"--input_list {artifact_dir}/input_list.txt" if quantized else "",
        "--preserve_io",
        f"-o {artifact_dir}/model.cpp;",
        # qnn-model-lib-generator
        f"{qnn_sdk}/bin/{target}/qnn-model-lib-generator",
        f"-c {artifact_dir}/model.cpp",
        f"-t {target}",
        "-l model",
        f"-o {artifact_dir}/model_libs;",
        # qnn-context-binary-generator
        f"{qnn_sdk}/bin/{target}/qnn-context-binary-generator",
        f"--model {artifact_dir}/model_libs/{target}/libmodel.so",
        f"--backend {qnn_sdk}/lib/{target}/libQnnHtp.so",
        "--binary_file model_ctx",
        f"--output_dir {artifact_dir};",
    ]
    result = subprocess.run(
        " ".join(cmds),
        shell=True,
        executable="/bin/bash",
        capture_output=True,
    )
    assert os.path.isfile(f"{artifact_dir}/model_ctx.bin"), print(result.stderr)


class TestQNN(unittest.TestCase):
    rtol: float = 0
    atol: float = 0
    host: str = ""
    device: str = ""
    build_folder: str = ""
    model: QcomChipset = None
    compiler_specs: List[CompileSpec] = None
    chipset_table = get_soc_to_chipset_map()
    error_only = False
    ip = "localhost"
    port = 8080
    executorch_root: str = ""
    artifact_dir: str = ""
    image_dataset: str = ""
    pretrained_weight: str = ""
    enable_profile: bool = False
    online_prepare: bool = False
    use_8a8w: str = "8a8w"
    use_16a16w: str = "16a16w"
    use_16a4w: str = "16a4w"
    shared_buffer: bool = False
    enable_x86_64: bool = False

    def _assert_outputs_equal(self, model_output, ref_output):
        self.assertTrue(len(ref_output) == len(model_output))
        for i in range(len(ref_output)):
            self.assertTrue(
                torch.allclose(
                    model_output[i], ref_output[i], atol=self.atol, rtol=self.rtol
                ),
                msg=f"ref_output:\n{ref_output[i]}\n\nmodel_output:\n{model_output[i]}",
            )

    def _save_model_and_expected_output(
        self,
        module: torch.nn.Module,
        buffer: exir.ExirExportedProgram,
        inputs: Tuple[torch.Tensor],
        dir_name: str,
    ) -> None:
        # Save the input data list to be executed
        input_list = ""
        for idx, _ in enumerate(inputs):
            input_name = f"input_0_{idx}.raw"
            input_list += input_name + " "
        input_list = input_list.strip() + "\n"

        ref_output = module(*inputs)

        # Save the expected output data to be verified
        ref_outputs = []
        if isinstance(ref_output, collections.OrderedDict):
            ref_outputs.append(ref_output["out"].detach())
        elif isinstance(ref_output, (list, tuple)):
            for output in ref_output:
                ref_outputs.append(output.detach())
        else:
            ref_outputs.append(ref_output.detach())

        pte_fname = f"{dir_name}/qnn_executorch_test.pte"
        with open(pte_fname, "wb") as file:
            file.write(buffer)

        return input_list, ref_outputs, pte_fname

    def verify_output(  # noqa: C901
        self,
        module: torch.nn.Module,
        sample_inputs: Tuple[torch.Tensor],
        executorch_prog: ExecutorchProgram | ExecutorchProgramManager,
        etrecord_path: str = "etrecord.bin",
        expected_profile_events: int = -1,
        expected_intermediate_events: int = -1,
        method_index: int = 0,
    ):
        with tempfile.TemporaryDirectory() as tmp_dir:
            (
                input_list,
                ref_outputs,
                pte_fname,
            ) = self._save_model_and_expected_output(
                module,
                executorch_prog.buffer,
                sample_inputs,
                tmp_dir,
            )

            output_dir = f"{tmp_dir}/outputs"
            outputs = []
            etdump_path = f"{tmp_dir}/etdump.etdp"
            debug_output_path = f"{tmp_dir}/debug_output.bin"

            def post_process():
                for i, f in enumerate(sorted(os.listdir(output_dir))):
                    filename = os.path.join(output_dir, f)
                    output = np.fromfile(filename, dtype=ref_outputs[i].numpy().dtype)
                    output = torch.from_numpy(output).reshape(ref_outputs[i].shape)
                    outputs.append(output)

            def validate_profile():
                inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path)
                self.assertTrue(
                    len(inspector.to_dataframe().index) == expected_profile_events
                )

            def validate_intermediate_tensor():
                inspector = Inspector(
                    etdump_path=etdump_path, debug_buffer_path=debug_output_path
                )
                for event_block in inspector.event_blocks:
                    if event_block.name == "Execute":
                        self.assertTrue(
                            len(event_block.events) == expected_intermediate_events
                        )

            if self.enable_x86_64:
                generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list)
                make_output_dir(output_dir)

                target = "x86_64-linux-clang"
                qnn_sdk = os.environ.get("QNN_SDK_ROOT", None)
                assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable"

                build_folder = self.build_folder
                if os.path.isabs(self.build_folder):
                    # obey user's opinion
                    pass
                else:
                    # ok, assuming the user give a relative path to cwd
                    build_folder = os.path.join(os.getcwd(), self.build_folder)

                cmd = [
                    # qnn_executor_runner
                    f"{build_folder}/examples/qualcomm/executor_runner/qnn_executor_runner",
                    "--model_path",
                    pte_fname,
                    "--input_list_path",
                    f"{tmp_dir}/input_list.txt",
                    "--output_folder_path",
                    output_dir,
                    "--method_index",
                    str(method_index),
                ]
                if expected_intermediate_events != -1:
                    cmd.append("--dump_intermediate_outputs")

                env = dict(os.environ)
                env["LD_LIBRARY_PATH"] = f"{qnn_sdk}/lib/{target}/:{build_folder}/lib"
                proc = subprocess.run(
                    cmd,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.STDOUT,
                    env=env,
                    cwd=tmp_dir,
                )

                self.assertEqual(
                    proc.returncode,
                    0,
                    f"The process running qnn_executorch_runner return {proc.returncode}, "
                    "STDOUT=\n"
                    f"{proc.stdout.decode('utf-8')}",
                )

                # Verify the outputs
                post_process()
                self._assert_outputs_equal(outputs, ref_outputs)

                # Verify the etdump
                if expected_profile_events != -1:
                    validate_profile()

                if expected_intermediate_events != -1:
                    validate_intermediate_tensor()
            else:
                adb = SimpleADB(
                    qnn_sdk=os.getenv("QNN_SDK_ROOT"),
                    build_path=self.build_folder,
                    pte_path=pte_fname,
                    workspace="/data/local/tmp/qnn_executorch_test",
                    device_id=self.device,
                    host_id=self.host,
                    soc_model=self.model,
                    error_only=self.error_only,
                    dump_intermediate_outputs=(
                        True if expected_intermediate_events != -1 else False
                    ),
                )
                adb.push(inputs=[sample_inputs], input_list=input_list)
                adb.execute(method_index=method_index)
                adb.pull(output_path=tmp_dir, callback=post_process)
                self._assert_outputs_equal(outputs, ref_outputs)

                if expected_profile_events != -1:
                    adb.pull_etdump(etdump_path, callback=validate_profile)

                if expected_intermediate_events != -1:
                    adb.pull_debug_output(
                        etdump_path,
                        debug_output_path,
                        callback=validate_intermediate_tensor,
                    )

    def lower_module_and_test_output(
        self,
        module: torch.nn.Module,
        sample_inputs: Tuple[torch.Tensor],
        expected_partitions: int = 1,
        expected_profile_events: int = -1,
        expected_intermediate_events: int = -1,
        assert_output_equal: bool = True,
        skip_node_id_set: set = None,
        skip_node_op_set: set = None,
    ):
        qnn_partitioner = QnnPartitioner(
            self.compiler_specs, skip_node_id_set, skip_node_op_set
        )
        delegated_program = capture_program(module, sample_inputs)

        # this is needed for the ETRecord as lowering modifies the graph in-place
        edge_copy = copy.deepcopy(delegated_program)

        delegated_program.exported_program = to_backend(
            delegated_program.exported_program, qnn_partitioner
        )
        exec_prog = delegated_program.to_executorch(
            exir.ExecutorchBackendConfig(
                # For shared buffer, user must pass the memory address
                # which is allocated by RPC memory to executor runner.
                # Therefore, won't want to pre-allocate
                # by memory manager in runtime.
                memory_planning_pass=MemoryPlanningPass(
                    alloc_graph_input=not self.shared_buffer,
                    alloc_graph_output=not self.shared_buffer,
                ),
            )
        )

        # Assert the backend name is qnn
        self.assertEqual(
            len(exec_prog.program.execution_plan[0].delegates),
            expected_partitions,
        )
        for i in range(expected_partitions):
            self.assertEqual(
                exec_prog.program.execution_plan[0].delegates[i].id,
                QnnBackend.__name__,
            )

        etrecord_path = "etrecord.bin"
        if self.enable_profile:
            generate_etrecord(etrecord_path, edge_copy, exec_prog)
        # Check numerics
        if (
            assert_output_equal
            or expected_profile_events != -1
            or expected_intermediate_events != -1
        ):
            self.verify_output(
                module,
                sample_inputs,
                exec_prog,
                etrecord_path,
                expected_profile_events,
                expected_intermediate_events,
            )

    def get_qdq_module(
        self,
        module: torch.nn.Module,
        inputs: Tuple[torch.Tensor],
        is_conv_per_channel: Optional[bool] = True,
        is_linear_per_channel: Optional[bool] = False,
        custom_quant_annotations: Tuple[Callable] = (),
        quant_dtype: QuantDtype = QuantDtype.use_8a8w,
    ) -> torch.fx.GraphModule:
        m = torch.export.export(module, inputs).module()

        quantizer = QnnQuantizer()
        quantizer.add_custom_quant_annotations(custom_quant_annotations)
        quantizer.set_per_channel_conv_quant(is_conv_per_channel)
        quantizer.set_per_channel_linear_quant(is_linear_per_channel)
        quantizer.set_quant_config(quant_dtype)

        prepared = prepare_pt2e(m, quantizer)
        prepared(*inputs)
        quantized_module = convert_pt2e(prepared)
        nodes = {node.target for node in quantized_module.graph.nodes}
        q_and_dq = {
            torch.ops.quantized_decomposed.quantize_per_tensor.default,
            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
            torch.ops.quantized_decomposed.quantize_per_channel.default,
            torch.ops.quantized_decomposed.dequantize_per_channel.default,
        }
        self.assertTrue(nodes.intersection(q_and_dq))
        return quantized_module

    def get_prepared_qat_module(
        self,
        module: torch.nn.Module,
        inputs: Tuple[torch.Tensor],
        is_conv_per_channel: Optional[bool] = True,
        is_linear_per_channel: Optional[bool] = False,
        custom_quant_annotations: Tuple[Callable] = (),
        quant_dtype: QuantDtype = QuantDtype.use_8a8w,
    ) -> torch.fx.GraphModule:
        m = torch.export.export_for_training(module, inputs).module()

        quantizer = QnnQuantizer()
        quantizer.add_custom_quant_annotations(custom_quant_annotations)
        quantizer.set_per_channel_conv_quant(is_conv_per_channel)
        quantizer.set_per_channel_linear_quant(is_linear_per_channel)

        if quant_dtype == QuantDtype.use_8a8w:
            quantizer.set_quant_config(quant_dtype, is_qat=True)
        else:
            raise RuntimeError("Shuld not be here")

        prepared = prepare_qat_pt2e(m, quantizer)
        return torch.ao.quantization.move_exported_model_to_train(prepared)

    def get_converted_sgd_trained_module(
        self,
        ori_module: torch.nn.Module,
        prepared: torch.nn.Module,
        inputs: Tuple[torch.Tensor],
    ) -> torch.fx.GraphModule:
        optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
        criterion = torch.nn.CrossEntropyLoss()
        output = prepared(*inputs)
        loss = criterion(output, ori_module(*inputs))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared)

    def split_graph(self, graph_module: torch.fx.GraphModule, division: int):
        class SplitGraph(ExportPass):
            """
            Split graph based on number of nodes.
            """

            def __init__(self, shares):
                super().__init__()
                self.shares = shares

            def _insert_clone(
                self, graph_module: torch.fx.GraphModule
            ) -> torch.fx.GraphModule:
                num_graph_nodes = 0
                for node in graph_module.graph.nodes:
                    num_graph_nodes += 1 if node.op == "call_function" else 0

                    if num_graph_nodes % self.shares != 0 or node.op != "call_function":
                        continue

                    with graph_module.graph.inserting_after(node):
                        users = list(node.users.keys())
                        inserted_node = graph_module.graph.create_node(
                            "call_function",
                            exir_ops.edge.aten.clone.default,
                            (node,),
                        )
                        inserted_node.meta["val"] = node.meta["val"]
                        if "quant_attrs" in node.meta:
                            inserted_node.meta["quant_attrs"] = node.meta["quant_attrs"]
                        for user in users:
                            user.replace_input_with(node, inserted_node)

            def call(self, graph_module: torch.fx.GraphModule):
                self._insert_clone(graph_module)
                graph_module.recompile()

        num_graph_nodes = 0
        for node in graph_module.graph.nodes:
            num_graph_nodes += 1 if node.op == "call_function" else 0

        SplitGraph(-(num_graph_nodes // -division))(graph_module)
