# Owner(s): ["module: onnx"]

import contextlib
import io
import tempfile
import unittest

import numpy as np

import onnx
import parameterized
import pytorch_test_common
from packaging import version

import torch
from torch.onnx import _constants, _experimental, verification
from torch.testing._internal import common_utils


class TestVerification(pytorch_test_common.ExportTestCase):
    def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
        class UnexportableModel(torch.nn.Module):
            def forward(self, x, y):
                # tensor.data() will be exported as a constant,
                # leading to wrong model output under different inputs.
                return x + y.data

        test_input_groups = [
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
        ]

        results = verification.check_export_model_diff(
            UnexportableModel(), test_input_groups
        )
        self.assertRegex(
            results,
            r"Graph diff:(.|\n)*"
            r"First diverging operator:(.|\n)*"
            r"prim::Constant(.|\n)*"
            r"Former source location:(.|\n)*"
            r"Latter source location:",
        )

    def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
        self,
    ):
        class UnexportableModel(torch.nn.Module):
            def forward(self, x, y):
                for i in range(x.size(0)):
                    y = x[i] + y
                return y

        test_input_groups = [
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
            ((torch.randn(4, 3), torch.randn(2, 3)), {}),
        ]

        export_options = _experimental.ExportOptions(
            input_names=["x", "y"], dynamic_axes={"x": [0]}
        )
        results = verification.check_export_model_diff(
            UnexportableModel(), test_input_groups, export_options
        )
        self.assertRegex(
            results,
            r"Graph diff:(.|\n)*"
            r"First diverging operator:(.|\n)*"
            r"prim::Constant(.|\n)*"
            r"Latter source location:(.|\n)*",
        )

    def test_check_export_model_diff_returns_empty_when_correct_export(self):
        class SupportedModel(torch.nn.Module):
            def forward(self, x, y):
                return x + y

        test_input_groups = [
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
        ]

        results = verification.check_export_model_diff(
            SupportedModel(), test_input_groups
        )
        self.assertEqual(results, "")

    def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
        self,
    ):
        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
        options = verification.VerificationOptions(
            rtol=1e-5,
            atol=1e-6,
            check_shape=True,
            check_dtype=False,
            ignore_none=True,
            acceptable_error_percentage=0.3,
        )
        verification._compare_onnx_pytorch_outputs(
            ort_outs,
            pytorch_outs,
            options,
        )

    def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
        self,
    ):
        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
        options = verification.VerificationOptions(
            rtol=1e-5,
            atol=1e-6,
            check_shape=True,
            check_dtype=False,
            ignore_none=True,
            acceptable_error_percentage=None,
        )
        with self.assertRaises(AssertionError):
            verification._compare_onnx_pytorch_outputs(
                ort_outs,
                pytorch_outs,
                options,
            )


@common_utils.instantiate_parametrized_tests
class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
    opset_version: int

    def setUp(self):
        super().setUp()

        def incorrect_add_symbolic_function(g, self, other, alpha):
            return self

        self.opset_version = _constants.ONNX_DEFAULT_OPSET
        torch.onnx.register_custom_op_symbolic(
            "aten::add",
            incorrect_add_symbolic_function,
            opset_version=self.opset_version,
        )

    def tearDown(self):
        super().tearDown()
        torch.onnx.unregister_custom_op_symbolic(
            "aten::add", opset_version=self.opset_version
        )

    @common_utils.parametrize(
        "onnx_backend",
        [
            common_utils.subtest(
                verification.OnnxBackend.REFERENCE,
                decorators=[
                    unittest.skipIf(
                        version.Version(onnx.__version__) < version.Version("1.13"),
                        reason="Reference Python runtime was introduced in 'onnx' 1.13.",
                    )
                ],
            ),
            verification.OnnxBackend.ONNX_RUNTIME_CPU,
        ],
    )
    def test_verify_found_mismatch_when_export_is_wrong(
        self, onnx_backend: verification.OnnxBackend
    ):
        class Model(torch.nn.Module):
            def forward(self, x):
                return x + 1

        with self.assertRaisesRegex(AssertionError, ".*Tensor-likes are not close!.*"):
            verification.verify(
                Model(),
                (torch.randn(2, 3),),
                opset_version=self.opset_version,
                options=verification.VerificationOptions(backend=onnx_backend),
            )


@parameterized.parameterized_class(
    [
        # TODO: enable this when ONNX submodule catches up to >= 1.13.
        # {"onnx_backend": verification.OnnxBackend.ONNX},
        {"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
    ],
    class_name_func=lambda cls,
    idx,
    input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
)
class TestFindMismatch(pytorch_test_common.ExportTestCase):
    onnx_backend: verification.OnnxBackend
    opset_version: int
    graph_info: verification.GraphInfo

    def setUp(self):
        super().setUp()
        self.opset_version = _constants.ONNX_DEFAULT_OPSET

        def incorrect_relu_symbolic_function(g, self):
            return g.op("Add", self, g.op("Constant", value_t=torch.tensor(1.0)))

        torch.onnx.register_custom_op_symbolic(
            "aten::relu",
            incorrect_relu_symbolic_function,
            opset_version=self.opset_version,
        )

        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.layers = torch.nn.Sequential(
                    torch.nn.Linear(3, 4),
                    torch.nn.ReLU(),
                    torch.nn.Linear(4, 5),
                    torch.nn.ReLU(),
                    torch.nn.Linear(5, 6),
                )

            def forward(self, x):
                return self.layers(x)

        self.graph_info = verification.find_mismatch(
            Model(),
            (torch.randn(2, 3),),
            opset_version=self.opset_version,
            options=verification.VerificationOptions(backend=self.onnx_backend),
        )

    def tearDown(self):
        super().tearDown()
        torch.onnx.unregister_custom_op_symbolic(
            "aten::relu", opset_version=self.opset_version
        )
        delattr(self, "opset_version")
        delattr(self, "graph_info")

    def test_pretty_print_tree_visualizes_mismatch(self):
        f = io.StringIO()
        with contextlib.redirect_stdout(f):
            self.graph_info.pretty_print_tree()
        self.assertExpected(f.getvalue())

    def test_preserve_mismatch_source_location(self):
        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()

        self.assertTrue(len(mismatch_leaves) > 0)

        for leaf_info in mismatch_leaves:
            f = io.StringIO()
            with contextlib.redirect_stdout(f):
                leaf_info.pretty_print_mismatch(graph=True)
            self.assertRegex(
                f.getvalue(),
                r"(.|\n)*" r"aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*",
            )

    def test_find_all_mismatch_operators(self):
        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()

        self.assertEqual(len(mismatch_leaves), 2)

        for leaf_info in mismatch_leaves:
            self.assertEqual(leaf_info.essential_node_count(), 1)
            self.assertEqual(leaf_info.essential_node_kinds(), {"aten::relu"})

    def test_find_mismatch_prints_correct_info_when_no_mismatch(self):
        self.maxDiff = None

        class Model(torch.nn.Module):
            def forward(self, x):
                return x + 1

        f = io.StringIO()
        with contextlib.redirect_stdout(f):
            verification.find_mismatch(
                Model(),
                (torch.randn(2, 3),),
                opset_version=self.opset_version,
                options=verification.VerificationOptions(backend=self.onnx_backend),
            )
        self.assertExpected(f.getvalue())

    def test_export_repro_for_mismatch(self):
        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
        self.assertTrue(len(mismatch_leaves) > 0)
        leaf_info = mismatch_leaves[0]
        with tempfile.TemporaryDirectory() as temp_dir:
            repro_dir = leaf_info.export_repro(temp_dir)

            with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
                options = verification.VerificationOptions(backend=self.onnx_backend)
                verification.OnnxTestCaseRepro(repro_dir).validate(options)


if __name__ == "__main__":
    common_utils.run_tests()
