# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
import unittest
from typing import Dict, List

import executorch.exir as exir
import torch
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
    DelegationSpec,
    Partitioner,
    PartitionResult,
)

# import the backend implementation
from executorch.exir.backend.test.backend_with_compiler_demo import (
    BackendWithCompilerDemo,
)
from executorch.exir.backend.test.hta_partitioner_demo import (
    HTAPartitionerMultiplePatternsDemo,
    HTAPartitionerOnePatternDemo,
)
from executorch.exir.backend.test.op_partitioner_demo import (
    AddAttributePartitionerDemo,
    AddMulPartitionerDemo,
)
from executorch.exir.backend.test.qnn_backend_demo import QnnBackend

from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.print_program import print_program
from executorch.exir.schema import (
    BackendDelegate,
    BackendDelegateDataReference,
    DataLocation,
    DelegateCall,
    Program,
)

from executorch.extension.pybindings.portable_lib import (  # @manual
    _load_for_executorch_from_buffer,
)
from executorch.extension.pytree import tree_flatten

from functorch.experimental import control_flow
from torch.ao.quantization import get_default_qconfig_mapping  # @manual
from torch.ao.quantization.backend_config.executorch import (
    get_executorch_backend_config,
)
from torch.ao.quantization.quantize_fx import (
    _convert_to_reference_decomposed_fx,
    prepare_fx,
)
from torch.export import ExportedProgram
from torch.testing import FileCheck


def vary_segments(test_method):
    """A decorator that calls the test method with `extract_delegate_segments` set to
    True and False.

    Decorated test methods must expect a boolean parameter named
    `extract_delegate_segments`, and they should pass that value to to_executorch() like:

        m.to_executorch(
            config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments)
        )

    This will cause the delegate data blobs to be extracted from the program and
    serialized as separate, freeable program segments. Backends should detect no
    difference at runtime.
    """

    def wrapper(self):
        for extract_delegate_segments in [False, True]:
            # subTest will create a different top-level test entry for each
            # value, whose full names have a suffix like
            # "(extract_delegate_segments=True)".
            with self.subTest(extract_delegate_segments=extract_delegate_segments):
                test_method(self, extract_delegate_segments=extract_delegate_segments)

    return wrapper


class TestBackends(unittest.TestCase):
    def check_delegate_input(
        self, delegate: LoweredBackendModule, input_len: int
    ) -> None:
        counter = 0
        for node in delegate.original_module.graph.nodes:
            if node.op == "placeholder":
                counter += 1
        self.assertEqual(counter, input_len)

    def check_backend_delegate(
        self,
        program: Program,
        delegate: BackendDelegate,
        expected_id: str,
        expected_processed: bytes,
    ) -> None:
        self.assertEqual(delegate.id, expected_id)
        processed: BackendDelegateDataReference = delegate.processed
        self.assertEqual(processed.location, DataLocation.INLINE)
        self.assertLess(processed.index, len(program.backend_delegate_data))
        self.assertEqual(
            program.backend_delegate_data[processed.index].data, expected_processed
        )

    @vary_segments
    def test_backend_with_compiler(self, extract_delegate_segments: bool):
        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            # TODO(chenlai): add a test with a diffrent method name when
            # it's resolved in compiler side.
            def forward(self, x):
                return torch.sin(x)

        sin_module = SinModule()
        model_inputs = (torch.ones(1),)
        edgeir_m = exir.capture(
            sin_module, model_inputs, exir.CaptureConfig()
        ).to_edge()
        max_value = model_inputs[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_sin_module = to_backend(
            "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_linear_sin = lowered_sin_module

            def forward(self, x):
                return self.lowered_linear_sin(x)

        composite_model = CompositeModule()
        model_inputs = (torch.ones(1),)

        composite_model(*model_inputs)

        exec_prog = (
            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
            .to_edge()
            .to_executorch(
                config=exir.ExecutorchBackendConfig(
                    extract_delegate_segments=extract_delegate_segments
                )
            )
        )
        graph_module = exec_prog.dump_graph_module()

        # Check that there is not an aten.sin node.
        self.assertTrue(
            exir_ops.edge.aten.sin
            not in {node.target for node in graph_module.graph.nodes}
        )

        # Check that there exists a call_delegate, representing the call to the
        # delegated function
        FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
            graph_module.code
        )
        lowered_submodules = get_lowered_submodules(graph_module)
        self.assertEqual(len(lowered_submodules), 1)

        for node in graph_module.graph.nodes:
            if node.op == "call_function" and node.target == executorch_call_delegate:
                # Check that first arg is lowered_module_{unique_id}
                self.assertEqual(node.args[0].target, "lowered_module_0")

        program = exec_prog.program

        # Check the program can be printed
        print_program(program)

        # Check the backend delegate
        self.check_backend_delegate(
            program=program,
            delegate=program.execution_plan[0].delegates[0],
            expected_id=BackendWithCompilerDemo.__name__,
            expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
        )

        # Check the delegate instruction
        self.assertTrue(
            isinstance(
                program.execution_plan[0].chains[0].instructions[0].instr_args,
                DelegateCall,
            )
        )
        buff = exec_prog.buffer

        executorch_module = _load_for_executorch_from_buffer(buff)
        model_inputs = torch.ones(1)
        model_outputs = executorch_module.forward([model_inputs])
        self.assertEqual(
            model_inputs,
            torch.ones(1),
        )
        expected_output = 0.8333 * torch.ones(1)

        self.assertTrue(
            torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
        )

    @vary_segments
    def test_lowered_add_mul(self, extract_delegate_segments: bool):
        class AddMulModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, a, x, b):
                y = torch.mm(a, x)
                z = torch.add(y, b)
                return z

        add_mul_module = AddMulModule()
        model_inputs = (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
        edge_graph_module = exir.capture(
            add_mul_module, model_inputs, exir.CaptureConfig()
        ).to_edge()
        max_value = model_inputs[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_add_mul = to_backend(
            "BackendWithCompilerDemo", edge_graph_module.exported_program, compile_specs
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_add_mul = lowered_add_mul

            def forward(self, a, x, b):
                return self.lowered_add_mul(a, x, b)

        composite_model = CompositeModule()

        composite_model(*model_inputs)

        exec_prog = (
            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
            .to_edge()
            .to_executorch(
                config=exir.ExecutorchBackendConfig(
                    extract_delegate_segments=extract_delegate_segments
                )
            )
        )
        buff = exec_prog.buffer

        executorch_module = _load_for_executorch_from_buffer(buff)

        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
        inputs_flattened, _ = tree_flatten(model_inputs)
        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
        ref_output = add_mul_module(*model_inputs)

        self.assertTrue(
            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
        )

    def run_model_in_unsupported_backend(self, extract_delegate_segments: bool):
        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.sin(x)

        sin_module = SinModule()
        # the backend only  accepts shape <= 4
        model_inputs = (torch.ones(6),)
        edgeir_m = exir.capture(
            sin_module, model_inputs, exir.CaptureConfig()
        ).to_edge()
        max_value = model_inputs[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_sin_module = to_backend(
            "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_linear_sin = lowered_sin_module

            def forward(self, x):
                return self.lowered_linear_sin(x)

        composite_model = CompositeModule()
        model_inputs = (torch.zeros(6),)

        composite_model(*model_inputs)

        exec_prog = (
            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
            .to_edge()
            .to_executorch(
                config=exir.ExecutorchBackendConfig(
                    extract_delegate_segments=extract_delegate_segments
                ),
            )
        )

        buff = exec_prog.buffer

        # This line should raise an exception like
        # RuntimeError: failed with error 0x12
        _load_for_executorch_from_buffer(buff)

    @vary_segments
    def test_backend_with_compiler_out_of_range(self, extract_delegate_segments: bool):
        with self.assertRaisesRegex(
            RuntimeError,
            "loading method forward failed with error 0x12",
        ):
            self.run_model_in_unsupported_backend(
                extract_delegate_segments=extract_delegate_segments
            )

    @vary_segments
    def test_backend_with_compiler_delegate_and_operator(
        self, extract_delegate_segments: bool
    ):
        # Test includes both delegates and operator
        # import the backend implementation
        from executorch.exir.backend.test.backend_with_compiler_demo import (
            BackendWithCompilerDemo,
        )

        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            # TODO(chenlai): add a test with a diffrent method name when
            # it's resolved in compiler side.
            def forward(self, x):
                return [torch.sin(x)]

        sin_module = SinModule()
        model_inputs = (torch.ones(1),)
        edgeir_m = exir.capture(
            sin_module, model_inputs, exir.CaptureConfig()
        ).to_edge()
        max_value = model_inputs[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_sin_module = to_backend(
            "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_linear_sin = lowered_sin_module

            def forward(self, x):
                a = self.lowered_linear_sin(x)[0]
                b = self.lowered_linear_sin(x)[0]
                return torch.add(a, b)

        composite_model = CompositeModule()
        model_inputs = (torch.ones(1),)

        composite_model(*model_inputs)

        exec_prog = (
            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
            .to_edge()
            .to_executorch(
                config=exir.ExecutorchBackendConfig(
                    extract_delegate_segments=extract_delegate_segments
                ),
            )
        )
        graph_module = exec_prog.dump_graph_module()
        program = exec_prog.program
        buff = exec_prog.buffer

        # Check that there is not an aten.sin node.
        self.assertTrue(
            exir_ops.edge.aten.sin.default
            not in {node.target for node in graph_module.graph.nodes}
        )

        # Check that there exists a call_delegate op, representing the call to the
        # delegated function
        FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
            graph_module.code
        )

        for node in graph_module.graph.nodes:
            if node.op == "call_function" and node.target == executorch_call_delegate:
                # Check that first arg is lowered_module_{unique_id}
                self.assertEqual(node.args[0].target, "lowered_module_0")

        # Check the backend delegate
        self.check_backend_delegate(
            program=program,
            delegate=program.execution_plan[0].delegates[0],
            expected_id=BackendWithCompilerDemo.__name__,
            expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
        )

        # Check the delegate instruction
        self.assertTrue(
            isinstance(
                program.execution_plan[0].chains[0].instructions[0].instr_args,
                DelegateCall,
            )
        )

        executorch_module = _load_for_executorch_from_buffer(buff)
        model_inputs = torch.ones(1)

        model_outputs = executorch_module.forward([model_inputs])

        self.assertEqual(
            model_inputs,
            torch.ones(1),
        )
        expected_output = 1.666667 * torch.ones(1)

        self.assertTrue(
            torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
        )

    def test_backend_with_compiler_backend_runtime_exception(self):
        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            # TODO(chenlai): add a test with a diffrent method name when
            # it's resolved in compiler side.
            def forward(self, x):
                return torch.sin(x) + torch.cos(x)

        sin_module = SinModule()
        model_inputs = (torch.ones(1),)
        edgeir_m = exir.capture(
            sin_module, model_inputs, exir.CaptureConfig()
        ).to_edge()
        error_msg = r"call_function aten.cos.default is not supported in backend BackendWithCompilerDemo"

        with self.assertRaisesRegex(
            RuntimeError,
            error_msg,
        ):
            _ = to_backend("BackendWithCompilerDemo", edgeir_m.exported_program, [])

    def test_backend_with_compiler_backend_not_found_exception(self):
        class SinModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            # TODO(chenlai): add a test with a diffrent method name when
            # it's resolved in compiler side.
            def forward(self, x):
                return torch.sin(x) + torch.cos(x)

        sin_module = SinModule()
        model_inputs = (torch.ones(1),)
        edgeir_m = exir.capture(
            sin_module, model_inputs, exir.CaptureConfig()
        ).to_edge()
        error_msg = r"Backend FakeBackendWithCompilerDemo was not found."

        with self.assertRaisesRegex(
            NotImplementedError,
            error_msg,
        ):
            _ = to_backend("FakeBackendWithCompilerDemo", edgeir_m.exported_program, [])

    @vary_segments
    def test_backend_with_compiler_delegate_and_operator_with_two_modules(
        self, extract_delegate_segments: bool
    ):
        # the submodule runs in a specific backend. In this example, `BackendWithCompilerDemo` backend
        class LowerableSubModel(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.sin(x)

        # sin_module is an nn.Module
        to_be_lowered = LowerableSubModel()
        example_input = (torch.ones(1),)
        to_be_lowered_exir_submodule = exir.capture(
            to_be_lowered, example_input, exir.CaptureConfig()
        ).to_edge()

        max_value = example_input[0].shape[0]
        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
        lowered_module = to_backend(
            "BackendWithCompilerDemo",
            to_be_lowered_exir_submodule.exported_program,
            compile_specs,
        )

        class NonLowerableSubModel(torch.nn.Module):
            def __init__(self, bias):
                super().__init__()
                self.bias = bias

            def forward(self, a, b):
                return torch.add(torch.add(a, b), self.bias)

        # the composite modules, including lower part and non-lowerpart
        class CompositeModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.non_lowerable = NonLowerableSubModel(torch.ones(1) * 0.3)
                self.lowerable = lowered_module

            def forward(self, x):
                a = self.lowerable(x)
                b = self.lowerable(a)
                ret = self.non_lowerable(a, b)
                return a, b, ret

        composite_model = CompositeModel()

        # Prepare the model input
        model_inputs = (torch.ones(1),)

        # Verify the input works with eager module
        composite_model(*model_inputs)

        exec_prog = (
            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
            .to_edge()
            .to_executorch(
                config=exir.ExecutorchBackendConfig(
                    extract_delegate_segments=extract_delegate_segments
                ),
            )
        )
        flatbuffer = exec_prog.buffer

        executorch_module = _load_for_executorch_from_buffer(flatbuffer)
        model_outputs = executorch_module.forward([*model_inputs])

        expected_outputs = [
            0.8333 * torch.ones(1),
            0.7369 * torch.ones(1),
            1.8702 * torch.ones(1),
        ]

        for index, expected_output in enumerate(expected_outputs):
            self.assertTrue(
                torch.allclose(
                    model_outputs[index], expected_output, atol=1e-03, rtol=1e-03
                )
            )

    @vary_segments
    def test_partition_delegate_graph_with_multiple_patterns(
        self, extract_delegate_segments: bool
    ):
        class CompositeModel(torch.nn.Module):
            def __init__(self, _weight):
                super().__init__()
                self.weight = _weight
                self.lstm = torch.nn.LSTM(
                    input_size=32,
                    hidden_size=32,
                    num_layers=1,
                )
                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)

            def forward(self, x_raw, h, c):
                output, (hn, cn) = self.lstm(x_raw, (h, c))
                k = self.conv(output)
                x = output
                y = cn
                a = torch.sub(x, y)
                b = torch.sub(x, a)
                c = torch.sub(x, b)
                d = torch.add(x, self.weight)
                e = torch.mul(c, d)
                return e, hn, k

        # Prepare input and trace it
        input_x = torch.ones([1, 32])
        input_h = torch.ones([1, 32])
        input_c = torch.ones([1, 32])
        inputs = (input_x, input_h, input_c)

        composite_m = CompositeModel(3)
        orig_res = composite_m(*inputs)

        traced = exir.capture(composite_m, inputs, exir.CaptureConfig()).to_edge(
            # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
            exir.EdgeCompileConfig(_check_ir_validity=False)
        )

        program_without_delegates = (
            exir.capture(CompositeModel(3), inputs)
            .to_edge(
                # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
                exir.EdgeCompileConfig(_check_ir_validity=False)
            )
            .to_executorch(
                config=exir.ExecutorchBackendConfig(
                    extract_delegate_segments=extract_delegate_segments
                ),
            )
        )
        # after this step, part of the graph will be lowered to backend, depending on
        # HTAPartitionerDemo's rule.
        program_with_delegates = traced
        program_with_delegates.exported_program = to_backend(
            traced.exported_program, HTAPartitionerMultiplePatternsDemo()
        )
        program_with_delegates = program_with_delegates.to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )

        new_res = program_with_delegates.dump_graph_module()(*inputs)
        for t1, t2 in zip(new_res, orig_res, strict=True):
            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))

        # Check the backend delegate
        self.check_backend_delegate(
            program=program_with_delegates.program,
            delegate=program_with_delegates.program.execution_plan[0].delegates[0],
            expected_id=QnnBackend.__name__,
            expected_processed=b"imqnncompiled",
        )

        # Check add not in the program with delegates
        self.assertEqual(
            0,
            len(
                [
                    op
                    for op in program_with_delegates.program.execution_plan[0].operators
                    if op.name == "aten::sub"
                ]
            ),
        )

        # Check convolution not in the program with delegates
        self.assertEqual(
            0,
            len(
                [
                    op
                    for op in program_with_delegates.program.execution_plan[0].operators
                    if op.name == "aten::convolution"
                ]
            ),
        )

        # Check convolution in the program without delegates
        self.assertEqual(
            1,
            len(
                [
                    op
                    for op in program_without_delegates.program.execution_plan[
                        0
                    ].operators
                    if op.name == "aten::convolution"
                ]
            ),
        )

    @vary_segments
    def test_partition_delegate_graph_with_one_patterns(
        self, extract_delegate_segments: bool
    ):
        class CompositeModel(torch.nn.Module):
            def __init__(self, _weight):
                super().__init__()
                self.weight = _weight
                self.lstm = torch.nn.LSTM(
                    input_size=32,
                    hidden_size=32,
                    num_layers=1,
                )
                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)

            def forward(self, x_raw, h, c):
                output, (hn, cn) = self.lstm(x_raw, (h, c))
                k = self.conv(output)
                x = output
                y = cn
                a = torch.sub(x, y)
                b = torch.sub(x, a)
                c = torch.sub(x, b)
                d = torch.add(x, self.weight)
                e = torch.mul(c, d)
                return e, hn, k

        # Prepare input and trace it
        input_x = torch.ones([1, 32])
        input_h = torch.ones([1, 32])
        input_c = torch.ones([1, 32])
        inputs = (input_x, input_h, input_c)

        composite_m = CompositeModel(3)
        orig_res = composite_m(*inputs)

        traced = exir.capture(
            composite_m,
            inputs,
            exir.CaptureConfig(),
        ).to_edge(
            # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
            exir.EdgeCompileConfig(_check_ir_validity=False)
        )

        program_without_delegates = (
            exir.capture(
                CompositeModel(3),
                (input_x, input_h, input_c),
                exir.CaptureConfig(),
            )
            .to_edge(
                # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
                exir.EdgeCompileConfig(_check_ir_validity=False)
            )
            .to_executorch(
                config=exir.ExecutorchBackendConfig(
                    extract_delegate_segments=extract_delegate_segments
                ),
            )
        )
        # after this step, part of the graph will be lowered to backend, depending on
        # HTAPartitionerDemo's rule.
        traced_with_delegate = traced
        traced_with_delegate.exported_program = to_backend(
            traced.exported_program, HTAPartitionerOnePatternDemo()
        )

        new_res = traced_with_delegate(*inputs)
        for t1, t2 in zip(new_res, orig_res, strict=True):
            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))

        program_with_delegates = traced_with_delegate.to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )

        # TODO(T143084047): Currently not retraceable
        # Retracing is not needed, but keeping this here to make sure the result
        # of to_backend is retraceable
        # graph_module_with_delegate = exir.capture(
        #     traced_with_delegate,
        #     (input_x, input_h, input_c),
        #     exir.CaptureConfig(),
        # ).to_edge()

        # program_with_delegates = graph_module_with_delegate.to_executorch(
        #     config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments),
        # )

        new_res = program_with_delegates.dump_graph_module()(*inputs)
        for t1, t2 in zip(new_res, orig_res, strict=True):
            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))

        # Check the backend delegate
        self.check_backend_delegate(
            program=program_with_delegates.program,
            delegate=program_with_delegates.program.execution_plan[0].delegates[0],
            expected_id=QnnBackend.__name__,
            expected_processed=b"imqnncompiled",
        )

        # Check add is in the program with delegates
        self.assertEqual(
            1,
            len(
                [
                    op
                    for op in program_with_delegates.program.execution_plan[0].operators
                    if op.name == "aten::sub"
                ]
            ),
        )

        # Check convolution not in the program with delegates
        self.assertEqual(
            0,
            len(
                [
                    op
                    for op in program_with_delegates.program.execution_plan[0].operators
                    if op.name == "aten::convolution"
                ]
            ),
        )

        # Check convolution in the program without delegates
        self.assertEqual(
            1,
            len(
                [
                    op
                    for op in program_without_delegates.program.execution_plan[
                        0
                    ].operators
                    if op.name == "aten::convolution"
                ]
            ),
        )

    @vary_segments
    def test_add_mul_partitioner(self, extract_delegate_segments: bool):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, a, x, b):
                y = torch.mm(a, x)
                z = y + b
                a = z - a
                y = torch.mm(a, x)
                z = y + b
                return z

        m = Model()
        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
        orig_res = m(*inputs)

        ep = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
        executorch_prog = ep
        executorch_prog.exported_program = to_backend(
            ep.exported_program, AddMulPartitionerDemo()
        )

        for node in executorch_prog.exported_program.graph.nodes:
            if node.op == "call_function" and node.target is executorch_call_delegate:
                for user in node.users:
                    self.assertTrue(
                        user.op == "call_function" and user.target == operator.getitem
                    )
                    self.assertTrue(user.meta.get("source_fn_stack", None) is None)
                    self.assertTrue(user.meta.get("nn_module_stack", None) is None)

        executorch_prog = executorch_prog.to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )

        new_res = executorch_prog.dump_graph_module()(*inputs)
        self.assertTrue(torch.allclose(new_res[0], orig_res))

        counter = 0
        for node in executorch_prog.dump_graph_module().graph.nodes:
            if node.op == "get_attr":
                self.assertEqual(node.target, f"lowered_module_{counter}")
                counter += 1
        # There should be 2 delegated modules
        self.assertEqual(counter, 2)

        executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer)
        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
        inputs_flattened, _ = tree_flatten(inputs)
        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
        ref_output = m(*inputs)

        self.assertTrue(
            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03),
        )

    @vary_segments
    def test_partitioner_with_attributes(self, extract_delegate_segments: bool):
        """
        Check that if we tag the getattr nodes, the attributes will be added to
        the lowered submodule rather than being passed into the delegate as
        inputs.
        """

        class AddOne(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.one = torch.ones(1, 3)

            def forward(self, x):
                return x + self.one

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.add_one = AddOne()

            def forward(self, x, y):
                x = self.add_one(x) * y
                return self.add_one(x), self.add_one(y)

        inputs = (torch.randn(1, 3), torch.randn(1, 3))
        orig_res = Model()(*inputs)
        ep = exir.capture(Model(), inputs, exir.CaptureConfig()).to_edge()
        executorch_prog = ep
        executorch_prog.exported_program = to_backend(
            ep.exported_program, AddAttributePartitionerDemo()
        )

        for node in executorch_prog.exported_program.graph.nodes:
            if node.op == "call_function" and node.target is executorch_call_delegate:
                for user in node.users:
                    self.assertTrue(
                        user.op == "call_function" and user.target == operator.getitem
                    )
                    self.assertTrue(user.meta.get("source_fn_stack", None) is None)
                    self.assertTrue(user.meta.get("nn_module_stack", None) is None)

        executorch_prog = executorch_prog.to_executorch(
            config=exir.ExecutorchBackendConfig(
                extract_delegate_segments=extract_delegate_segments
            ),
        )

        # Check the delegated submodules
        lowered_submodules = get_lowered_submodules(executorch_prog.dump_graph_module())
        self.assertEqual(len(lowered_submodules), 2)
        # Attributes should be stored in the lowered module
        self.check_delegate_input(lowered_submodules[0][1], 1)
        self.check_delegate_input(lowered_submodules[1][1], 2)

        executorch_prog.buffer

        new_res = executorch_prog.dump_graph_module()(*inputs)
        self.assertTrue(torch.allclose(orig_res[0], new_res[0]))
        self.assertTrue(torch.allclose(orig_res[1], new_res[1]))

    def test_bad_partitioner(self):
        """
        Checks that we throw an error if user provided partitioner modifies the
        graph module
        """
        inputs = (torch.randn(1, 3), torch.randn(1, 3))

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, y):
                x = x + y
                x = x * y
                x = x - y
                x = x / y
                x = x * y
                x = x + y
                return x

        class BadPartitioner(Partitioner):
            def partition(self, exported_program: ExportedProgram) -> PartitionResult:
                # Partitioner should not modify the given graph module
                for node in exported_program.graph.nodes:
                    if (
                        node.op == "call_function"
                        and node.target == exir_ops.edge.aten.add.Tensor
                    ):
                        node.target = exir_ops.edge.aten.mul.Tensor
                return PartitionResult(
                    tagged_exported_program=exported_program,
                    partition_tags={
                        "tag1": DelegationSpec("BackendWithCompilerDemo", [])
                    },
                )

        ep = exir.capture(Model(), inputs, exir.CaptureConfig()).to_edge()
        with self.assertRaises(AssertionError):
            _ = to_backend(ep.exported_program, BadPartitioner())

    def test_quantized_with_delegate(self) -> None:
        torch.ops.load_library(
            "//executorch/kernels/quantized:custom_ops_generated_lib"
        )
        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
        in_size = 2
        input_size = 3
        output_size = 4
        linear = torch.nn.Linear(input_size, output_size).eval()
        example_inputs = (torch.ones(in_size, input_size),)
        prepared_linear = prepare_fx(
            linear,
            qconfig_mapping,
            example_inputs,
            backend_config=get_executorch_backend_config(),
        )
        converted_linear: torch.nn.Module = _convert_to_reference_decomposed_fx(
            prepared_linear,
        )

        # fails to trace here
        converted_linear_gm = exir.capture(
            converted_linear,
            example_inputs,
            exir.CaptureConfig(
                enable_aot=True,
            ),
        ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
        FileCheck().check_count("quantize_per_tensor_default", 3).check("addmm").run(
            converted_linear_gm.exported_program.graph_module.code
        )

    def test_partition_with_control_flow(self) -> None:
        def true_fn(x, y):
            x = x - y
            x = x + y
            x = x - y
            return x

        def false_fn(x, y):
            x = x - y
            x = torch.mm(x, y)
            x = x - y
            return x

        def f(x, y):
            x = x + y
            x = torch.ops.higher_order.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
            x = x - y
            return x

        inputs = (torch.ones(2, 2), torch.ones(2, 2))
        orig_res = f(*inputs)
        orig = exir.capture(
            f,
            inputs,
            exir.CaptureConfig(),
        ).to_edge()
        partitioned = orig
        partitioned.exported_program = to_backend(
            orig.exported_program, AddMulPartitionerDemo()
        )

        new_res = partitioned(*inputs)
        self.assertTrue(torch.allclose(orig_res, new_res[0]))

        toplevel_lowered = get_lowered_submodules(
            partitioned.exported_program.graph_module
        )
        self.assertEqual(len(toplevel_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
            toplevel_lowered[0][1].original_module.graph_module.code
        )

        # Toplevel module only has the cond submodules
        partitioned_submodules = get_control_flow_submodules(
            partitioned.exported_program.graph_module
        )
        self.assertEqual(len(partitioned_submodules), 2)

        true_gm = partitioned_submodules[0][1]
        true_lowered = get_lowered_submodules(true_gm)
        self.assertEqual(len(true_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
            true_lowered[0][1].original_module.graph_module.code
        )

        false_gm = partitioned_submodules[1][1]
        false_lowered = get_lowered_submodules(false_gm)
        self.assertEqual(len(true_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
            false_lowered[0][1].original_module.graph_module.code
        )

    def test_partition_with_map(self) -> None:
        def map_fn(x, y):
            x = x - y
            x = x + y
            return x

        def f(xs, y):
            y = torch.mm(y, y)
            return control_flow.map(map_fn, xs, y)

        inputs = (torch.ones(2, 2), torch.ones(2, 2))
        orig_res = f(*inputs)
        orig = exir.capture(
            f,
            inputs,
            exir.CaptureConfig(),
        ).to_edge()
        partitioned = orig
        partitioned.exported_program = to_backend(
            orig.exported_program, AddMulPartitionerDemo()
        )

        toplevel_lowered = get_lowered_submodules(
            partitioned.exported_program.graph_module
        )
        self.assertEqual(len(toplevel_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
            toplevel_lowered[0][1].original_module.graph_module.code
        )

        # Toplevel module only has the map submodule
        partitioned_submodules = get_control_flow_submodules(
            partitioned.exported_program.graph_module
        )
        self.assertEqual(len(partitioned_submodules), 1)

        map_fn_gm = partitioned_submodules[0][1]
        map_fn_lowered = get_lowered_submodules(map_fn_gm)
        self.assertEqual(len(map_fn_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
            map_fn_lowered[0][1].original_module.graph_module.code
        )

        new_res = partitioned(*inputs)

        self.assertTrue(torch.allclose(orig_res, new_res[0]))

    def test_partition_with_nested_control_flow(self) -> None:
        """
        Partitions the add and mul ops, including the ones inside the submodules
        """

        def true_nested(y):
            y = y + y
            y = torch.mm(y, y)
            return y

        def false_nested(y):
            return torch.mm(y, y)

        def true_fn(x, pred2):
            z = control_flow.cond(pred2, true_nested, false_nested, [x])
            return x + z

        def false_fn(x, _):
            return x.cos()

        def map_fn(x, pred1, pred2, y):
            x = x.cos()
            y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
            x = x + y
            return x.sin()

        def f(xs, pred1, pred2, y):
            y = torch.mm(y, y)
            return control_flow.map(map_fn, xs, pred1, pred2, y)

        inputs = (
            torch.ones(2, 2),
            torch.tensor([False]),
            torch.Tensor([False]),
            torch.ones(2, 2),
        )

        orig_res = f(*inputs)
        orig = exir.capture(
            f,
            inputs,
            exir.CaptureConfig(),
        ).to_edge()
        partitioned = orig
        partitioned.exported_program = to_backend(
            orig.exported_program, AddMulPartitionerDemo()
        )

        new_res = partitioned(*inputs)
        self.assertTrue(torch.allclose(orig_res, new_res[0]))

        toplevel_lowered = get_lowered_submodules(
            partitioned.exported_program.graph_module
        )
        self.assertEqual(len(toplevel_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
            toplevel_lowered[0][1].original_module.graph_module.code
        )

        # Toplevel module only has the map submodule
        partitioned_submodules = get_control_flow_submodules(
            partitioned.exported_program.graph_module
        )
        self.assertEqual(len(partitioned_submodules), 1)

        # Map module has the cond submodules
        map_submodules = get_control_flow_submodules(partitioned_submodules[0][1])
        self.assertEqual(len(map_submodules), 2)

        # True module
        true_module = map_submodules[0][1]
        true_lowered = get_lowered_submodules(true_module)
        self.assertEqual(len(true_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
            true_lowered[0][1].original_module.graph_module.code
        )

        # False module
        false_lowered = get_lowered_submodules(map_submodules[1][1])
        self.assertEqual(len(false_lowered), 0)

        # True module has the nested cond submodules
        true_submodules = get_control_flow_submodules(true_module)
        self.assertEqual(len(true_submodules), 2)

        # Nested True module
        true_true_lowered = get_lowered_submodules(true_submodules[0][1])
        self.assertEqual(len(true_true_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").check(
            "executorch_exir_dialects_edge__ops_aten_mm_default"
        ).run(true_true_lowered[0][1].original_module.graph_module.code)

        # Nested False module
        true_false_lowered = get_lowered_submodules(true_submodules[1][1])
        self.assertEqual(len(true_false_lowered), 1)
        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
            true_false_lowered[0][1].original_module.graph_module.code
        )

    def test_list_input(self):
        def f(x: List[torch.Tensor]):
            y = x[0] + x[1]
            return y

        inputs = ([torch.randn(2, 2), torch.randn(2, 2)],)
        edge_prog = exir.capture(f, inputs, exir.CaptureConfig()).to_edge()
        lowered_gm = to_backend(
            BackendWithCompilerDemo.__name__, edge_prog.exported_program, []
        )

        class ComposedM(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered = lowered_gm

            def forward(self, x: List[torch.Tensor]):
                return self.lowered(x)

        gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
        gm(*inputs)

    def test_dict_input(self):
        class M(torch.nn.Module):
            def forward(self, x: Dict[str, torch.Tensor]):
                y = x["a"] + x["b"]
                return y

        inputs = ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},)
        edge_prog = exir.to_edge(torch.export.export(M(), inputs))
        lowered_gm = to_backend(
            BackendWithCompilerDemo.__name__, edge_prog.exported_program(), []
        )

        class ComposedM(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered = lowered_gm

            def forward(self, x: List[torch.Tensor]):
                return self.lowered(x)

        gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
        gm(*inputs)
