# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

try:  # noqa: C901
    from torch._higher_order_ops.executorch_call_delegate import (
        executorch_call_delegate as executorch_call_delegate,
        get_lowered_module_name as get_lowered_module_name,
        is_lowered_module as is_lowered_module,
    )

except ImportError:

    # TODO: Delete this code once pytorch pin advances

    from typing import Any, cast

    import torch
    import torch.utils._pytree as pytree
    from torch._ops import HigherOrderOperator
    from torch._subclasses.fake_tensor import FakeTensorMode
    from torch.fx.experimental.proxy_tensor import (
        disable_proxy_modes_tracing,
        get_proxy_slot,
        ProxyTorchDispatchMode,
        track_tensor_tree,
    )
    from torch.utils._pytree import tree_flatten

    executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
    executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
    executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
    executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)
    executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)

    LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"

    # pyre-ignore
    def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
        # pyre-ignore
        def _unwrap_proxy(e):
            if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
                return e
            return get_proxy_slot(
                cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy
            )

        if not is_lowered_module(lowered_module):
            raise ValueError(
                "executorch_call_delegate()'s first argument must be a LoweredBackendModule"
            )

        with disable_proxy_modes_tracing():
            out = call_delegate_cpu(lowered_module, *args)

        get_lowered_module_name(proxy_mode.tracer.root, lowered_module)

        node_args = (lowered_module, *args)
        proxy_args = pytree.tree_map(_unwrap_proxy, node_args)
        out_proxy = proxy_mode.tracer.create_proxy(
            "call_function",
            func_overload,
            proxy_args,
            {},
            name="executorch_call_delegate",
        )
        return track_tensor_tree(
            out, out_proxy, constant=None, tracer=proxy_mode.tracer
        )

    @executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
    # pyre-ignore
    def call_delegate_cpu(lowered_module, *args):
        # FX creates this immutable_dict/list concept. Get rid of this.
        map_types = {
            torch.fx.immutable_collections.immutable_dict: dict,
            torch.fx.immutable_collections.immutable_list: list,
        }
        new_args = pytree.tree_map_only(
            tuple(map_types.keys()),
            lambda a: map_types[type(a)](a),
            args,
            lambda a: isinstance(a, tuple(map_types.keys())),
        )
        return lowered_module.original_module.module()(*new_args)

    @executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd)
    # pyre-ignore
    def call_delegate_autograd(lowered_module, *args):
        # TODO: support autograd
        flat_operands, _ = tree_flatten([lowered_module, *args])
        requires_grad = any(
            f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
        )

        with torch._C._ExcludeDispatchKeyGuard(
            torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
        ):
            res = executorch_call_delegate(lowered_module, *args)

            if requires_grad:
                # Create aliases of the output that has requires_grad=True. We need
                # at least one of the inputs to err_fn to require grad so that the
                # output will have a grad_fn.

                # pyre-ignore
                def fake_requires_grad(var):
                    if var is not None:
                        var = var.detach()
                        if torch.is_floating_point(var) or torch.is_complex(var):
                            var.requires_grad = True
                    return var

                return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res)

            return res

    @executorch_call_delegate.py_impl(ProxyTorchDispatchMode)
    # pyre-ignore
    def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args):
        res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args)
        return res

    @executorch_call_delegate.py_impl(FakeTensorMode)
    # pyre-ignore
    def call_delegate_fake_tensor_mode(mode, lowered_module, *args):
        with mode:
            return call_delegate_cpu(lowered_module, *args)

    @executorch_call_delegate.py_functionalize_impl
    # pyre-ignore
    def call_delegate_functionalize(ctx, lowered_module, *args):
        unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
        with ctx.redispatch_to_next():
            res = executorch_call_delegate(lowered_module, *unwrapped_args)
            return ctx.wrap_tensors(res)

    # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre
    def is_lowered_module(obj: Any) -> bool:
        """
        This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import.
        """
        return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE

    def get_lowered_module_name(
        root: torch.nn.Module,
        # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
        lowered_module: LOWERED_BACKEND_MODULE_TYPE,  # noqa
    ) -> str:
        """
        Adds the given lowered_module into the given root module and returns the
        name of the module added.
        """
        # Find a qualifying name for the lowered submodule
        qualname = None
        i = 0
        while True:
            qualname = f"lowered_module_{i}"
            if not hasattr(root, qualname):
                break
            i += 1
        assert qualname is not None

        root.add_module(qualname, lowered_module)
        return qualname
