# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Dict, List

import executorch.exir as exir
import torch
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
    DelegationSpec,
    Partitioner,
    PartitionResult,
)

# import the backend implementation
from executorch.exir.backend.test.backend_with_compiler_demo import (
    BackendWithCompilerDemo,
)
from executorch.exir.backend.test.hta_partitioner_demo import (
    HTAPartitionerMultiplePatternsDemo,
    HTAPartitionerOnePatternDemo,
)
from executorch.exir.backend.test.op_partitioner_demo import (
    AddAttributePartitionerDemo,
    AddMulPartitionerDemo,
)
from executorch.exir.backend.test.qnn_backend_demo import QnnBackend

from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.lowered_backend_module import (
    get_lowered_backend_modules,
    get_lowered_submodules,
)
from executorch.exir.print_program import print_program
from executorch.exir.schema import (
    BackendDelegate,
    BackendDelegateDataReference,
    DataLocation,
    DelegateCall,
    Program,
)

from executorch.extension.pybindings.portable_lib import (  # @manual
    _load_for_executorch_from_buffer,
)
from executorch.extension.pytree import tree_flatten

from functorch.experimental import control_flow
from torch.ao.quantization import get_default_qconfig_mapping  # @manual
from torch.ao.quantization.backend_config.executorch import (
    get_executorch_backend_config,
)
from torch.ao.quantization.quantize_fx import (
    _convert_to_reference_decomposed_fx,
    prepare_fx,
)
from torch.export import export, ExportedProgram
from torch.testing import FileCheck


def vary_segments(test_method):
    """A decorator that calls the test method with `extract_delegate_segments` set to
    True and False.

    Decorated test methods must expect a boolean parameter named
    `extract_delegate_segments`, and they should pass that value to to_executorch() like:

        m.to_executorch(
            config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments)
        )

    This will cause the delegate data blobs to be extracted from the program and
    serialized as separate, freeable program segments. Backends should detect no
    difference at runtime.
    """

    def wrapper(self):
        for extract_delegate_segments in [False, True]:
            # subTest will create a different top-level test entry for each
            # value, whose full names have a suffix like
            # "(extract_delegate_segments=True)".
            with self.subTest(extract_delegate_segments=extract_delegate_segments):
                test_method(self, extract_delegate_segments=extract_delegate_segments)

    return wrapper


class TestBackends(unittest.TestCase):
    def check_delegate_input(
        self, delegate: LoweredBackendModule, input_len: int
    ) -> None:
        counter = 0
        for node in delegate.original_module.graph.nodes:
            if node.op == "placeholder":
                counter += 1
        self.assertEqual(counter, input_len)

    def check_backend_delegate(
        self,
        program: Program,
        delegate: BackendDelegate,
        expected_id: str,
        expected_processed: bytes,
    ) -> None:
        self.assertEqual(delegate.id, expected_id)
        processed: BackendDelegateDataReference = delegate.processed
        self.assertEqual(processed.location, DataLocation.INLINE)
        self.assertLess(processed.index, len(program.backend_delegate_data))
        self.assertEqual(
            program.backend_delegate_data[processed.index].data, expected_processed
        )

    def test_simple(self):
        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.sin(x)

        sin_module = SinModule()
        model_inputs = (torch.ones(1),)
        expected_res = sin_module(*model_inputs)
        edgeir_m = to_edge(export(sin_module, model_inputs))

        lowered_sin_module = to_backend(
            "BackendWithCompilerDemo", edgeir_m.exported_program(), []
        )
        new_res = lowered_sin_module(*model_inputs)

        self.assertTrue(torch.allclose(new_res, expected_res))

        # TODO(tkaruturi): emitting single LoweredBackendModule
        # program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program

    @vary_segments
    def test_backend_with_compiler(self, extract_delegate_segments: bool):
        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            # TODO(chenlai): add a test with a diffrent method name when
            # it's resolved in compiler side.
            def forward(self, x):
                return torch.sin(x)

        sin_module = SinModule()
        model_inputs = (torch.ones(1),)
        edgeir_m = to_edge(export(sin_module, model_inputs))
        max_value = model_inputs[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_sin_module = to_backend(
            "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_linear_sin = lowered_sin_module

            def forward(self, x):
                return self.lowered_linear_sin(x)

        composite_model = CompositeModule()
        model_inputs = (torch.ones(1),)

        composite_model(*model_inputs)

        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            )
        )
        graph_module = exec_prog.exported_program().graph_module

        # Check that there is not an aten.sin node.
        self.assertTrue(
            exir_ops.edge.aten.sin
            not in {node.target for node in graph_module.graph.nodes}
        )

        # Check that there exists a call_delegate, representing the call to the
        # delegated function
        FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
            graph_module.code
        )
        lowered_submodules = get_lowered_submodules(graph_module)
        self.assertEqual(len(lowered_submodules), 1)

        for node in graph_module.graph.nodes:
            if node.op == "call_function" and node.target == executorch_call_delegate:
                # Check that first arg is lowered_module_{unique_id}
                self.assertEqual(node.args[0].target, "lowered_module_0")

        program = exec_prog._emitter_output.program

        # Check the program can be printed
        print_program(program)

        # Check the backend delegate
        self.check_backend_delegate(
            program=program,
            delegate=program.execution_plan[0].delegates[0],
            expected_id=BackendWithCompilerDemo.__name__,
            expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
        )

        # Check the delegate instruction
        self.assertTrue(
            isinstance(
                program.execution_plan[0].chains[0].instructions[0].instr_args,
                DelegateCall,
            )
        )
        buff = exec_prog.buffer

        executorch_module = _load_for_executorch_from_buffer(buff)
        model_inputs = torch.ones(1)
        model_outputs = executorch_module.forward([model_inputs])
        self.assertEqual(
            model_inputs,
            torch.ones(1),
        )
        expected_output = 0.8333 * torch.ones(1)

        self.assertTrue(
            torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
        )

    @vary_segments
    def test_lowered_add_mul(self, extract_delegate_segments: bool):
        class AddMulModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, a, x, b):
                y = torch.mm(a, x)
                z = torch.add(y, b)
                return z

        add_mul_module = AddMulModule()
        model_inputs = (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
        edge_graph_module = to_edge(export(add_mul_module, model_inputs))
        max_value = model_inputs[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_add_mul = to_backend(
            "BackendWithCompilerDemo",
            edge_graph_module.exported_program(),
            compile_specs,
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_add_mul = lowered_add_mul

            def forward(self, a, x, b):
                return self.lowered_add_mul(a, x, b)

        composite_model = CompositeModule()

        composite_model(*model_inputs)

        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            )
        )
        buff = exec_prog.buffer

        executorch_module = _load_for_executorch_from_buffer(buff)

        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
        inputs_flattened, _ = tree_flatten(model_inputs)
        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
        ref_output = add_mul_module(*model_inputs)

        self.assertTrue(
            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
        )

    def run_model_in_unsupported_backend(self, extract_delegate_segments: bool):
        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.sin(x)

        sin_module = SinModule()
        # the backend only  accepts shape <= 4
        model_inputs = (torch.ones(6),)
        edgeir_m = to_edge(export(sin_module, model_inputs))
        max_value = model_inputs[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_sin_module = to_backend(
            "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_linear_sin = lowered_sin_module

            def forward(self, x):
                return self.lowered_linear_sin(x)

        composite_model = CompositeModule()
        model_inputs = (torch.zeros(6),)

        composite_model(*model_inputs)

        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )

        buff = exec_prog.buffer

        # This line should raise an exception like
        # RuntimeError: failed with error 0x12
        _load_for_executorch_from_buffer(buff)

    @vary_segments
    def test_backend_with_compiler_out_of_range(self, extract_delegate_segments: bool):
        with self.assertRaisesRegex(
            RuntimeError,
            "loading method forward failed with error 0x12",
        ):
            self.run_model_in_unsupported_backend(
                extract_delegate_segments=extract_delegate_segments
            )

    @vary_segments
    def test_backend_with_compiler_delegate_and_operator(
        self, extract_delegate_segments: bool
    ):
        # Test includes both delegates and operator
        # import the backend implementation
        from executorch.exir.backend.test.backend_with_compiler_demo import (
            BackendWithCompilerDemo,
        )

        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            # TODO(chenlai): add a test with a diffrent method name when
            # it's resolved in compiler side.
            def forward(self, x):
                return [torch.sin(x)]

        sin_module = SinModule()
        model_inputs = (torch.ones(1),)
        edgeir_m = to_edge(export(sin_module, model_inputs))
        max_value = model_inputs[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_sin_module = to_backend(
            "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_linear_sin = lowered_sin_module

            def forward(self, x):
                a = self.lowered_linear_sin(x)[0]
                b = self.lowered_linear_sin(x)[0]
                return torch.add(a, b)

        composite_model = CompositeModule()
        model_inputs = (torch.ones(1),)

        composite_model(*model_inputs)

        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )
        graph_module = exec_prog.exported_program().graph_module
        program = exec_prog._emitter_output.program
        buff = exec_prog.buffer

        # Check that there is not an aten.sin node.
        self.assertTrue(
            exir_ops.edge.aten.sin.default
            not in {node.target for node in graph_module.graph.nodes}
        )

        # Check that there exists a call_delegate op, representing the call to the
        # delegated function
        FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
            graph_module.code
        )

        for node in graph_module.graph.nodes:
            if node.op == "call_function" and node.target == executorch_call_delegate:
                # Check that first arg is lowered_module_{unique_id}
                self.assertEqual(node.args[0].target, "lowered_module_0")

        # Check the backend delegate
        self.check_backend_delegate(
            program=program,
            delegate=program.execution_plan[0].delegates[0],
            expected_id=BackendWithCompilerDemo.__name__,
            expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
        )

        # Check the delegate instruction
        self.assertTrue(
            isinstance(
                program.execution_plan[0].chains[0].instructions[0].instr_args,
                DelegateCall,
            )
        )

        executorch_module = _load_for_executorch_from_buffer(buff)
        model_inputs = torch.ones(1)

        model_outputs = executorch_module.forward([model_inputs])

        self.assertEqual(
            model_inputs,
            torch.ones(1),
        )
        expected_output = 1.666667 * torch.ones(1)

        self.assertTrue(
            torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
        )

    def test_backend_with_compiler_backend_runtime_exception(self):
        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            # TODO(chenlai): add a test with a diffrent method name when
            # it's resolved in compiler side.
            def forward(self, x):
                return torch.sin(x) + torch.cos(x)

        sin_module = SinModule()
        model_inputs = (torch.ones(1),)
        edgeir_m = to_edge(export(sin_module, model_inputs))
        error_msg = r"call_function aten.cos.default is not supported in backend BackendWithCompilerDemo"

        with self.assertRaisesRegex(
            RuntimeError,
            error_msg,
        ):
            _ = to_backend("BackendWithCompilerDemo", edgeir_m.exported_program(), [])

    def test_backend_with_compiler_backend_not_found_exception(self):
        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            # TODO(chenlai): add a test with a diffrent method name when
            # it's resolved in compiler side.
            def forward(self, x):
                return torch.sin(x) + torch.cos(x)

        sin_module = SinModule()
        model_inputs = (torch.ones(1),)
        edgeir_m = to_edge(export(sin_module, model_inputs))
        error_msg = r"Backend FakeBackendWithCompilerDemo was not found."

        with self.assertRaisesRegex(
            NotImplementedError,
            error_msg,
        ):
            _ = to_backend(
                "FakeBackendWithCompilerDemo", edgeir_m.exported_program(), []
            )

    @vary_segments
    def test_backend_with_compiler_delegate_and_operator_with_two_modules(
        self, extract_delegate_segments: bool
    ):
        # the submodule runs in a specific backend. In this example, `BackendWithCompilerDemo` backend
        class LowerableSubModel(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.sin(x)

        # sin_module is an nn.Module
        to_be_lowered = LowerableSubModel()
        example_input = (torch.ones(1),)
        to_be_lowered_exir_submodule = to_edge(export(to_be_lowered, example_input))

        max_value = example_input[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_module = to_backend(
            "BackendWithCompilerDemo",
            to_be_lowered_exir_submodule.exported_program(),
            compile_specs,
        )

        class NonLowerableSubModel(torch.nn.Module):
            def __init__(self, bias):
                super().__init__()
                self.register_buffer("bias", bias)

            def forward(self, a, b):
                return torch.add(torch.add(a, b), self.bias)

        # the composite modules, including lower part and non-lowerpart
        class CompositeModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.non_lowerable = NonLowerableSubModel(torch.ones(1) * 0.3)
                self.lowerable = lowered_module

            def forward(self, x):
                a = self.lowerable(x)
                b = self.lowerable(a)
                ret = self.non_lowerable(a, b)
                return a, b, ret

        composite_model = CompositeModel()

        # Prepare the model input
        model_inputs = (torch.ones(1),)

        # Verify the input works with eager module
        composite_model(*model_inputs)

        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )
        flatbuffer = exec_prog.buffer

        executorch_module = _load_for_executorch_from_buffer(flatbuffer)
        model_outputs = executorch_module.forward([*model_inputs])

        expected_outputs = [
            0.8333 * torch.ones(1),
            0.7369 * torch.ones(1),
            1.8702 * torch.ones(1),
        ]

        for index, expected_output in enumerate(expected_outputs):
            self.assertTrue(
                torch.allclose(
                    model_outputs[index], expected_output, atol=1e-03, rtol=1e-03
                )
            )

    @vary_segments
    def test_partition_delegate_graph_with_multiple_patterns(
        self, extract_delegate_segments: bool
    ):
        class CompositeModel(torch.nn.Module):
            def __init__(self, _weight):
                super().__init__()
                self.weight = _weight
                self.lstm = torch.nn.LSTM(
                    input_size=32,
                    hidden_size=32,
                    num_layers=1,
                )
                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)

            def forward(self, x_raw, h, c):
                output, (hn, cn) = self.lstm(x_raw, (h, c))
                k = self.conv(output)
                x = output
                y = cn
                a = torch.sub(x, y)
                b = torch.sub(x, a)
                c = torch.sub(x, b)
                d = torch.add(x, self.weight)
                e = torch.mul(c, d)
                return e, hn, k

        # Prepare input and trace it
        input_x = torch.ones([1, 32])
        input_h = torch.ones([1, 32])
        input_c = torch.ones([1, 32])
        inputs = (input_x, input_h, input_c)

        composite_m = CompositeModel(3)
        orig_res = composite_m(*inputs)

        traced = to_edge(
            export(composite_m, inputs),
            compile_config=exir.EdgeCompileConfig(
                _check_ir_validity=False, _use_edge_ops=True
            ),
        )

        program_without_delegates = to_edge(
            export(CompositeModel(3), inputs),
            compile_config=exir.EdgeCompileConfig(
                _check_ir_validity=False,
            ),
        ).to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )
        # after this step, part of the graph will be lowered to backend, depending on
        # HTAPartitionerDemo's rule.
        program_with_delegates = traced
        program_with_delegates = program_with_delegates.to_backend(
            HTAPartitionerMultiplePatternsDemo()
        )
        program_with_delegates = program_with_delegates.to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )

        new_res = program_with_delegates.exported_program().module()(*inputs)
        for t1, t2 in zip(new_res, orig_res, strict=True):
            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))

        # Check the backend delegate
        self.check_backend_delegate(
            program=program_with_delegates._emitter_output.program,
            delegate=program_with_delegates._emitter_output.program.execution_plan[
                0
            ].delegates[0],
            expected_id=QnnBackend.__name__,
            expected_processed=b"imqnncompiled",
        )

        # Check add not in the program with delegates
        self.assertEqual(
            0,
            len(
                [
                    op
                    for op in program_with_delegates._emitter_output.program.execution_plan[
                        0
                    ].operators
                    if op.name == "aten::sub"
                ]
            ),
        )

        # Check convolution not in the program with delegates
        self.assertEqual(
            0,
            len(
                [
                    op
                    for op in program_with_delegates._emitter_output.program.execution_plan[
                        0
                    ].operators
                    if op.name == "aten::convolution"
                ]
            ),
        )

        # Check convolution in the program without delegates
        self.assertEqual(
            1,
            len(
                [
                    op
                    for op in program_without_delegates._emitter_output.program.execution_plan[
                        0
                    ].operators
                    if op.name == "aten::convolution"
                ]
            ),
        )

    @vary_segments
    def test_partition_delegate_graph_with_one_patterns(
        self, extract_delegate_segments: bool
    ):
        class CompositeModel(torch.nn.Module):
            def __init__(self, _weight):
                super().__init__()
                self.weight = _weight
                self.lstm = torch.nn.LSTM(
                    input_size=32,
                    hidden_size=32,
                    num_layers=1,
                )
                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)

            def forward(self, x_raw, h, c):
                output, (hn, cn) = self.lstm(x_raw, (h, c))
                k = self.conv(output)
                x = output
                y = cn
                a = torch.sub(x, y)
                b = torch.sub(x, a)
                c = torch.sub(x, b)
                d = torch.add(x, self.weight)
                e = torch.mul(c, d)
                return e, hn, k

        # Prepare input and trace it
        input_x = torch.ones([1, 32])
        input_h = torch.ones([1, 32])
        input_c = torch.ones([1, 32])
        inputs = (input_x, input_h, input_c)

        composite_m = CompositeModel(3)
        orig_res = composite_m(*inputs)

        traced = to_edge(
            export(composite_m, inputs),
            compile_config=exir.EdgeCompileConfig(
                _check_ir_validity=False, _use_edge_ops=True
            ),
        )

        program_without_delegates = to_edge(
            export(
                CompositeModel(3),
                (input_x, input_h, input_c),
            ),
            compile_config=exir.EdgeCompileConfig(
                _check_ir_validity=False,
            ),
        ).to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )
        # after this step, part of the graph will be lowered to backend, depending on
        # HTAPartitionerDemo's rule.
        traced_with_delegate = traced
        traced_with_delegate = traced_with_delegate.to_backend(
            HTAPartitionerOnePatternDemo()
        )

        new_res = traced_with_delegate.exported_program().module()(*inputs)
        for t1, t2 in zip(new_res, orig_res, strict=True):
            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))

        program_with_delegates = traced_with_delegate.to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )

        # TODO(T143084047): Currently not retraceable
        # Retracing is not needed, but keeping this here to make sure the result
        # of to_backend is retraceable
        # graph_module_with_delegate = to_edge(export(
        #     traced_with_delegate,
        #     (input_x, input_h, input_c),
        #
        # ))

        # program_with_delegates = graph_module_with_delegate.to_executorch(
        #     config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments),
        # )

        new_res = program_with_delegates.exported_program().module()(*inputs)
        for t1, t2 in zip(new_res, orig_res, strict=True):
            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))

        # Check the backend delegate
        self.check_backend_delegate(
            program=program_with_delegates._emitter_output.program,
            delegate=program_with_delegates._emitter_output.program.execution_plan[
                0
            ].delegates[0],
            expected_id=QnnBackend.__name__,
            expected_processed=b"imqnncompiled",
        )

        # Check add is in the program with delegates
        self.assertEqual(
            1,
            len(
                [
                    op
                    for op in program_with_delegates._emitter_output.program.execution_plan[
                        0
                    ].operators
                    if op.name == "aten::sub"
                ]
            ),
        )

        # Check convolution not in the program with delegates
        self.assertEqual(
            0,
            len(
                [
                    op
                    for op in program_with_delegates._emitter_output.program.execution_plan[
                        0
                    ].operators
                    if op.name == "aten::convolution"
                ]
            ),
        )

        # Check convolution in the program without delegates
        self.assertEqual(
            1,
            len(
                [
                    op
                    for op in program_without_delegates._emitter_output.program.execution_plan[
                        0
                    ].operators
                    if op.name == "aten::convolution"
                ]
            ),
        )

    @vary_segments
    def test_add_mul_partitioner(self, extract_delegate_segments: bool):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, a, x, b):
                y = torch.mm(a, x)
                z = y + b
                a = z - a
                y = torch.mm(a, x)
                z = y + b
                return z

        m = Model()
        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
        orig_res = m(*inputs)

        ep = to_edge(export(m, inputs))
        executorch_prog = ep
        executorch_prog = executorch_prog.to_backend(AddMulPartitionerDemo())
        executorch_prog = executorch_prog.to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )

        new_res = executorch_prog.exported_program().graph_module(*inputs)
        self.assertTrue(torch.allclose(new_res[0], orig_res))

        counter = 0
        for node in executorch_prog.exported_program().graph_module.graph.nodes:
            if node.op == "get_attr":
                self.assertEqual(node.target, f"lowered_module_{counter}")
                counter += 1
        # There should be 2 delegated modules
        self.assertEqual(counter, 2)

        executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer)
        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
        inputs_flattened, _ = tree_flatten(inputs)
        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
        ref_output = m(*inputs)

        self.assertTrue(
            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03),
        )

    @vary_segments
    def test_partitioner_with_attributes(self, extract_delegate_segments: bool):
        """
        check that parameters that are lowered are correctly moved into the sub
        program, rather than being retained and passed as inputs.
        """

        class AddOne(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.register_buffer("one", torch.ones(1, 3))

            def forward(self, x):
                return x + self.one

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.add_one = AddOne()
                self.add_one_2 = AddOne()

            def forward(self, x, y):
                x = self.add_one(x) * y
                return self.add_one_2(x)

        inputs = (torch.randn(1, 3), torch.randn(1, 3))
        orig_res = Model()(*inputs)
        ep = to_edge(export(Model(), inputs))
        executorch_prog = ep
        executorch_prog = executorch_prog.to_backend(AddAttributePartitionerDemo())
        executorch_prog = executorch_prog.to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )

        # Check the delegated submodules
        lowered_backends = get_lowered_backend_modules(
            executorch_prog.exported_program().graph_module
        )
        self.assertEqual(len(lowered_backends), 2)
        for backend in lowered_backends:
            original_program = backend.original_module
            # check that program has the lowered attributes
            self.assertEqual(len(original_program.state_dict), 1)
            # check backend has one placeholder input one placeholder parameter
            self.check_delegate_input(backend, 2)

        executorch_prog.buffer

        new_res = executorch_prog.exported_program().graph_module(*inputs)
        self.assertTrue(torch.allclose(orig_res, new_res[0]))

    def test_bad_partitioner(self):
        """
        Checks that we throw an error if user provided partitioner modifies the
        graph module
        """
        inputs = (torch.randn(1, 3), torch.randn(1, 3))

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, y):
                x = x + y
                x = x * y
                x = x - y
                x = x / y
                x = x * y
                x = x + y
                return x

        class BadPartitioner(Partitioner):
            partition_tags = {"tag1": DelegationSpec("BackendWithCompilerDemo", [])}

            def partition(self, exported_program: ExportedProgram) -> PartitionResult:
                # Partitioner should not modify the given graph module
                partition_tags: Dict[str, DelegationSpec] = {}
                for node in exported_program.graph.nodes:
                    if (
                        node.op == "call_function"
                        and node.target == exir_ops.edge.aten.add.Tensor
                    ):
                        node.target = exir_ops.edge.aten.mul.Tensor
                return PartitionResult(
                    tagged_exported_program=exported_program,
                    partition_tags=partition_tags,
                )

        ep = to_edge(export(Model(), inputs))
        with self.assertRaises(AssertionError):
            _ = ep.to_backend(BadPartitioner())

    def test_quantized_with_delegate(self) -> None:
        torch.ops.load_library(
            "//executorch/kernels/quantized:custom_ops_generated_lib"
        )
        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
        in_size = 2
        input_size = 3
        output_size = 4
        linear = torch.nn.Linear(input_size, output_size).eval()
        example_inputs = (torch.ones(in_size, input_size),)
        prepared_linear = prepare_fx(
            linear,
            qconfig_mapping,
            example_inputs,
            backend_config=get_executorch_backend_config(),
        )
        converted_linear: torch.nn.Module = _convert_to_reference_decomposed_fx(
            prepared_linear,
        )

        # fails to trace here
        converted_linear_gm = to_edge(
            export(
                converted_linear,
                example_inputs,
            ),
            compile_config=exir.EdgeCompileConfig(
                _check_ir_validity=False,
            ),
        )
        FileCheck().check_count("quantize_per_tensor_default", 3).check("addmm").run(
            converted_linear_gm.exported_program().graph_module.code
        )

    def test_partition_with_control_flow(self) -> None:
        def true_fn(x, y):
            x = x - y
            x = x + y
            x = x - y
            return x

        def false_fn(x, y):
            x = x - y
            x = torch.mm(x, y)
            x = x - y
            return x

        class Module(torch.nn.Module):
            def forward(self, x, y):
                x = x + y
                x = control_flow.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
                x = x - y
                return x

        f = Module()
        inputs = (torch.ones(2, 2), torch.ones(2, 2))
        orig_res = f(*inputs)
        orig = to_edge(
            export(
                f,
                inputs,
            )
        )
        partitioned = orig
        partitioned = partitioned.to_backend(AddMulPartitionerDemo())

        new_res = partitioned.exported_program().module()(*inputs)
        self.assertTrue(torch.allclose(orig_res, new_res[0]))

        toplevel_lowered = get_lowered_submodules(
            partitioned.exported_program().graph_module
        )
        self.assertEqual(len(toplevel_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
            toplevel_lowered[0][1].original_module.graph_module.code
        )

        # Toplevel module only has the cond submodules
        partitioned_submodules = get_control_flow_submodules(
            partitioned.exported_program().graph_module
        )
        self.assertEqual(len(partitioned_submodules), 2)

        true_gm = partitioned_submodules[0][1]
        true_lowered = get_lowered_submodules(true_gm)
        self.assertEqual(len(true_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
            true_lowered[0][1].original_module.graph_module.code
        )

        false_gm = partitioned_submodules[1][1]
        false_lowered = get_lowered_submodules(false_gm)
        self.assertEqual(len(true_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
            false_lowered[0][1].original_module.graph_module.code
        )

    def test_partition_with_map(self) -> None:
        def map_fn(x, y):
            x = x - y
            x = x + y
            return x

        class Module(torch.nn.Module):
            def forward(self, xs, y):
                y = torch.mm(y, y)
                return control_flow.map(map_fn, xs, y)

        f = Module()
        inputs = (torch.ones(2, 2), torch.ones(2, 2))
        orig_res = f(*inputs)
        orig = to_edge(
            export(
                f,
                inputs,
            )
        )
        partitioned = orig
        partitioned = partitioned.to_backend(AddMulPartitionerDemo())

        toplevel_lowered = get_lowered_submodules(
            partitioned.exported_program().graph_module
        )
        self.assertEqual(len(toplevel_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
            toplevel_lowered[0][1].original_module.graph_module.code
        )

        # Toplevel module only has the map submodule
        partitioned_submodules = get_control_flow_submodules(
            partitioned.exported_program().graph_module
        )
        self.assertEqual(len(partitioned_submodules), 1)

        map_fn_gm = partitioned_submodules[0][1]
        map_fn_lowered = get_lowered_submodules(map_fn_gm)
        self.assertEqual(len(map_fn_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
            map_fn_lowered[0][1].original_module.graph_module.code
        )

        new_res = partitioned.exported_program().module()(*inputs)

        self.assertTrue(torch.allclose(orig_res, new_res[0]))

    def test_partition_with_nested_control_flow(self) -> None:
        """
        Partitions the add and mul ops, including the ones inside the submodules
        """

        def true_nested(y):
            y = y + y
            y = torch.mm(y, y)
            return y

        def false_nested(y):
            return torch.mm(y, y)

        def true_fn(x, pred2):
            z = control_flow.cond(pred2, true_nested, false_nested, [x])
            return x + z

        def false_fn(x, _):
            return x.cos()

        def map_fn(x, pred1, pred2, y):
            x = x.cos()
            y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
            x = x + y
            return x.sin()

        class Module(torch.nn.Module):
            def forward(self, xs, pred1, pred2, y):
                y = torch.mm(y, y)
                return control_flow.map(map_fn, xs, pred1, pred2, y)

        inputs = (
            torch.ones(2, 2),
            torch.tensor([False]),
            torch.Tensor([False]),
            torch.ones(2, 2),
        )

        f = Module()
        orig_res = f(*inputs)
        orig = to_edge(
            export(
                f,
                inputs,
            )
        )
        partitioned = orig
        partitioned = partitioned.to_backend(AddMulPartitionerDemo())

        new_res = partitioned.exported_program().module()(*inputs)
        self.assertTrue(torch.allclose(orig_res, new_res[0]))

        toplevel_lowered = get_lowered_submodules(
            partitioned.exported_program().graph_module
        )
        self.assertEqual(len(toplevel_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
            toplevel_lowered[0][1].original_module.graph_module.code
        )

        # Toplevel module only has the map submodule
        partitioned_submodules = get_control_flow_submodules(
            partitioned.exported_program().graph_module
        )
        self.assertEqual(len(partitioned_submodules), 1)

        # Map module has the cond submodules
        map_submodules = get_control_flow_submodules(partitioned_submodules[0][1])
        self.assertEqual(len(map_submodules), 2)

        # True module
        true_module = map_submodules[0][1]
        true_lowered = get_lowered_submodules(true_module)
        self.assertEqual(len(true_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
            true_lowered[0][1].original_module.graph_module.code
        )

        # False module
        false_lowered = get_lowered_submodules(map_submodules[1][1])
        self.assertEqual(len(false_lowered), 0)

        # True module has the nested cond submodules
        true_submodules = get_control_flow_submodules(true_module)
        self.assertEqual(len(true_submodules), 2)

        # Nested True module
        true_true_lowered = get_lowered_submodules(true_submodules[0][1])
        self.assertEqual(len(true_true_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").check(
            "executorch_exir_dialects_edge__ops_aten_mm_default"
        ).run(true_true_lowered[0][1].original_module.graph_module.code)

        # Nested False module
        true_false_lowered = get_lowered_submodules(true_submodules[1][1])
        self.assertEqual(len(true_false_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
            true_false_lowered[0][1].original_module.graph_module.code
        )

    def test_list_input(self):
        class Module(torch.nn.Module):
            def forward(self, x: List[torch.Tensor]):
                y = x[0] + x[1]
                return y

        f = Module()
        inputs = ([torch.randn(2, 2), torch.randn(2, 2)],)
        edge_prog = to_edge(export(f, inputs))
        lowered_gm = to_backend(
            BackendWithCompilerDemo.__name__, edge_prog.exported_program(), []
        )

        class ComposedM(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered = lowered_gm

            def forward(self, x: List[torch.Tensor]):
                return self.lowered(x)

        gm = to_edge(export(ComposedM(), inputs))
        gm.exported_program().module()(*inputs)

    def test_dict_input(self):
        class Module(torch.nn.Module):
            def forward(self, x: Dict[str, torch.Tensor]):
                y = x["a"] + x["b"]
                return y

        f = Module()
        inputs = ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},)
        edge_prog = to_edge(export(f, inputs))
        lowered_gm = to_backend(
            BackendWithCompilerDemo.__name__, edge_prog.exported_program(), []
        )

        class ComposedM(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered = lowered_gm

            def forward(self, x: List[torch.Tensor]):
                return self.lowered(x)

        gm = to_edge(export(ComposedM(), inputs))
        gm.exported_program().module()(*inputs)
