# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pye-strict

import copy
import unittest
from typing import Any, Dict

import torch
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.backend.test.op_partitioner_demo import (
    AddMulPartitionerDemo,
    NonDecompTestPartitioner,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.error import ExportError
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.pass_base import ExportPass
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.program._program import (
    EdgeProgramManager,
    ExecutorchProgramManager,
    to_edge,
    to_edge_transform_and_lower,
    to_edge_with_preserved_ops,
)
from executorch.exir.tracer import _default_decomposition_table
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier

from executorch.extension.pybindings.portable_lib import (
    _load_for_executorch_from_buffer,
)
from torch.export import Dim, export, ExportedProgram
from torch.export._trace import _export

from torch.library import impl, Library
from torch.nn import functional as F


class TestLinear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(32, 16, bias=True)

    def forward(self, x):
        return self.linear(x)

    @classmethod
    def _get_random_inputs(cls):
        x = torch.rand(8, 32)
        return (x,)


class TestSDPA(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query, key, value):
        return torch.ops.aten.scaled_dot_product_attention.default(query, key, value)

    @classmethod
    def _get_random_inputs(cls):
        d_k = 64
        batch = 16
        seq_len = 10
        query = torch.rand(batch, seq_len, d_k)
        key = torch.rand(batch, seq_len, d_k)
        value = torch.rand(batch, seq_len, d_k)
        return (query, key, value)


class TestLinearSDPACombined(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(32, 16, bias=True)

    def forward(self, x, query, key, value):
        x = self.linear(x)
        return (
            x,
            torch.ops.aten.scaled_dot_product_attention.default(query, key, value),
        )

    @classmethod
    def _get_random_inputs(cls):
        return TestLinear._get_random_inputs() + TestSDPA._get_random_inputs()


class TestUpsample(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        return x

    @classmethod
    def _get_random_inputs(cls):
        x = torch.randn(1, 1, 8, 8)
        return (x,)


class TestLSTM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=8, hidden_size=16, batch_first=True)

    def forward(self, x):
        return self.lstm(x)

    @classmethod
    def _get_random_inputs(cls):
        return (torch.rand(1, 10, 8),)


class WrapperModule(torch.nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, *args, **kwargs):
        return self.fn(*args, **kwargs)


lib = Library("exir_program_test_op", "DEF")

# Fake a operator for testing.
# This operator takes two tensors as input and returns the first one.
lib.define("foo(Tensor self, Tensor other) -> Tensor")


@impl(lib, "foo", "CPU")
def foo(a, b):
    # do nothing and return a.
    return a + b


@impl(lib, "foo", "Meta")
def foo_meta(a, b):
    # do nothing and return a.
    return torch.empty_like(a)


def get_exported_programs() -> Dict[str, ExportedProgram]:
    class Forward(torch.nn.Module):
        def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            z = torch.mul(x, y)
            return torch.add(z, x)

    forward = Forward()

    class Foo(torch.nn.Module):
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.add(x, torch.ones(1))

    foo = Foo()

    programs = {}
    programs["forward"] = export(
        forward,
        args=(
            torch.ones(1),
            torch.zeros(1),
        ),
    ).run_decompositions()
    programs["foo"] = export(
        foo,
        (torch.ones(1),),
    ).run_decompositions()
    return programs


def get_config_methods() -> Dict[str, Any]:
    def bam():
        return 3

    def bar():
        return "bar"

    return {"bam": bam(), "bar": bar()}


class AddToMulPassEdge(ExportPass):
    def call_operator(self, op, args, kwargs, meta):
        if op == exir_ops.edge.aten.add.Tensor:
            return super().call_operator(
                exir_ops.edge.aten.mul.Tensor, args, kwargs, meta
            )
        else:
            return super().call_operator(op, args, kwargs, meta)


class TestProgramManagers(unittest.TestCase):
    def test_edge_manager_basic_api(self):
        edge_manager: EdgeProgramManager = to_edge(
            get_exported_programs(), get_config_methods()
        )

        # test basic apis
        self.assertEqual(edge_manager.methods, {"forward", "foo"})
        self.assertEqual(edge_manager.config_methods, {"bam", "bar"})

        # test dialect is correct
        try:
            EXIREdgeDialectVerifier()(
                edge_manager.exported_program("forward").graph_module
            )
            EXIREdgeDialectVerifier()(edge_manager.exported_program("foo").graph_module)
        except ExportError as e:
            self.assertTrue(False, msg="Graph not in edge dialect : " + e.msg)

    def test_executorch_manager_basic_api(self):
        executorch_manager: ExecutorchProgramManager = to_edge(
            get_exported_programs(), get_config_methods()
        ).to_executorch()

        # test basic apis
        self.assertEqual(executorch_manager.methods, {"forward", "foo"})
        self.assertEqual(executorch_manager.config_methods, {"bam", "bar"})

        # test that the emitted output is correct
        self.assertEqual(
            len(executorch_manager._emitter_output.program.execution_plan), 4
        )

        # test that the buffer is correct
        executorch_module = _load_for_executorch_from_buffer(executorch_manager.buffer)
        self.assertEqual(
            executorch_module.run_method("forward", (torch.ones(1), torch.zeros(1)))[0],
            torch.ones(1),
        )
        self.assertEqual(
            executorch_module.run_method("foo", (torch.ones(1),))[0],
            torch.ones(1) + torch.ones(1),
        )
        self.assertEqual(
            executorch_module.run_method("bar", ())[0],
            "bar",
        )
        self.assertEqual(
            executorch_module.run_method("bam", ())[0],
            3,
        )

    def test_executorch_manager_multi_config(self):
        def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]:
            return {
                "forward": MemoryPlanningPass(
                    alloc_graph_input=True,
                    alloc_graph_output=False,
                ),
                "foo": MemoryPlanningPass(
                    alloc_graph_input=False,
                    alloc_graph_output=True,
                ),
            }

        executorch_manager: ExecutorchProgramManager = to_edge(
            get_exported_programs(), get_config_methods()
        ).to_executorch(
            ExecutorchBackendConfig(
                memory_planning_pass=get_executorch_memory_planning_passes()
            )
        )

        method = executorch_manager._emitter_output.program.execution_plan[0]
        if method.name == "forward":
            for input_val in method.inputs:
                evalue = method.values[input_val]
                self.assertEqual(evalue.val.allocation_info, None)
            for output_val in method.outputs:
                evalue = method.values[output_val]
                self.assertNotEqual(evalue.val.allocation_info, None)
        else:
            for input_val in method.inputs:
                evalue = method.values[input_val]
                self.assertEqual(evalue.val.allocation_info, None)
            for output_val in method.outputs:
                evalue = method.values[output_val]
                self.assertNotEqual(evalue.val.allocation_info, None)

    def test_no_getattr(self):
        class Mul(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x * 3.14

        mul = Mul()
        ep = to_edge(torch.export.export(mul, (torch.ones(1),))).exported_program()
        for node in ep.graph.nodes:
            self.assertNotEqual(node.op, "get_attr")
        self.assertEqual(
            len([node for node in ep.graph.nodes if node.op == "placeholder"]), 2
        )

    def test_constraint_present_after_dce(self):
        import executorch.exir as exir

        class M(torch.nn.Module):
            def forward(self, x, y):
                z = y.item()
                torch._check(z > 0)
                torch._check(z < 4)
                return x[z : z + y.shape[0]]

        ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3])))

        edge_manager = to_edge(
            ep, compile_config=exir.EdgeCompileConfig(_check_ir_validity=False)
        )
        edge_manager.to_executorch()

    def test_edge_manager_transform(self):
        edge_manager: EdgeProgramManager = to_edge(
            get_exported_programs(), get_config_methods()
        )

        original_res = edge_manager.exported_program("forward").module()(
            torch.ones(1), torch.ones(1)
        )

        # perform transformation
        transformed_edge = edge_manager.transform(
            [
                AddToMulPassEdge(),
            ]
        )

        # still have all our methods
        self.assertEqual(len(transformed_edge.methods), 2)
        self.assertEqual(len(transformed_edge.config_methods), 2)

        # transformation was applied
        self.assertEqual(
            transformed_edge.exported_program("forward").module()(
                torch.ones(1), torch.ones(1)
            ),
            torch.ones(1),  # x * y * x
        )

        # original unchanged
        self.assertEqual(
            edge_manager.exported_program("forward").module()(
                torch.ones(1), torch.ones(1)
            ),
            original_res,  # x * y + x
        )

    def test_issue_3659(self):

        class Mul(torch.nn.Module):
            def __init__(self):
                super(Mul, self).__init__()

            def forward(self, x: torch.Tensor, y: torch.Tensor):
                return torch.matmul(x, y)

            def get_eager_model(self) -> torch.nn.Module:
                return self

            def get_example_inputs(self):
                return (torch.randn(1, 3, 10), torch.randn(1, 10, 3))

            def get_dynamic_shapes(self):
                dim1_x = Dim("Dot_dim1_x", min=2, max=100)
                dim2_x = Dim("Dot_dim2_x", min=2, max=100)
                return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}}

        model = Mul()
        ep = torch.export.export(
            model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes()
        )

        to_edge(
            ep,
            compile_config=EdgeCompileConfig(
                _check_ir_validity=True,
            ),
        )

    def test_transform_dict_api(self):
        edge_manager = to_edge(get_exported_programs(), get_config_methods())

        transformed_edge = edge_manager.transform(
            {
                "forward": [
                    AddToMulPassEdge(),
                ]
            }
        )

        self.assertEqual(
            transformed_edge.exported_program("forward").module()(
                torch.ones(1), torch.ones(1)
            ),
            torch.ones(1),  # x * y * x
        )

        self.assertEqual(
            transformed_edge.exported_program("foo").module()(
                torch.ones(1),
            ),
            torch.ones(1) + 1,  # x + 1
        )

    def test_edge_to_backend_replaces_subgraph(self):
        edge_manager: EdgeProgramManager = to_edge(
            get_exported_programs(), get_config_methods()
        )
        delegate_manager: EdgeProgramManager = edge_manager.to_backend(
            AddMulPartitionerDemo()
        )

        forward_program = delegate_manager.exported_program("forward")
        self.assertEqual(
            forward_program.module()(torch.ones(1), torch.ones(1)),
            torch.ones(1) + 1,  # x * y + x
        )

        add_nodes = [
            node
            for node in forward_program.graph_module.graph.nodes
            if node.op == "call_function"
            and node.target == exir_ops.edge.aten.add.Tensor
        ]
        self.assertEqual(len(add_nodes), 0)

        foo_program = delegate_manager.exported_program("foo")
        add_nodes = [
            node
            for node in foo_program.graph_module.graph.nodes
            if node.op == "call_function"
            and node.target == exir_ops.edge.aten.add.Tensor
        ]
        self.assertEqual(len(add_nodes), 0)

        lowered_submods = get_lowered_submodules(foo_program.graph_module)
        self.assertEqual(len(lowered_submods), 1)

        # original unchanged
        lowered_submods = get_lowered_submodules(
            edge_manager.exported_program("forward").graph_module
        )
        self.assertEqual(len(lowered_submods), 0)

        # two delegate blobs for forward and foo
        self.assertEqual(
            len(
                delegate_manager.to_executorch(ExecutorchBackendConfig())
                ._emitter_output.program.execution_plan[0]
                .delegates
            ),
            1,
        )
        self.assertEqual(
            len(
                delegate_manager.to_executorch(ExecutorchBackendConfig())
                ._emitter_output.program.execution_plan[1]
                .delegates
            ),
            1,
        )

    def test_edge_to_backend_selective(self):
        edge_manager: EdgeProgramManager = to_edge(
            get_exported_programs(), get_config_methods()
        )
        delegate_manager: EdgeProgramManager = edge_manager.to_backend(
            {"forward": AddMulPartitionerDemo()}
        )

        forward_program = delegate_manager.exported_program("forward")
        self.assertEqual(
            forward_program.module()(torch.ones(1), torch.ones(1)),
            torch.ones(1) + 1,  # x * y + x
        )

        add_nodes = [
            node
            for node in forward_program.graph_module.graph.nodes
            if node.op == "call_function"
            and node.target == exir_ops.edge.aten.add.Tensor
        ]
        self.assertEqual(len(add_nodes), 0)

        # foo unchanged
        lowered_submods = get_lowered_submodules(
            delegate_manager.exported_program("foo").graph_module
        )
        self.assertEqual(len(lowered_submods), 0)

        # original unchanged
        lowered_submods = get_lowered_submodules(
            edge_manager.exported_program("forward").graph_module
        )
        self.assertEqual(len(lowered_submods), 0)

        # one delegate blob for forward
        self.assertEqual(
            len(
                delegate_manager.to_executorch(
                    ExecutorchBackendConfig(
                        extract_delegate_segments=False,
                    )
                )
                ._emitter_output.program.execution_plan[0]  # foo
                .delegates
            ),
            0,
        )
        self.assertEqual(
            len(
                delegate_manager.to_executorch(
                    ExecutorchBackendConfig(
                        extract_delegate_segments=False,
                    )
                )
                ._emitter_output.program.execution_plan[1]  # forward
                .delegates
            ),
            1,
        )

    def test_edge_manager_dialect(self):
        edge_manager: EdgeProgramManager = to_edge(
            get_exported_programs(), get_config_methods()
        )
        self.assertTrue(edge_manager.exported_program().dialect == "EDGE")

    def _test_edge_dialect_verifier(
        self, callable, validate_ir=True, exception_list=None
    ):
        from executorch.exir import EdgeCompileConfig

        edge_compile_config = EdgeCompileConfig(
            _check_ir_validity=validate_ir,
            _core_aten_ops_exception_list=exception_list,
        )
        # pre-autograd export. eventually this will become torch.export
        one = torch.ones(1, dtype=torch.float)
        two = torch.ones(1, dtype=torch.int32)
        inputs = (
            one,
            two,
        )
        if not isinstance(callable, torch.nn.Module):
            callable = WrapperModule(callable)

        exported_foo = export(callable, inputs)
        _ = to_edge(exported_foo, compile_config=edge_compile_config)

    def test_edge_dialect_custom_op(self):
        # We shouldn't error out if there's a custom op in the graph.
        def _use_foo_add(a: torch.Tensor, b: torch.Tensor):
            return torch.ops.exir_program_test_op.foo(a, b)

        from torch._export.verifier import SpecViolationError

        try:
            # This should not raise error
            self._test_edge_dialect_verifier(_use_foo_add)
            self._test_edge_dialect_verifier(_use_foo_add, False)
        except SpecViolationError:
            self.fail("Should not error out on custom op")

    def get_num_nondecomposed_ops(self, ep, partitioner):
        # count the number of aten ops that the partitioner can delegate
        # we do this by running run_decompositions() with the preserved ops given
        # to us by the partitioner. Then we count the number of preserved aten ops
        # which pass the filter_ops fn given by the partitioner
        reference_ep = copy.deepcopy(ep)
        aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep)
        table = _default_decomposition_table()
        for op in aten_ops_not_decomposed:
            table.pop(op, None)
        reference_decomp_ep = reference_ep.run_decompositions(decomp_table=table)
        num_non_decomposed_aten_ops = 0
        for node in reference_decomp_ep.graph.nodes:
            if (
                node.op == "call_function"
                and node.target in aten_ops_not_decomposed
                and (filter_ops(node) if filter_ops else True)
            ):
                num_non_decomposed_aten_ops += 1
        return num_non_decomposed_aten_ops

    def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module):
        # This is the pre-dispatch export that we will be switching to primarily
        # in the near future. The input to to_edge_transform_and_lower needs to
        # be a graph generated by this pre dispatch export.
        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
        #  torch.nn.modules.module.Module]` is not a function.
        ep = _export(model, model._get_random_inputs(), pre_dispatch=True)
        non_decomp_partitioner = NonDecompTestPartitioner()

        num_non_decomposed_aten_ops = self.get_num_nondecomposed_ops(
            ep, non_decomp_partitioner
        )

        # run to_edge_trasnform_and_lower
        edge = to_edge_transform_and_lower(
            ep,
            compile_config=EdgeCompileConfig(),
            partitioner=[NonDecompTestPartitioner()],
        )
        # Check that non_decomposed_edge_ops are all consumed by the delegate
        non_decomposed_edge_ops = (
            non_decomp_partitioner.supported_non_decomposed_edge_ops
        )
        for node in edge.exported_program().graph.nodes:
            if node.op == "call_function":
                self.assertTrue(node.target not in non_decomposed_edge_ops)

        # check that the number of call_delegate_nodes is equal to the number of
        # non_decomposed_aten_ops we found above
        num_call_delegates = 0
        for node in edge.exported_program().graph_module.graph.nodes:
            # There should only be a single call_function node in the graph
            # and that should be a call_delegate node.
            if (
                node.op == "call_function"
                and node.target == torch.ops.higher_order.executorch_call_delegate
            ):
                num_call_delegates += 1

        self.assertEqual(num_call_delegates, num_non_decomposed_aten_ops)

    def test_to_edge_transform_and_lower(self):
        self._test_model_with_non_decomp_partitioner(TestLinear())

        self._test_model_with_non_decomp_partitioner(TestSDPA())

        self._test_model_with_non_decomp_partitioner(TestLinearSDPACombined())

        self._test_model_with_non_decomp_partitioner(TestUpsample())

        self._test_model_with_non_decomp_partitioner(TestLSTM())

    def test_to_edge_transform_and_lower_with_exception(self):
        class TestLinear(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(32, 16, bias=True)
                self.linear_no_bias = torch.nn.Linear(32, 16, bias=False)

            def forward(self, x):
                return (self.linear(x), self.linear_no_bias(x))

            @classmethod
            def _get_random_inputs(cls):
                x = torch.rand(8, 32)
                return (x,)

        model = TestLinear()
        ep = _export(model, model._get_random_inputs(), pre_dispatch=True)
        edge = to_edge_transform_and_lower(
            ep,
            compile_config=EdgeCompileConfig(),
            partitioner=[NonDecompTestPartitioner()],
        )

        def count_nodes(graph_module, target):
            count = 0
            for node in graph_module.graph.nodes:
                if node.op == "call_function" and node.target == target:
                    count += 1
            return count

        # There should be 1 call_delegate node and 1 node for aten.mm.default for the
        # linear that doesn't have a bias which was decomposed as the partitioner
        # said this node wasn't supported.
        self.assertEqual(
            count_nodes(
                edge.exported_program().graph_module,
                torch.ops.higher_order.executorch_call_delegate,
            ),
            1,
        )
        self.assertEqual(
            count_nodes(
                edge.exported_program().graph_module, exir_ops.edge.aten.mm.default
            ),
            1,
        )

    def test_edge_dialect_non_core_aten_ops(self):
        class LinalgNorm(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.linalg.norm(x)

        from torch._export.verifier import SpecViolationError

        input = torch.arange(9, dtype=torch.float) - 4
        ep = torch.export.export(LinalgNorm(), (input,))

        # aten::linalg_norm is not a core op, so it should error out
        with self.assertRaises(SpecViolationError):
            _ = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=True))

        # with exception list, it should not error out
        try:
            # This should not raise error
            _ = to_edge(
                ep,
                compile_config=EdgeCompileConfig(
                    _check_ir_validity=True,
                    _core_aten_ops_exception_list=[
                        torch.ops.aten.linalg_vector_norm.default
                    ],
                ),
            )
        except SpecViolationError:
            self.fail("Should not error out on linalg_vector_norm op")

    def _test_to_edge_with_preserved_ops(
        self, program, preserved_ops, expected_preserved_ops
    ):
        edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)

        def count_nodes(graph_module, target):
            count = 0
            for node in graph_module.graph.nodes:
                if node.op == "call_function" and node.target in target:
                    count += 1
            return count

        aten_ops_non_decomposed = count_nodes(
            program.graph_module,
            preserved_ops,
        )

        edge_ops_non_decomposed = count_nodes(
            edge.exported_program().graph_module,
            expected_preserved_ops,
        )

        self.assertEqual(aten_ops_non_decomposed, edge_ops_non_decomposed)

    def test_to_edge_with_single_preserved_op(self):
        model = TestLinear()
        program = torch.export.export(model, model._get_random_inputs())

        ops_not_to_decompose = [
            torch.ops.aten.linear.default,
        ]
        expected_non_decomposed_edge_ops = [
            exir_ops.edge.aten.linear.default,
        ]

        self._test_to_edge_with_preserved_ops(
            program, ops_not_to_decompose, expected_non_decomposed_edge_ops
        )

    def test_to_edge_with_partial_ops_preserved(self):
        model = TestLinearSDPACombined()
        program = torch.export.export(model, model._get_random_inputs())

        ops_not_to_decompose = [
            torch.ops.aten.linear.default,
        ]
        expected_non_decomposed_edge_ops = [
            exir_ops.edge.aten.linear.default,
        ]

        self._test_to_edge_with_preserved_ops(
            program, ops_not_to_decompose, expected_non_decomposed_edge_ops
        )

    def test_to_edge_with_multiple_ops_preserved(self):
        model = TestLinearSDPACombined()
        program = torch.export.export(model, model._get_random_inputs())

        ops_not_to_decompose = [
            torch.ops.aten.linear.default,
            torch.ops.aten.scaled_dot_product_attention.default,
        ]
        expected_non_decomposed_edge_ops = [
            exir_ops.edge.aten.linear.default,
            exir_ops.edge.aten.scaled_dot_product_attention.default,
        ]

        self._test_to_edge_with_preserved_ops(
            program, ops_not_to_decompose, expected_non_decomposed_edge_ops
        )

    def test_to_edge_with_preserved_ops_not_in_model(self):
        model = TestSDPA()
        program = torch.export.export(model, model._get_random_inputs())

        ops_not_to_decompose = [
            torch.ops.aten.linear.default,
        ]
        expected_non_decomposed_edge_ops = [
            exir_ops.edge.aten.linear.default,
        ]

        self._test_to_edge_with_preserved_ops(
            program, ops_not_to_decompose, expected_non_decomposed_edge_ops
        )
