# Owner(s): ["module: onnx"]
"""Unit tests for the internal registration wrapper module."""

from typing import Sequence

from torch.onnx import errors
from torch.onnx._internal import registration
from torch.testing._internal import common_utils


@common_utils.instantiate_parametrized_tests
class TestGlobalHelpers(common_utils.TestCase):
    @common_utils.parametrize(
        "available_opsets, target, expected",
        [
            ((7, 8, 9, 10, 11), 16, 11),
            ((7, 8, 9, 10, 11), 11, 11),
            ((7, 8, 9, 10, 11), 10, 10),
            ((7, 8, 9, 10, 11), 9, 9),
            ((7, 8, 9, 10, 11), 8, 8),
            ((7, 8, 9, 10, 11), 7, 7),
            ((9, 10, 16), 16, 16),
            ((9, 10, 16), 15, 10),
            ((9, 10, 16), 10, 10),
            ((9, 10, 16), 9, 9),
            ((9, 10, 16), 8, 9),
            ((9, 10, 16), 7, 9),
            ((7, 9, 10, 16), 16, 16),
            ((7, 9, 10, 16), 10, 10),
            ((7, 9, 10, 16), 9, 9),
            ((7, 9, 10, 16), 8, 9),
            ((7, 9, 10, 16), 7, 7),
            ([17], 16, None),  # New op added in 17
            ([9], 9, 9),
            ([9], 8, 9),
            ([], 16, None),
            ([], 9, None),
            ([], 8, None),
            # Ops registered at opset 1 found as a fallback when target >= 9
            ([1], 16, 1),
        ],
    )
    def test_dispatch_opset_version_returns_correct_version(
        self, available_opsets: Sequence[int], target: int, expected: int
    ):
        actual = registration._dispatch_opset_version(target, available_opsets)
        self.assertEqual(actual, expected)


class TestOverrideDict(common_utils.TestCase):
    def setUp(self):
        self.override_dict: registration.OverrideDict[str, int] = (
            registration.OverrideDict()
        )

    def test_get_item_returns_base_value_when_no_override(self):
        self.override_dict.set_base("a", 42)
        self.override_dict.set_base("b", 0)

        self.assertEqual(self.override_dict["a"], 42)
        self.assertEqual(self.override_dict["b"], 0)
        self.assertEqual(len(self.override_dict), 2)

    def test_get_item_returns_overridden_value_when_override(self):
        self.override_dict.set_base("a", 42)
        self.override_dict.set_base("b", 0)
        self.override_dict.override("a", 100)
        self.override_dict.override("c", 1)

        self.assertEqual(self.override_dict["a"], 100)
        self.assertEqual(self.override_dict["b"], 0)
        self.assertEqual(self.override_dict["c"], 1)
        self.assertEqual(len(self.override_dict), 3)

    def test_get_item_raises_key_error_when_not_found(self):
        self.override_dict.set_base("a", 42)

        with self.assertRaises(KeyError):
            self.override_dict["nonexistent_key"]

    def test_get_returns_overridden_value_when_override(self):
        self.override_dict.set_base("a", 42)
        self.override_dict.set_base("b", 0)
        self.override_dict.override("a", 100)
        self.override_dict.override("c", 1)

        self.assertEqual(self.override_dict.get("a"), 100)
        self.assertEqual(self.override_dict.get("b"), 0)
        self.assertEqual(self.override_dict.get("c"), 1)
        self.assertEqual(len(self.override_dict), 3)

    def test_get_returns_none_when_not_found(self):
        self.override_dict.set_base("a", 42)

        self.assertEqual(self.override_dict.get("nonexistent_key"), None)

    def test_in_base_returns_true_for_base_value(self):
        self.override_dict.set_base("a", 42)
        self.override_dict.set_base("b", 0)
        self.override_dict.override("a", 100)
        self.override_dict.override("c", 1)

        self.assertIn("a", self.override_dict)
        self.assertIn("b", self.override_dict)
        self.assertIn("c", self.override_dict)

        self.assertTrue(self.override_dict.in_base("a"))
        self.assertTrue(self.override_dict.in_base("b"))
        self.assertFalse(self.override_dict.in_base("c"))
        self.assertFalse(self.override_dict.in_base("nonexistent_key"))

    def test_overridden_returns_true_for_overridden_value(self):
        self.override_dict.set_base("a", 42)
        self.override_dict.set_base("b", 0)
        self.override_dict.override("a", 100)
        self.override_dict.override("c", 1)

        self.assertTrue(self.override_dict.overridden("a"))
        self.assertFalse(self.override_dict.overridden("b"))
        self.assertTrue(self.override_dict.overridden("c"))
        self.assertFalse(self.override_dict.overridden("nonexistent_key"))

    def test_remove_override_removes_overridden_value(self):
        self.override_dict.set_base("a", 42)
        self.override_dict.set_base("b", 0)
        self.override_dict.override("a", 100)
        self.override_dict.override("c", 1)

        self.assertEqual(self.override_dict["a"], 100)
        self.assertEqual(self.override_dict["c"], 1)

        self.override_dict.remove_override("a")
        self.override_dict.remove_override("c")
        self.assertEqual(self.override_dict["a"], 42)
        self.assertEqual(self.override_dict.get("c"), None)
        self.assertFalse(self.override_dict.overridden("a"))
        self.assertFalse(self.override_dict.overridden("c"))

    def test_remove_override_removes_overridden_key(self):
        self.override_dict.override("a", 100)
        self.assertEqual(self.override_dict["a"], 100)
        self.assertEqual(len(self.override_dict), 1)
        self.override_dict.remove_override("a")
        self.assertEqual(len(self.override_dict), 0)
        self.assertNotIn("a", self.override_dict)

    def test_overriden_key_precededs_base_key_regardless_of_insert_order(self):
        self.override_dict.set_base("a", 42)
        self.override_dict.override("a", 100)
        self.override_dict.set_base("a", 0)

        self.assertEqual(self.override_dict["a"], 100)
        self.assertEqual(len(self.override_dict), 1)

    def test_bool_is_true_when_not_empty(self):
        if self.override_dict:
            self.fail("OverrideDict should be false when empty")
        self.override_dict.override("a", 1)
        if not self.override_dict:
            self.fail("OverrideDict should be true when not empty")
        self.override_dict.set_base("a", 42)
        if not self.override_dict:
            self.fail("OverrideDict should be true when not empty")
        self.override_dict.remove_override("a")
        if not self.override_dict:
            self.fail("OverrideDict should be true when not empty")


class TestRegistrationDecorators(common_utils.TestCase):
    def tearDown(self) -> None:
        registration.registry._registry.pop("test::test_op", None)

    def test_onnx_symbolic_registers_function(self):
        self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))

        @registration.onnx_symbolic("test::test_op", opset=9)
        def test(g, x):
            return g.op("test", x)

        self.assertTrue(registration.registry.is_registered_op("test::test_op", 9))
        function_group = registration.registry.get_function_group("test::test_op")
        assert function_group is not None
        self.assertEqual(function_group.get(9), test)

    def test_onnx_symbolic_registers_function_applied_decorator_when_provided(self):
        wrapper_called = False

        def decorator(func):
            def wrapper(*args, **kwargs):
                nonlocal wrapper_called
                wrapper_called = True
                return func(*args, **kwargs)

            return wrapper

        @registration.onnx_symbolic("test::test_op", opset=9, decorate=[decorator])
        def test():
            return

        function_group = registration.registry.get_function_group("test::test_op")
        assert function_group is not None
        registered_function = function_group[9]
        self.assertFalse(wrapper_called)
        registered_function()
        self.assertTrue(wrapper_called)

    def test_onnx_symbolic_raises_warning_when_overriding_function(self):
        self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))

        @registration.onnx_symbolic("test::test_op", opset=9)
        def test1():
            return

        with self.assertWarnsRegex(
            errors.OnnxExporterWarning,
            "Symbolic function 'test::test_op' already registered",
        ):

            @registration.onnx_symbolic("test::test_op", opset=9)
            def test2():
                return

    def test_custom_onnx_symbolic_registers_custom_function(self):
        self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))

        @registration.custom_onnx_symbolic("test::test_op", opset=9)
        def test(g, x):
            return g.op("test", x)

        self.assertTrue(registration.registry.is_registered_op("test::test_op", 9))
        function_group = registration.registry.get_function_group("test::test_op")
        assert function_group is not None
        self.assertEqual(function_group.get(9), test)

    def test_custom_onnx_symbolic_overrides_existing_function(self):
        self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))

        @registration.onnx_symbolic("test::test_op", opset=9)
        def test_original():
            return "original"

        self.assertTrue(registration.registry.is_registered_op("test::test_op", 9))

        @registration.custom_onnx_symbolic("test::test_op", opset=9)
        def test_custom():
            return "custom"

        function_group = registration.registry.get_function_group("test::test_op")
        assert function_group is not None
        self.assertEqual(function_group.get(9), test_custom)


if __name__ == "__main__":
    common_utils.run_tests()
