# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Iterator, Union

import torch
from executorch import exir
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.test.backend_with_delegate_mapping_demo import (
    BackendWithDelegateMappingDemo,
)

from executorch.exir.backend.utils import DelegateMappingBuilder


class TestDelegateMapBuilder(unittest.TestCase):
    def setUp(self) -> None:
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                y = torch.sin(x)
                return torch.cos(y)

        model = Model()
        model_inputs = (torch.ones(1, 1),)
        program = (
            exir.capture(model, model_inputs, exir.CaptureConfig(pt2_mode=True))
            .to_edge()
            .to_executorch()
        )

        # Create nodes for testing mapping
        # nodes: [arg0_1, alloc, aten_sin_default, alloc_1, aten_cos_default, output]
        # debug handles: [0, None, 1, None, 2, 3]
        self.nodes = list(program.graph_module.graph.nodes)

        self.handles = [node.meta.get("debug_handle") for node in self.nodes]

    def test_basic_generated_identifier(self):
        delegate_builder = DelegateMappingBuilder(generated_identifiers=True)

        expected_mapping = {0: (1, 2, 3, 4)}
        self.assertEqual(
            delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes), 0
        )
        self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)

        expected_mapping = {0: (1, 2, 3, 4), 1: (1,)}
        self.assertEqual(
            delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes[0]), 1
        )
        self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)

        expected_mapping = {0: (1, 2, 3, 4), 1: (1,), 2: (2,)}
        self.assertEqual(
            delegate_builder.insert_delegate_mapping_entry(handles=self.handles[2]),
            2,
        )
        self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)

        expected_mapping = {
            0: (1, 2, 3, 4),
            1: (1,),
            2: (2,),
            3: (1, 2, 3, 4),
        }
        self.assertEqual(
            delegate_builder.insert_delegate_mapping_entry(handles=self.handles), 3
        )
        self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)

    def test_basic_manual_int_identifier(self):
        self._test_basic_manual_identifier(iter([22, 55]))

    def test_basic_manual_string_identifier(self):
        self._test_basic_manual_identifier(iter(["22", "55"]))

    def test_adding_manual_identifier_when_generated(self):
        delegate_builder = DelegateMappingBuilder(generated_identifiers=True)

        self.assertRaises(
            Exception,
            lambda: delegate_builder.insert_delegate_mapping_entry(
                nodes=self.nodes, identifier="22"
            ),
        )
        self.assertRaises(
            Exception,
            lambda: delegate_builder.insert_delegate_mapping_entry(
                handles=self.handles, identifier="22"
            ),
        )

    def test_omitting_identifier_when_not_generated(self):
        delegate_builder = DelegateMappingBuilder()

        self.assertRaises(
            Exception,
            lambda: delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes),
        )
        self.assertRaises(
            Exception,
            lambda: delegate_builder.insert_delegate_mapping_entry(
                handles=self.handles
            ),
        )

    def test_reinsert_delegate_debug_identifier(self):
        delegate_builder = DelegateMappingBuilder()
        delegate_builder.insert_delegate_mapping_entry(
            nodes=self.nodes[0], identifier="1"
        )

        self.assertRaises(
            Exception,
            lambda: delegate_builder.insert_delegate_mapping_entry(
                nodes=self.nodes[0], identifier="1"
            ),
        )
        self.assertRaises(
            Exception,
            lambda: delegate_builder.insert_delegate_mapping_entry(
                handles=self.handles[0], identifier="1"
            ),
        )

    def test_backend_with_delegate_mapping(self) -> None:
        model, inputs = BackendWithDelegateMappingDemo.get_test_model_and_inputs()
        edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge(
            exir.EdgeCompileConfig(_check_ir_validity=False)
        )
        lowered_module = to_backend(
            "BackendWithDelegateMappingDemo", edgeir_m.exported_program, []
        )
        debug_handle_map = lowered_module.meta.get("debug_handle_map")
        self.assertIsNotNone(debug_handle_map)
        # There should be 3 backend ops in this model.
        self.assertEqual(len(debug_handle_map), 5)
        # Check to see that all the delegate debug indexes in the range [0,2] are present.
        self.assertTrue(
            all(element in debug_handle_map.keys() for element in [1, 2, 3, 4])
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_module = lowered_module

            def forward(self, x):
                return self.lowered_module(x)

        composite_model = CompositeModule()
        # TODO: Switch this to lowered_module.program() once lowered_module has support
        # for storing debug delegate identifier maps.
        exir.capture(
            composite_model, inputs, exir.CaptureConfig()
        ).to_edge().to_executorch()

    def test_passing_both_nodes_and_handles(self):
        delegate_builder = DelegateMappingBuilder()

        self.assertRaises(
            Exception,
            lambda: delegate_builder.insert_delegate_mapping_entry(
                nodes=self.nodes, handles=self.handles
            ),
        )

    def test_missing_handle_filtering(self):
        delegate_builder = DelegateMappingBuilder()
        self.assertRaises(
            Exception,
            lambda: delegate_builder.insert_delegate_mapping_entry(handles=[None]),
        )
        self.assertRaises(
            Exception,
            lambda: delegate_builder.insert_delegate_mapping_entry(nodes=[None]),
        )

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]):
        """
        Using the iteration of identifiers:
        1) Create a Delegate Map Builder
        2) Add an entry with a list of Nodes using the first identifier
        3) Add an entry with a single node using the second identifier

        Verify behavior results
        """

        delegate_builder_nodes = DelegateMappingBuilder()
        delegate_builder_handles = DelegateMappingBuilder()

        # Entry with a list of nodes
        iden_1 = next(identifiers)
        expected_mapping = {iden_1: (1, 2, 3, 4)}
        self.assertEqual(
            delegate_builder_nodes.insert_delegate_mapping_entry(
                nodes=self.nodes, identifier=iden_1
            ),
            iden_1,
        )
        self.assertEqual(
            delegate_builder_handles.insert_delegate_mapping_entry(
                handles=self.handles, identifier=iden_1
            ),
            iden_1,
        )
        self.assertEqual(
            delegate_builder_nodes.get_delegate_mapping(), expected_mapping
        )
        self.assertEqual(
            delegate_builder_handles.get_delegate_mapping(), expected_mapping
        )

        # Entry with a single node
        iden_2 = next(identifiers)
        expected_mapping = {iden_1: (1, 2, 3, 4), iden_2: (1,)}
        self.assertEqual(
            delegate_builder_nodes.insert_delegate_mapping_entry(
                nodes=self.nodes[0], identifier=iden_2
            ),
            iden_2,
        )
        self.assertEqual(
            delegate_builder_handles.insert_delegate_mapping_entry(
                handles=self.handles[0], identifier=iden_2
            ),
            iden_2,
        )
        self.assertEqual(
            delegate_builder_nodes.get_delegate_mapping(), expected_mapping
        )
        self.assertEqual(
            delegate_builder_handles.get_delegate_mapping(), expected_mapping
        )
