# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import re
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import torch
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue
from torch.fx.node import Argument, Target
from torch.library import Library

lib = Library("aten", "FRAGMENT")
impl_lib = Library("aten", "IMPL")

log = logging.getLogger(__name__)


def get_target_version(versioned_upgrader_name: str) -> int:
    """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is
    upgrading to version 4."""
    if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name):
        raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid")

    return int(versioned_upgrader_name.split("_")[-1]) + 1


def get_upgraders() -> Dict[str, Tuple[str, str]]:
    """Getting upgraders entry map and operator version map and merge them into one dict."""
    upgraders = torch._C._get_upgraders_entry_map()
    op_version_map = torch._C._get_operator_version_map()
    output: Dict[str, Tuple[str, str]] = defaultdict(tuple)  # type: ignore[arg-type]
    for opname, entry_list in op_version_map.items():
        if not entry_list:
            raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
        entry = entry_list[0]
        old_schema = entry.old_schema
        upgrader_name = entry.upgrader_name
        upgrader_str = upgraders.get(upgrader_name, None)
        if not upgrader_str:
            raise RuntimeError(
                f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}"
            )
        output[upgrader_name] = (old_schema, upgrader_str)
    return output


class GraphModuleOpUpgrader:
    """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available.
    To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In
    __init__() it does the following:
    1. parse the upgrader list and reorder for upgrading purpose.
    2. register old versions of operators as custom ops.
    3. prepare upgrader passes.

    In `upgrade()` API run these upgrader passes.

    An example of op_upgraders input:
    {
        "aten::div__Scalar_0_3": (                              # versioned op name
            "div._Scalar(self: Tensor, other: Scalar)",         # old schema
            '''
            def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor:     # upgrader in literal string
              if (self.is_floating_point() or isinstance(other, float)):
                return self.true_divide_(other)
              return self.divide_(other, rounding_mode='trunc')
            ''',
        ),
    },

    Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the
    original TorchScript upgrader).
    """

    class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse):
        def __init__(self, old_target: Target, new_target: Target):
            super().__init__()
            self.old_target = old_target
            self.new_target = new_target

        def call_operator(
            self,
            op,
            args: Tuple[Argument, ...],
            kwargs: Dict[str, Argument],
            meta: NodeMetadata,
        ) -> ProxyValue:
            if op == self.old_target:
                return super().call_operator(self.new_target, args, kwargs, meta)
            return super().call_operator(op, args, kwargs, meta)

    def __init__(
        self,
        compiler_opset_version: Optional[Dict[str, int]] = None,
        model_opset_version: Optional[Dict[str, int]] = None,
        op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None,
    ):
        self.op_upgraders: Dict[str, Tuple[str, str]] = (
            get_upgraders() if not op_upgraders else op_upgraders
        )
        self.compiler_opset_version = (
            compiler_opset_version if compiler_opset_version else {}
        )
        self.model_opset_version = model_opset_version if model_opset_version else {}
        self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = (
            GraphModuleOpUpgrader._populate_passes(
                self._parse_upgraders(self.op_upgraders)
            )
        )

    def _parse_upgraders(
        self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None
    ) -> List[Tuple[str, str]]:
        """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well
        as the upgrader function string literal."""
        # TODO(larryliu0820): Add support for custom ops
        op_namespace = "aten"
        if (
            not op_upgraders
            or op_namespace not in self.model_opset_version
            or op_namespace not in self.compiler_opset_version
        ):
            return []
        model_ver = self.model_opset_version[op_namespace]
        curr_ver = self.compiler_opset_version[op_namespace]

        # key is the target version. div__Scalar_0_3 should have a key of 4.
        versioned_upgraders: Dict[int, Tuple[str, str]] = {
            get_target_version(name): v for name, v in op_upgraders.items()
        }
        target_upgraders: List[Tuple[str, str]] = []
        # we need all upgraders from model_ver + 1 to curr_ver, inclusively
        for ver in range(model_ver + 1, curr_ver + 1):
            if ver in versioned_upgraders:
                target_upgraders.append(versioned_upgraders[ver])
            else:
                # we may be able to get away with missing upgraders, if that operator is missing from given graph
                # module.
                log.warning(
                    "Missing an upgrader to upgrade to version {ver}.",
                    extra={"ver": ver},
                )

        return target_upgraders

    @staticmethod
    def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]:
        """Given a list of upgraders, loop through it from lower version to higher version and create passes for all
        upgraders. se torch.Library API to register old ops. Op name will be
        <name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example:

        lib = Library("aten", "FRAGMENT")
        lib.define(old_schema)

        impl_lib = Library("aten", "IMPL")
        impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd")

        @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the
        upgrader function literal text.
        @:return upgrader passes, order matters
        """

        upgrader_passes = []

        def register_old_op(name: str, schema: str, impl_str: str):
            """Registers an old version operator using impl_name as old op name."""
            lib.define(schema)
            try:
                exec(impl_str)
            except Exception as e:
                raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e
            impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd")

        for schema, upgrader_str in upgraders:
            upgrader_name = upgrader_str.split("(")[0].split(" ")[-1]
            op_name = schema.split("(")[0].split("::")[-1]
            schema = schema.replace(op_name, upgrader_name)
            try:
                register_old_op(
                    name=upgrader_name, schema=schema, impl_str=upgrader_str
                )
            except RuntimeError as e:
                if "with the same name and overload name multiple times" in str(e):
                    print(f"Registering {upgrader_name} multiple times")
                else:
                    raise RuntimeError from e
            old_op_target = getattr(torch.ops.aten, upgrader_name).default
            # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the
            # "default" at the end.
            op_name, overload_name = (
                (op_name, "default")
                if "." not in op_name
                else tuple(op_name.split(".")[:2])
            )
            new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name)
            # Note that the graph will have op names in the graph, but actually they are of old versions.
            upgrader_passes.append(
                GraphModuleOpUpgrader.UpgraderPass(
                    old_target=new_op_target, new_target=old_op_target
                )
            )

        return upgrader_passes

    def upgrade(self, exported_program):
        return exported_program
