# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import enum
import logging
import operator
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch

from executorch.exir import ExecutorchProgramManager, memory
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from tabulate import tabulate

from torch.ao.quantization.quantize_pt2e import _QUANT_OPS as quant_ops


# Check if the model is quantized, by looking at the graph and finding quant/dequant ops
def model_is_quantized(model: torch.nn.Module) -> bool:
    # Quantized models have to be GraphModules already, from prepare/convert calls.
    # Return false if the model is not a GraphModule.
    if not isinstance(model, torch.fx.GraphModule):
        return False

    # Walk through the graph and look for quant/dequant ops
    for op in quant_ops:
        if model.graph.find_nodes(op="call_function", target=op):
            return True
    return False


# Get the output size of a 1D convolution given the input size and parameters
def get_conv1d_output_size(
    in_size: torch.Size,
    out_channels: int,
    stride: int,
    padding: int,
    dilation: int,
    kernel_size: int,
    channel_last: bool,
) -> torch.Size:
    assert len(in_size) == 3
    if channel_last:
        N, L, C = in_size
    else:
        N, C, L = in_size

    # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
    lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

    if channel_last:
        return torch.Size((N, lout, out_channels))
    return torch.Size((N, out_channels, lout))


# Get the output size of a 2D convolution given the input size and parameters
def get_conv2d_output_size(
    in_size: torch.Size,
    out_channels: int,
    stride: Tuple[int],
    padding: Tuple[int],
    dilation: Tuple[int],
    kernel_size: List[int],
    channel_last: bool,
) -> torch.Size:
    assert len(in_size) == 4
    if channel_last:
        N, H, W, C = in_size
    else:
        N, C, H, W = in_size

    # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    hout = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[
        0
    ] + 1
    wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
        1
    ] + 1
    if channel_last:
        return torch.Size((N, hout, wout, out_channels))
    return torch.Size((in_size[0], out_channels, hout, wout))


# Return the overload packet for the edge op
def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
    edge_op_namespace, edge_op_name = (
        edge_op.namespace,
        edge_op._schema.name.split("::")[1],
    )
    edge_op_overload_packet = getattr(
        getattr(exir_ops.edge, edge_op_namespace), edge_op_name
    )
    return edge_op_overload_packet


# Get the frequency list of ops in a graph module
def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
    freq = {}
    # Loop over nodes to count the number of times each op occurs
    for node in graph_module.graph.nodes:
        if node.op == "call_function":
            # Ignore getitem, alloc and view cases, we only want actual operations
            if (
                node.target == operator.getitem
                or node.target.__name__ == "alloc"
                or node.target == memory.view
            ):
                continue
            # If the op is already present, increment the count
            if node.target._name in freq:
                freq[node.target._name] += 1
            # else, add a new entry
            else:
                freq[node.target._name] = 1
    return freq


# Print the ops and how many times they occur multiple graph modules:
# from export, from to_edge, and from final. Print the available
# implementations for each op, and error out if the op is not supported.
def print_ops_info(
    to_edge_gm: torch.fx.GraphModule,
    final_gm: torch.fx.GraphModule,
) -> None:
    to_edge_ops_count = get_ops_count(to_edge_gm)
    final_ops_count = get_ops_count(final_gm)

    removed_ops = []
    # Get the counts of the ops that are removed from the final graph
    for k in to_edge_ops_count:
        if k not in final_ops_count:
            removed_ops.append(k)

    # Create a dict of ops and their counts to pass to tabulate
    ops_count = [
        [
            op,
            final_ops_count[op],
            to_edge_ops_count[op] if op in to_edge_ops_count else 0,
        ]
        for op in final_ops_count
    ]
    sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True)

    # Create a dict of deleted ops and their counts to pass to tabulate
    removed_ops_count = [
        [
            op,
            0,
            to_edge_ops_count[op] if op in to_edge_ops_count else 0,
        ]
        for op in removed_ops
    ]

    # Print the final ops and their counts in a tabular format
    logging.info(
        tabulate(
            sorted_ops_count,
            headers=[
                "Final Operators                                    ",  # one character longer than the longest op name
                "Final Graph",
                "To_edge Graph",
                "Export Graph",
            ],
            tablefmt="outline",
        )
    )

    # Print the removed ops and their counts in a tabular format (if any)
    if removed_ops != []:
        logging.info(
            tabulate(
                removed_ops_count,
                headers=[
                    "Deleted Operators                                  ",  # one character longer than the longest op name
                    "Final Graph",
                    "To_edge Graph",
                    "Export Graph",
                ],
                tablefmt="outline",
            )
        )


def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
    for node in model_gm.graph.nodes:
        if node.op == "call_function":
            if node.target == torch.ops.aten.scaled_dot_product_attention.default:
                return True
    return False


def save_pte_program(
    prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
) -> None:
    if model_name.endswith(".pte"):
        filename = model_name
    else:
        filename = os.path.join(output_dir, f"{model_name}.pte")

    try:
        with open(filename, "wb") as file:
            prog.write_to_file(file)
            logging.info(f"Saved exported program to {filename}")
    except Exception as e:
        logging.error(f"Error while saving to {filename}: {e}")


def save_bpte_program(
    buffer: bytes,
    model_name: str,
    output_dir: str = "",
) -> None:
    if model_name.endswith(".bpte"):
        filename = model_name
    else:
        filename = os.path.join(output_dir, f"{model_name}.bpte")
    try:
        with open(filename, "wb") as f:
            f.write(buffer)
        logging.info(f"Saved exported program to {filename}")
    except Exception as e:
        logging.error(f"Error while saving to {output_dir}: {e}")


@dataclass
class MemoryConfig:
    memory_sizes: List[int]

    # Optional fields for logs
    memory_names: Optional[List[str]] = None
    base_addrs: Optional[List[int]] = None
    memory_xml_path: Optional[str] = None
    MemorySpace: Optional[enum.Enum] = None

    # get num memories indexed from 1..N, compatible with EXIR's spec.mem_id
    def get_num_memories(self) -> int:
        return len(self.memory_sizes) + 1

    # memory_space module provides num_memories indexed 0..num_memories-1.
    def get_size(self, exir_id: int) -> int:
        return self.memory_sizes[exir_id - 1]


# Return default memory config for the backend
def get_default_memory_config() -> MemoryConfig:
    return MemoryConfig(memory_sizes=[0x1000000000])
