# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from collections import defaultdict
from typing import final, List

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager

import torch  # noqa: F401
from executorch.backends.qualcomm._passes.fuse_consecutive_transpose import (
    FuseConsecutiveTranspose,
)
from executorch.backends.qualcomm._passes.insert_io_qdq import InsertIOQDQ
from executorch.backends.qualcomm._passes.insert_requantize import InsertRequantize
from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option
from executorch.exir.backend.backend_details import (
    BackendDetails,
    CompileSpec,
    PreprocessResult,
)
from executorch.exir.passes import PassManager
from torch.export.exported_program import ExportedProgram

DEFAULT_DEBUG_HANDLE = 65535

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


@final
class QnnBackend(BackendDetails):
    @staticmethod
    def preprocess(
        edge_program: ExportedProgram,
        compile_specs: List[CompileSpec],
    ) -> PreprocessResult:
        option = generate_qnn_executorch_option(compile_specs)
        qnn_manager = PyQnnManager.QnnManager(option)
        qnn_manager.Init()

        # QNN Delegate Specific Passes
        qnn_compiler_passes = PassManager(
            passes=[
                InsertRequantize(edge_program),
                InsertIOQDQ(edge_program),
                LayoutTransform(edge_program, insert_permute=True),
                FuseConsecutiveTranspose(),
            ]
        )

        pass_result = qnn_compiler_passes(edge_program.graph_module)
        assert pass_result is not None

        enable_tensor_dump = qnn_manager.IsTensorDump()
        nodes_to_wrappers = defaultdict(dict)
        node_visitors = get_node_visitors(
            edge_program, enable_tensor_dump=enable_tensor_dump
        )
        py_op_wrapper_list = []
        for node in pass_result.graph_module.graph.nodes:
            if node.op == "call_function":
                logger.info(f"Visiting: {node}, {node.target.__name__}")
                if node.target.__name__ in node_visitors:
                    py_op_wrapper = node_visitors[node.target.__name__].define_node(
                        node, nodes_to_wrappers
                    )
                    if py_op_wrapper is not None:
                        if isinstance(py_op_wrapper, List):
                            py_op_wrapper_list.extend(py_op_wrapper)
                        else:
                            py_op_wrapper_list.append(py_op_wrapper)
                else:
                    err_msg = (
                        f"For {node}, {node.op}:{node.target.__name__} "
                        "is not supported in Qnn Delegate"
                    )
                    try:
                        context_loader_target = eval(
                            f"torch.ops.{OpContextLoader.namespace}.{node.target.__name__}",
                            globals().update(torch.__dict__),
                        )
                        assert node.target == context_loader_target, err_msg
                        # if graph has context binary loader node, return directly
                        return PreprocessResult(
                            processed_bytes=node.meta[OpContextLoader.meta_ctx_bin],
                            debug_handle_map={},
                        )
                    except:
                        raise RuntimeError(err_msg)

            elif node.op in [
                "get_attr",
                "placeholder",
                "output",
            ]:
                continue
            else:
                raise RuntimeError(f"{node.op} is not supported in Qnn")
        qnn_context_binary = qnn_manager.Compile(
            qnn_manager.GetGraphNames()[0],
            [py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrapper_list],
        )
        assert len(qnn_context_binary) != 0, "Failed to generate Qnn context binary."
        qnn_manager.Destroy()
        # For now, debug_handle_map is not used by QNN ExecuTorch
        return PreprocessResult(
            processed_bytes=bytes(qnn_context_binary),
            debug_handle_map={},
        )
