# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import operator
from collections import defaultdict
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union

import torch
from executorch.exir.backend.backend_details import ExportedProgram
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
    duplicate_constant_node,
)
from executorch.exir.common import setting_python_recursive_limit
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.lowered_backend_module import create_submodule_from_nodes
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.fx.node import Node
from torch.fx.passes.utils.source_matcher_utils import SourcePartition

T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default


# NB: Set this to None to handle validation from MobileBert
@lru_cache(maxsize=None)
def is_same_node(
    node_left: Iterable[torch.fx.Node],
    node_right: Iterable[torch.fx.Node],
) -> bool:
    # two nodes are the same if they have the same target and op
    # same for their args
    if isinstance(node_left, torch.fx.Node) and isinstance(node_right, torch.fx.Node):
        if not (
            (node_left.target == node_right.target)
            and (node_left.op == node_right.op)
            and (len(node_left.all_input_nodes) == len(node_right.all_input_nodes))
            and all(
                is_same_node(arg_left, arg_right)
                for arg_left, arg_right in zip(
                    node_left.all_input_nodes, node_right.all_input_nodes
                )
            )
        ):
            return False
    else:
        if len(list(node_left)) != len(list(node_right)):
            return False
        for n_left, n_right in zip(node_left, node_right):
            if not is_same_node(n_left, n_right):
                return False
    return True


def is_identical_graph(
    graph_left: torch.fx.GraphModule, graph_right: torch.fx.GraphModule
) -> bool:
    # two graph are the same if they have the same nodes and op. The order of nodes also
    # matters in this function is more strict. Two graph are not considered as the same
    # if the topological order of the nodes is the same in this function but the order of nodes
    # is not the same.
    if len(list(graph_left.graph.nodes)) != len(list(graph_right.graph.nodes)):
        return False
    with setting_python_recursive_limit(30000):
        for node_left, node_right in zip(
            graph_left.graph.nodes, graph_right.graph.nodes
        ):
            if not (is_same_node(node_left, node_right)):
                return False
    return True


def remove_first_quant_and_last_dequant(
    graph_module: torch.fx.GraphModule,
) -> None:
    for node in graph_module.graph.nodes:
        if node.target == T_QuantPerTensor:
            if node.args[0].op == "placeholder":
                node_users = list(node.users.keys())
                for dequant_node in node_users:
                    # point the dequant arg to the placeholder
                    dequant_node.args = (node.args[0],) + dequant_node.args[1:]
        elif node.target == T_DQuantPerTensor:
            node_users = list(node.users.keys())
            if node_users[0].op == "output":
                # point the output arg to the quant node
                output_node = node_users[0]
                output_node.args = ([node.args[0]],)
    # Remove the quant/dequant nodes as they don't have users
    graph_module.graph.eliminate_dead_code()
    graph_module.recompile()


# TODO - use edge ops
def replace_quantized_partition_with_op(
    graph_module: torch.fx.GraphModule,
    partition: SourcePartition,
    replacement_op: torch._ops.OpOverloadPacket,
) -> Tuple[torch.fx.Node, List[torch.fx.Node], List[torch.fx.Node]]:
    """
    Replaces partition with the op specified by replacement_op. It's also expected that
    the nodes contained in partition are sourced from a quantized module as this function
    searches for the quantization pattern to consume along with the nodes in the partition,
    to be then replaced by replacement_op.

    Args:
        graph_module: The graph module from which this partition was sourced.
        partition: Partition to be replaced.
        replacement_op: The op to replace paritition with.
    Returns:
        Tuple: First element in the tuple is the new replaced module. The second and third
        node lists in the returned tuple consist of the dq and q nodes that were consumed
        along with this partition to be replaced by the replacement_op.
    """

    dequant_nodes = []
    quant_nodes = []
    input_nodes = []
    output_nodes = []

    partition_nodes = [node for node in partition.nodes if node not in partition.params]

    # We recreate our input nodes and output nodes list instead of using partition.input_nodes
    # and partition.output_nodes as the ordering of the nodes in those lists is not deterministic,
    # whereas for the quant fusion pass we expect deterministic ordering.
    for node in partition.nodes:
        for arg in node.args:
            if isinstance(arg, torch.fx.Node) and (arg not in partition.nodes):
                input_nodes.append(arg)

        for user in node.users.keys():
            if user not in partition.nodes:
                output_nodes.append(node)

    # Try to find all the dq nodes that are feeding into this module.
    for node in input_nodes:
        if node.target == T_DQuantPerTensor:
            dequant_nodes += [node]

    # Try to find all the q nodes that this module is feeding out into.
    for node in output_nodes:
        for user in node.users.keys():
            if user.target == T_QuantPerTensor:
                quant_nodes += [user]

    assert len(dequant_nodes) >= 1, "Dequant nodes missing in node list to be replaced."
    assert len(quant_nodes) >= 1, "Quant nodes missing in node list to be replaced."

    # After this, node list will essentially contain all the nodes in the
    # dq->op->q pattern that we will want to replace with a custom backend op.
    node_list = dequant_nodes + partition_nodes + quant_nodes

    submodule, call_module_node = create_submodule_from_nodes(
        graph_module, node_list, "to_be_replaced", skip_legalize_graph=True
    )

    # Update the replaced op so that we have all the latest args and kwargs.
    with graph_module.graph.inserting_before(call_module_node):
        replaced_op = graph_module.graph.call_function(
            replacement_op,
            call_module_node.args,
            kwargs=call_module_node.kwargs,
        )
        call_module_node.replace_all_uses_with(replaced_op)
        graph_module.graph.erase_node(call_module_node)
        replaced_op.meta = call_module_node.meta
    graph_module.recompile()

    return (replaced_op, dequant_nodes, quant_nodes)


def _assign_new_tag(
    tagged_exported_program: ExportedProgram,
    copied_nodes: Set[str],
):
    """
    Assign new tag to the copied nodes.

    Before the pass
    constant_0 (tag_10) ------------------> op_b (tag_10)
    constant_0_copy (tag_10) -------------> op_a (tag_11)

    After the pass
    constant_0 (tag_10) ------------------> op_b (tag_10)
    constant_0_copy (tag_11) -------------> op_a (tag_11)

    """
    for node in tagged_exported_program.graph.nodes:
        if node.op == "placeholder":
            if node.name in copied_nodes:
                users_tag = set()
                for user in node.users:
                    users_tag.add(user.meta.get("delegation_tag", None))
                # Assign the tag to the copy constant node the same as their users.
                if len(users_tag) == 1:
                    node.meta["delegation_tag"] = users_tag.pop()


def _maybe_duplicate_constant_nodes(
    tagged_exported_program: ExportedProgram,
    tag: str,
) -> None:
    """
    If the constants node is shared by different tagged nodes, like
    constant_0 ----> op_b (tag_10)
    |-------------> op_a (tag_11)

    we make default as constant_0 is duplicated to constant_0_1, constant_0_2, unless the node is tagged with "no_copy"
    constant_0 ------------------> op_b (tag_10)
    constant_0_copy -------------> op_a (tag_11)

    backend can estimate how much they want to duplicate the constant node, either error out or default to duplicate
    """
    candidate_nodes = set()
    for node in tagged_exported_program.graph.nodes:
        if node.meta.get("delegation_tag", "") == tag:
            if node.op == "placeholder":
                for user in node.users:
                    users_tag = user.meta.get("delegation_tag", None)
                    if users_tag != tag:
                        # If the node is tagged with "no_copy", we stop duplicating it and throw an error
                        if node.meta.get("no_copy", False):
                            raise RuntimeError(
                                f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})"
                            )
                        else:
                            candidate_nodes.add(node.name)
    copied_nodes = set()
    for candidate_node in candidate_nodes:
        # Both tagged exported program and the owning program need to go through the same duplication pass
        copied_nodes = copied_nodes.union(
            duplicate_constant_node(tagged_exported_program, candidate_node)
        )
    candidate_node_with_copies = candidate_nodes.union(copied_nodes)
    _assign_new_tag(tagged_exported_program, candidate_node_with_copies)


def _get_item_from_executorch_call_delegate(node: torch.fx.Node) -> bool:
    """
    Check if the node is the getitem followed by executorch_call_delegate node. These getitems node
    are just for getting the result from delegate because the input/output to delegates are flattened
    """
    return (
        node.target == operator.getitem
        and len(node.args) == 2
        and node.args[0].target == executorch_call_delegate  # pyre-ignore
        and isinstance(node.args[1], int)
    )


def get_non_lowered_nodes(graph: torch.fx.Graph) -> List[torch.fx.Node]:
    """
    Returns a list of non lowered nodes in the graph module.
    """
    return [
        node
        for node in graph.nodes
        if node.op == "call_function"
        and node.target != executorch_call_delegate
        and (not _get_item_from_executorch_call_delegate(node))
    ]


def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
    """
    Returns the list of delegates from the graph.
    """
    return [
        node
        for node in graph.nodes
        if node.op == "get_attr" and node.name.startswith("lowered_module_")
    ]


def print_delegated_graph(graph_module: torch.fx.GraphModule) -> None:
    """
    Print the formatted graph string.
    """
    print(format_delegated_graph(graph_module))


def format_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
    """
    Return the formatted graph string of including lowered_module (both backend id and original graph) together with the graph module. Example output:
    graph():
        %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
        %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
        %arg2_1 : [num_users=2] = placeholder[target=arg2_1]
        %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0]
            backend_id: BackendWithCompilerDemo
            lowered graph():
                %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
                %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
                %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
                %aten_mm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%arg0_1, %arg1_1), kwargs = {})
                %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default, %arg2_1), kwargs = {})
                return [aten_add_tensor]
        %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %arg0_1, %arg1_1, %arg2_1), kwargs = {})
        %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {})
        %aten_sub_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%getitem, %arg0_1), kwargs = {})
        %lowered_module_1 : [num_users=1] = get_attr[target=lowered_module_1]
            backend_id: BackendWithCompilerDemo
            lowered graph():
                %aten_sub_tensor : [num_users=1] = placeholder[target=aten_sub_tensor]
                %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
                %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
                %aten_mm_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%aten_sub_tensor, %arg1_1), kwargs = {})
                %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default_1, %arg2_1), kwargs = {})
                return [aten_add_tensor_1]
        %executorch_call_delegate_1 : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_1, %aten_sub_tensor, %arg1_1, %arg2_1), kwargs = {})
        %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_1, 0), kwargs = {})
        return [getitem_1]
    """
    lowered_module_dict = {
        node.name: getattr(graph_module, node.name)
        for node in graph_module.graph.nodes
        if node.op == "get_attr" and node.name.startswith("lowered_module_")
    }
    indent = "  "
    graph_format_str = "graph():\n"
    for node in graph_module.graph.nodes:
        graph_format_str += f"{indent}{node.format_node()}\n"
        if node.op == "get_attr" and node.name.startswith("lowered_module_"):
            lowered_module = lowered_module_dict[node.name]
            graph_format_str += f"{indent * 2}backend_id: {lowered_module.backend_id}\n"
            graph_format_str += f"{indent * 2}lowered graph():\n"
            for node_in_lowered_module in lowered_module.original_module.graph.nodes:
                graph_format_str += (
                    f"{indent * 3}{node_in_lowered_module.format_node()}\n"
                )
    return graph_format_str


def tag_constant_data(edge_program: ExportedProgram) -> None:
    """
    Util function for partitioners. This function tags the const/param/buffers nodes
    whose users all belong within the same partition. This should be called after tagging all other nodes.
    Any const/param/buffer which is used as input to a subgraph, will be tagged with the same tag as that
    subgraph. Throw error when const/param/buffers is used across different partitions. That is the
    underlying data will be owned by multiple delegates.
    """
    mutated_buffer = set()
    for node in edge_program.graph.nodes:
        if node.op == "placeholder" and (
            is_param(edge_program, node)
            or is_buffer(edge_program, node)
            or is_lifted_tensor_constant(edge_program, node)
        ):
            for node_user in node.users:
                if node_user.name in edge_program.graph_signature.buffers_to_mutate:
                    logging.info(
                        "The buffer node is a mutated buffer node, which is not constant."
                    )
                    mutated_buffer.add(node)

    for node in edge_program.graph.nodes:
        # go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
        if node.op == "placeholder" and (
            is_param(edge_program, node)
            or is_buffer(edge_program, node)
            or is_lifted_tensor_constant(edge_program, node)
        ):
            if node not in mutated_buffer:
                user_tags = set()
                for user in node.users:
                    user_tag = user.meta.get("delegation_tag", None)
                    if user_tag is not None:
                        user_tags.add(user_tag)
                if len(user_tags) > 1:
                    logging.info(
                        f"The data node is used across multiple partitions, including {user_tags}. "
                        "If the data is too large and it's not preferred to copy, please tag the "
                        "constant node like node.['no_copy'] = True and they won't be copied."
                    )
                # tag the data node with the same tag as the last user
                if len(user_tags) > 0:
                    node.meta["delegation_tag"] = user_tags.pop()


def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
    """
    Util function for partitioners. This function tags the mutated buffer nodes
    whose users all belong within the same partition. This should be called after tagging all other nodes.
    Any buffer which is used as input to a subgraph, will be tagged with the same tag as that
    subgraph. Throw error when buffers is used across different partitions. That is the
    underlying data will be owned by multiple delegates.
    """
    for node in edge_program.graph.nodes:
        # Determine whether this node is a mutated buffer
        is_mutated_buffer_node = False
        if node.op == "placeholder" and is_buffer(edge_program, node):
            for node_user in node.users:
                if node_user.name in edge_program.graph_signature.buffers_to_mutate:
                    is_mutated_buffer_node = True
                    break
        # This node is mutated buffer, tag it
        if is_mutated_buffer_node:
            user_tags = set()
            for user in node.users:
                user_tag = user.meta.get("delegation_tag", None)
                if user_tag is not None:
                    user_tags.add(user_tag)
            if len(user_tags) > 1:
                logging.info(
                    f"The data node is used across multiple partitions, including {user_tags}. "
                    "If the data is too large and it's not preferred to copy, please tag the "
                    "constant node like node.['no_copy'] = True and they won't be copied."
                )
            # tag the data node with the same tag as the last user
            if len(user_tags) > 0:
                node.meta["delegation_tag"] = user_tags.pop()


# TODO - style: use templated types
class DelegateMappingBuilder:
    """
    Profiling helper class for building Delegate Mappings.
    Delegate Mappings are mappings from delegate debug identifiers to node
    debug handles. Specifically this is used to log within backend delegates

    Args:
        generated_identifiers (bool, optional): Whether identifier keys are
            generated automatically. Defaults to False.
    """

    def __init__(self, generated_identifiers: bool = False):
        self._generated_identifiers = generated_identifiers

        # Note that the internal struct has a Set value, while the getter
        # function returns the values as a tuple
        self._debug_handle_map: Union[Dict[int, Set[int]], Dict[str, Set[int]]] = (
            defaultdict(set)
        )
        self._next_index: int = 0

    def get_delegate_mapping(
        self,
    ) -> Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]:
        """
        Returns:
           Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]:
                A map of delegate debug identifier to a list of debug handles
                The keys (identifier) are either integers or strings
                The values are a sorted tuple of integer debug handles
        """
        # pyre-ignore Warning between Union[Dict[K, V], Dict[K2, V]] vs Dict[Union[K, K2], V]
        return {k: tuple(sorted(v)) for k, v in self._debug_handle_map.items()}

    def insert_delegate_mapping_entry(
        self,
        nodes: Optional[Union[Node, List[Node]]] = None,
        handles: Optional[Union[int, List[Optional[int]]]] = None,
        identifier: Optional[Union[int, str]] = None,
    ) -> Union[int, str]:
        """
        Add a new delegate mapping entry

        If self._generated_identifiers = False:
            - A new identifier must be provided, else an exception is thrown

        If self._generated_identifiers = True:
            - New identifiers are generated incrementally, 0 indexed
            - Identifiers cannot be manually provided, else an exception is thrown

        Args:
            nodes (Union[Node, List[Node]]): A (list of) Node(s)
            handles (Union[int, List[Optional[int]]]): A (list of) debug handle(s)
            identifier (Optional[Union[int, str]]):
                Debug identifier corresponding to the Node(s)

        Note: Exactly one of nodes and handles must be provided
        Note: If a debug handle is missing or None, it is skipped

        Returns:
            Union[int, str]:
                Delegate debug identifier inserted
        """

        # Check for manual addition of identifier (with generated identifiers enabled)
        if self._generated_identifiers and identifier is not None:
            raise Exception(
                f"Builders using generated identifiers can't manually add identifiers: {identifier}. Failed to add or update entry"
            )

        if identifier is not None and identifier in self._debug_handle_map:
            raise Exception(
                "This delegate debug identifier was already inserted. Duplicate delegate debug identifiers are not allowed."
            )

        # Check for exactly one of nodes and handles being populated
        if not ((nodes is not None) ^ (handles is not None)):
            raise Exception(
                "Only one of nodes or handles must be provided. Either both were provided or neither were provided. Failed to add or update entry."
            )

        # Resolve Identifier
        if identifier is None:
            if self._generated_identifiers:
                identifier = self._next_index
                self._next_index += 1
            else:
                raise Exception(
                    "No identifier provided. Failed to add or update entry."
                )

        # Collect debug handles
        if nodes is not None:
            new_debug_handles = {
                node.meta.get("debug_handle")
                for node in (nodes if isinstance(nodes, List) else [nodes])
            }
        else:
            new_debug_handles = (
                handles if isinstance(handles, (tuple, List)) else [handles]
            )

        # Filter for empty debug handles
        filtered_debug_handles = {
            handle for handle in new_debug_handles if handle is not None
        }
        if len(filtered_debug_handles) == 0:
            raise Exception("No valid debug handles found. Failed to add entry.")

        # pyre-ignore Warning from Union[int, st] keys
        self._debug_handle_map[identifier] = filtered_debug_handles
        return identifier


class WhyNoPartition:
    """
    Simple helper class for partitioners to log why a node was not lowered.

    Example usage:

        # In your backend partitioner file(s)
        why = WhyNoPartition(logger=your_backend_logger)

        # hypothetical function that checks if a node can be lowered
        if not can_be_lowered(node):
            why(node, "This node was not lowered because ...")
    """

    def __init__(self, logger: logging.Logger):
        self.logger = logger
        self.node: Optional[torch.fx.Node] = None
        self.reason: str = ""

    def __call__(self, node: torch.fx.Node, reason: str) -> None:
        self.node = node
        self.reason = reason
        self.logger.debug(self)

    def __str__(self) -> str:
        return f"WhyNoPartition: Node {self.node} was not partitioned because {self.reason}."
