# Owner(s): ["module: onnx"]
from __future__ import annotations

import contextlib
import copy
import dataclasses
import os
import sys
import unittest
from typing import Tuple

import onnxruntime
from parameterized import parameterized

import torch
import torch._dynamo.backends.registry
from torch import nn
from torch.onnx import (
    _OrtBackend as OrtBackend,
    _OrtBackendOptions as OrtBackendOptions,
    ExportOptions,
)
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfNNModuleInlined


sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import onnx_test_common


def make_aot_ort(dynamic: bool = False):
    ort_backend = OrtBackend(
        options=OrtBackendOptions(
            export_options=ExportOptions(
                dynamic_shapes=dynamic,
            )
        )
    )
    return ort_backend, ort_backend


class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
    def setUp(self):
        super().setUp()
        torch._dynamo.reset()
        OrtBackend.clear_cached_instances()

    def tearDown(self):
        super().tearDown()
        torch._dynamo.reset()
        OrtBackend.clear_cached_instances()

    def test_get_ort_device_type(self):
        from onnxruntime.capi import _pybind_state as ORTC

        self.assertEqual(
            torch.onnx._internal.onnxruntime._get_ort_device_type("cuda"),
            ORTC.OrtDevice.cuda(),
        )
        self.assertEqual(
            torch.onnx._internal.onnxruntime._get_ort_device_type("cpu"),
            ORTC.OrtDevice.cpu(),
        )
        self.assertEqual(
            torch.onnx._internal.onnxruntime._get_ort_device_type("maia"),
            ORTC.OrtDevice.npu(),
        )

    def test_torch_compile_backend_registration(self):
        self.assertIn("onnxrt", torch._dynamo.backends.registry.list_backends())
        backend = torch._dynamo.backends.registry.lookup_backend("onnxrt")
        self.assertEqual(backend.__module__, "torch.onnx._internal.onnxruntime")

    def _test_torch_compile_backend_caching_assert_reused(
        self, options: OrtBackendOptions
    ):
        self.assertFalse(OrtBackend.get_cached_instances())  # assert setUp/tearDown
        new_backend = OrtBackend.get_cached_instance_for_options(options)
        reused_backend = OrtBackend.get_cached_instance_for_options(options)
        self.assertEqual(len(OrtBackend.get_cached_instances()), 1)
        self.assertIs(reused_backend, new_backend)
        if options is None or options.ort_session_options is None:
            # OrtBackendOptions.ort_session_options is a pybind11 object that
            # cannot be pickled via dataclasses.asdict
            self.assertEqual(
                new_backend,
                OrtBackend.get_cached_instance_for_options(
                    dataclasses.asdict(options) if options else None
                ),
            )

    @parameterized.expand(
        [
            (None,),
            (OrtBackendOptions(),),
            (OrtBackendOptions(use_aot_autograd=True),),
            (OrtBackendOptions(use_aot_autograd=False),),
            (OrtBackendOptions(preallocate_output=True),),
            (OrtBackendOptions(preallocate_output=False),),
            (OrtBackendOptions(infer_execution_providers=True),),
            (OrtBackendOptions(infer_execution_providers=False),),
            (OrtBackendOptions(preferred_execution_providers=["A", "B", "C"]),),
            (
                OrtBackendOptions(
                    preferred_execution_providers=["A", "B", ("C", {"option": "value"})]
                ),
            ),
            (OrtBackendOptions(default_execution_providers=["Something"]),),
            (
                OrtBackendOptions(
                    export_options=ExportOptions(
                        dynamic_shapes=True,
                    )
                ),
            ),
        ]
    )
    def test_torch_compile_backend_caching_assert_reused(
        self, options: OrtBackendOptions
    ):
        self._test_torch_compile_backend_caching_assert_reused(options)

    @parameterized.expand(
        [
            (OrtBackendOptions(ort_session_options=onnxruntime.SessionOptions()),),
        ]
    )
    def test_torch_compile_backend_caching_assert_not_reused(
        self, options: OrtBackendOptions
    ):
        with self.assertRaises(AssertionError):
            self._test_torch_compile_backend_caching_assert_reused(options)

    def _test_model_numerically(
        self,
        model,
        dynamo_backend,
        example_args_collection,
        fullgraph: bool = False,
        test_backward: bool = False,
        atol: float = 1e-5,
        rtol: float = 1e-6,
    ):
        """Run original and compiled model and compare the results.

        Args:
            model: The model to test.
            dynamo_backend: The dynamo backend to use. Here we use string `onnxrt` or
              the first returned value of `make_aot_ort(dynamic=True)`.
            example_args_collection: A tuple of example arguments to test. E.g.,
                (
                  (torch.randn(2), torch.randn(2)),
                  (torch.randn(4), torch.randn(4)),
                )
              if you want to test
                model(torch.randn(2), torch.randn(2)) and
                model(torch.randn(4), torch.randn(4))
              .
        """
        compiled_model = torch.compile(
            model if not isinstance(model, torch.nn.Module) else copy.deepcopy(model),
            backend=dynamo_backend,
            dynamic=True,
            fullgraph=fullgraph,
        )

        for example_args in example_args_collection:
            baseline_result = model(*example_args)
            result = compiled_model(*example_args)
            if isinstance(baseline_result, torch.Tensor):
                torch.testing.assert_close(
                    baseline_result, result, atol=atol, rtol=rtol
                )
                if test_backward:
                    baseline_result.sum().backward()
                    result.sum().backward()
                    for baseline_param, param in zip(
                        model.parameters(), compiled_model.parameters()
                    ):
                        torch.testing.assert_close(
                            baseline_param.grad, param.grad, atol=atol, rtol=rtol
                        )
            else:
                assert (
                    test_backward is False
                ), "Calculating backward with multiple outputs is not supported yet."
                for baseline_elem, result_elem in zip(baseline_result, result):
                    torch.testing.assert_close(
                        baseline_elem, result_elem, atol=atol, rtol=rtol
                    )

    def _assert_counting_information(
        self,
        ort_backend: OrtBackend,
        # Number of session runs.
        # If there is no graph break, this should be the same as
        # total number of forward calls.
        expected_execution_count: int,
        # Number of GraphModule's cached.
        # With one graph break, a model will be mapped
        # to two GraphModule's.
        number_of_cached_graph_modules: int,
        # Number of ONNX models cached for each GraphModule,
        # number_of_exported_onnx_models[i] contains # of ONNX models exported from
        # the i-th element (type: torch.fx.GraphModule) in
        # OrtBackend._all_ort_execution_info.execution_info_per_graph_module.values().
        number_of_exported_onnx_models_for_all_graph_modules: Tuple[int, ...],
    ):
        self.assertEqual(expected_execution_count, ort_backend.execution_count)
        self.assertEqual(
            len(ort_backend._all_ort_execution_info.execution_info_per_graph_module),
            number_of_cached_graph_modules,
        )
        self.assertEqual(
            len(ort_backend._all_ort_execution_info.execution_info_per_graph_module),
            len(number_of_exported_onnx_models_for_all_graph_modules),
        )
        for (
            onnx_info,
            expected_number_of_onnx_models,
        ) in zip(
            ort_backend._all_ort_execution_info.execution_info_per_graph_module.values(),
            number_of_exported_onnx_models_for_all_graph_modules,
        ):
            self.assertEqual(len(onnx_info), expected_number_of_onnx_models)

    def _assert_dynamic_input_and_output_shapes_in_all_onnx_models(self, backend):
        for (
            onnx_session_infos
        ) in backend._all_ort_execution_info.execution_info_per_graph_module.values():
            for onnx_session_info in onnx_session_infos:
                inputs_have_dynamic_shapes = False
                for input in onnx_session_info.input_value_infos:
                    if hasattr(input.type, "tensor_type") and hasattr(
                        input.type.tensor_type, "shape"
                    ):
                        for dim in input.type.tensor_type.shape.dim:
                            inputs_have_dynamic_shapes = (
                                inputs_have_dynamic_shapes or hasattr(dim, "dim_param")
                            )
                output_have_dynamic_shapes = False
                for output in onnx_session_info.output_value_infos:
                    if hasattr(output.type, "tensor_type") and hasattr(
                        output.type.tensor_type, "shape"
                    ):
                        for dim in output.type.tensor_type.shape.dim:
                            output_have_dynamic_shapes = (
                                output_have_dynamic_shapes or hasattr(dim, "dim_param")
                            )
                self.assertTrue(inputs_have_dynamic_shapes)
                self.assertTrue(output_have_dynamic_shapes)

    @parameterized.expand(
        [
            (True,),
            (False,),
        ]
    )
    def test_elementwise_function_single_output(self, test_local_backend: bool):
        example_args_collection = tuple(
            (torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 6, 8, 10)
        )

        def elementwise_model(x: torch.Tensor):
            y = x.relu()
            z = y.sigmoid()
            return z

        if test_local_backend:
            local_aot_ort, local_ort = make_aot_ort(dynamic=True)
        else:
            # This will use the global ONNXRuntime backend registered
            # in Dynamo to compile the tested model.
            local_aot_ort, local_ort = "onnxrt", None

        self._test_model_numerically(
            elementwise_model,
            local_aot_ort,
            example_args_collection,
        )

        # We can only check local backend's counting information
        # since global backend's counting information comes from
        # all compiled models.
        if test_local_backend:
            assert local_ort is not None
            self._assert_counting_information(
                local_ort,
                # OrtBackend._ort_acclerated_call should have been called 5 times because
                # we have 5 different batch sizes to test.
                expected_execution_count=len(example_args_collection),
                # Since this local_ort only compiled one function,
                # there should be only one GraphModule in its cached.
                number_of_cached_graph_modules=1,
                # Since dynamic shape is enabled, we should only have one ONNX model
                # to support different batch sizes.
                number_of_exported_onnx_models_for_all_graph_modules=(1,),
            )

    @parameterized.expand(
        [
            (True,),
            (False,),
        ]
    )
    def test_elementwise_function_multiple_output(self, test_local_backend: bool):
        example_args_collection = tuple(
            (torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 8)
        )

        def elementwise_model_with_multiple_outputs(w: torch.Tensor):
            x = w + w
            y = x.relu()
            z = y * y
            return x, y, z

        if test_local_backend:
            local_aot_ort, local_ort = make_aot_ort(dynamic=True)
        else:
            local_aot_ort, local_ort = "onnxrt", None

        self._test_model_numerically(
            elementwise_model_with_multiple_outputs,
            local_aot_ort,
            example_args_collection,
        )

        if test_local_backend:
            assert local_ort is not None
            self._assert_counting_information(
                local_ort,
                expected_execution_count=len(example_args_collection),
                number_of_cached_graph_modules=1,
                number_of_exported_onnx_models_for_all_graph_modules=(1,),
            )

    @parameterized.expand(
        [
            (True,),
            (False,),
        ]
    )
    def test_mlp_with_local_backend(self, test_local_backend: bool):
        example_args_collection = tuple(
            (torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8)
        )

        class MLP(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = nn.Linear(2, 4, bias=True)
                self.fc2 = nn.Linear(4, 2, bias=True)

            def forward(self, tensor_x: torch.Tensor):
                tensor_x = self.fc1(tensor_x)
                tensor_x = torch.sigmoid(tensor_x)
                tensor_x = self.fc2(tensor_x)
                tensor_x = torch.sigmoid(tensor_x)
                return tensor_x

        if test_local_backend:
            local_aot_ort, local_ort = make_aot_ort(dynamic=True)
        else:
            local_aot_ort, local_ort = "onnxrt", None

        self._test_model_numerically(
            MLP(),
            local_aot_ort,
            example_args_collection,
        )

        if test_local_backend:
            assert local_ort is not None
            self._assert_counting_information(
                local_ort,
                # OrtBackend._ort_acclerated_call should have been called 5 times because
                # we have 5 different batch sizes to test.
                expected_execution_count=len(example_args_collection),
                # Since this local_ort only compiled one function, there should be only two
                # GraphModule's in its cached. One for batch sizes 2, 4, 6, 8 and the other
                # for batch size 1.
                number_of_cached_graph_modules=2,
                # Since dynamic shape is enabled, we should only have one ONNX model
                # to support different batch sizes.
                number_of_exported_onnx_models_for_all_graph_modules=(1, 1),
            )

    @parameterized.expand(
        [
            (True, True),
            (True, False),
        ]
    )
    @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456")
    def test_llama_attention_with_local_backend(
        self, test_local_backend: bool, test_backward: bool
    ):
        from transformers import LlamaConfig  # noqa: F811
        from transformers.models.llama.modeling_llama import (  # noqa: F811
            LlamaAttention,
        )

        hidden_size = 16

        config = LlamaConfig(
            num_hidden_layers=1,
            vocab_size=1024,
            hidden_size=hidden_size,
            intermediate_size=16,
            max_position_embeddings=256,
            num_attention_heads=2,
            hidden_dropout_prob=0.0,
            attention_dropout_prob=0.0,
        )

        class LlamaAttentionWrapper(torch.nn.Module):
            def __init__(self, config):
                super().__init__()
                try:
                    # New version of LlamaAttention has layer_idx argument.
                    self.attention = LlamaAttention(config, layer_idx=0)
                except TypeError:
                    # Fall back to old version of LlamaAttention.
                    self.attention = LlamaAttention(config)

            def forward(self, hidden_states, attention_mask, position_ids):
                attn_output, _, _ = self.attention(
                    hidden_states, attention_mask, position_ids
                )
                return attn_output

        def generate_example_inputs(batch: int, seq: int, hidden_size: int):
            # shape: batch x seq x hidden_size
            hidden_state = torch.randn(batch, seq, hidden_size)
            # [0.0000e+00, ..., 0.0000e+00, -3.4028e+38, ...]
            # shape: batch x 1 x seq x seq
            attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float)
            position_ids = torch.arange(0, seq, dtype=torch.int64)
            position_ids = position_ids.unsqueeze(0).view(-1, seq)

            return hidden_state, attention_mask, position_ids

        # Reason for using multiple example argument groups:
        #  Export model to ONNX with one example argument group
        #  and test it with other example argument groups.
        example_args_collection = (
            generate_example_inputs(2, 8, hidden_size),
            generate_example_inputs(4, 7, hidden_size),
            generate_example_inputs(9, 15, hidden_size),
        )

        if test_local_backend:
            local_aot_ort, local_ort = make_aot_ort(dynamic=True)
        else:
            local_aot_ort, local_ort = "onnxrt", None

        model = LlamaAttentionWrapper(config).eval()

        self._test_model_numerically(
            model,
            local_aot_ort,
            example_args_collection,
            fullgraph=True,
            test_backward=test_backward,
        )

        if test_local_backend:
            assert local_ort is not None
            number_of_captured_graphs = 2 if test_backward else 1

            execution_count = len(example_args_collection) * number_of_captured_graphs
            self._assert_counting_information(
                local_ort,
                # Number of InferenceSession runs.
                expected_execution_count=execution_count,
                # Number of GraphModule's seen by ORT.
                number_of_cached_graph_modules=number_of_captured_graphs,
                # Number of InferenceSession's created per GraphModule.
                number_of_exported_onnx_models_for_all_graph_modules=(1,)
                * number_of_captured_graphs,
            )
            self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort)

    @parameterized.expand(
        [
            (True, False),
            (True, True),
        ]
    )
    @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456")
    def test_llama_decoder_with_local_backend(
        self, test_local_backend: bool, test_backward: bool
    ):
        from transformers import LlamaConfig  # noqa: F811
        from transformers.models.llama.modeling_llama import (  # noqa: F811
            LlamaDecoderLayer,
        )

        hidden_size = 16

        config = LlamaConfig(
            num_hidden_layers=1,
            vocab_size=1024,
            hidden_size=hidden_size,
            intermediate_size=16,
            max_position_embeddings=256,
            num_attention_heads=2,
            hidden_dropout_prob=0.0,
            attention_dropout_prob=0.0,
        )

        class LlamaDecoderWrapper(torch.nn.Module):
            def __init__(self, config):
                super().__init__()
                try:
                    # New version of LlamaDecoderLayer has layer_idx argument.
                    self.decoder = LlamaDecoderLayer(config, layer_idx=0)
                except TypeError:
                    # Fall back to old version of LlamaDecoderLayer.
                    self.decoder = LlamaDecoderLayer(config)

            def forward(self, hidden_states, attention_mask, position_ids):
                (decoder_output,) = self.decoder(
                    hidden_states, attention_mask, position_ids
                )
                return decoder_output

        def generate_example_inputs(batch: int, seq: int, hidden_size: int):
            # shape: batch x seq x hidden_size
            hidden_state = torch.randn(batch, seq, hidden_size)
            # [0.0000e+00, ..., 0.0000e+00, -3.4028e+38, ...]
            # shape: batch x 1 x seq x seq
            attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float)
            position_ids = torch.arange(0, seq, dtype=torch.int64)
            position_ids = position_ids.unsqueeze(0).view(-1, seq)
            return hidden_state, attention_mask, position_ids

        # Reason for using multiple example argument groups:
        #  Export model to ONNX with one example argument group
        #  and test it with other example argument groups.
        example_args_collection = (
            generate_example_inputs(2, 8, hidden_size),
            generate_example_inputs(4, 7, hidden_size),
            generate_example_inputs(9, 15, hidden_size),
        )

        if test_local_backend:
            local_aot_ort, local_ort = make_aot_ort(dynamic=True)
        else:
            local_aot_ort, local_ort = "onnxrt", None

        model = LlamaDecoderWrapper(config).eval()

        self._test_model_numerically(
            model,
            local_aot_ort,
            example_args_collection,
            fullgraph=True,
            test_backward=test_backward,
        )

        if test_local_backend:
            assert local_ort is not None
            number_of_captured_graphs = 2 if test_backward else 1

            execution_count = len(example_args_collection) * number_of_captured_graphs

            self._assert_counting_information(
                local_ort,
                expected_execution_count=execution_count,
                number_of_cached_graph_modules=number_of_captured_graphs,
                number_of_exported_onnx_models_for_all_graph_modules=(1,)
                * number_of_captured_graphs,
            )
            self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort)

    @parameterized.expand(
        [
            (True, False),
            (True, True),
        ]
    )
    @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456")
    def test_llama_with_local_backend(
        self, test_local_backend: bool, test_backward: bool
    ):
        from transformers import LlamaConfig  # noqa: F811
        from transformers.models.llama.modeling_llama import LlamaModel  # noqa: F811

        config = LlamaConfig(
            num_hidden_layers=1,
            vocab_size=1024,
            hidden_size=16,
            intermediate_size=16,
            max_position_embeddings=256,
            num_attention_heads=2,
            hidden_dropout_prob=0.0,
            attention_dropout_prob=0.0,
        )

        config._attn_implementation = "eager"

        class LlamaModelWrapper(torch.nn.Module):
            def __init__(self, config):
                super().__init__()
                self.llama = LlamaModel(config)

            def forward(self, input_ids, attention_mask, position_ids):
                decoder_output = self.llama(
                    input_ids, attention_mask, position_ids, return_dict=False
                )
                return decoder_output[0]

        def generate_example_inputs(batch: int, seq: int):
            # shape: batch x seq x hidden_size
            input_ids = torch.randint(0, 7, size=(batch, seq), dtype=torch.int64)
            # Usually, its shape is a tensor with shape batch x seq x seq.
            # However, to bypass some control flow in the model, we use None.
            attention_mask = None
            position_ids = torch.arange(0, seq, dtype=torch.int64)
            position_ids = position_ids.unsqueeze(0).view(-1, seq)
            return input_ids, attention_mask, position_ids

        # Reason for using multiple example argument groups:
        #  Export model to ONNX with one example argument group
        #  and test it with other example argument groups.
        example_args_collection = (
            generate_example_inputs(2, 8),
            generate_example_inputs(4, 7),
            generate_example_inputs(9, 15),
        )

        if test_local_backend:
            local_aot_ort, local_ort = make_aot_ort(dynamic=True)
        else:
            local_aot_ort, local_ort = "onnxrt", None

        model = LlamaModelWrapper(config).eval()

        self._test_model_numerically(
            model,
            local_aot_ort,
            example_args_collection,
            fullgraph=True,
            test_backward=test_backward,
            atol=1e-4,
            rtol=1e-4,
        )

        if test_local_backend:
            assert local_ort is not None
            number_of_captured_graphs = 2 if test_backward else 1
            execution_count = len(example_args_collection) * number_of_captured_graphs
            self._assert_counting_information(
                local_ort,
                expected_execution_count=execution_count,
                number_of_cached_graph_modules=number_of_captured_graphs,
                number_of_exported_onnx_models_for_all_graph_modules=(1,)
                * number_of_captured_graphs,
            )
            self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort)

    @parameterized.expand(
        [
            (True,),
            (False,),
        ]
    )
    def test_dump_model(self, test_local_backend: bool):
        @contextlib.contextmanager
        def onnxrt_dump_path(path):
            key = "ONNXRT_DUMP_PATH"
            before = os.environ.get(key, None)
            os.environ[key] = path
            yield
            if before is None:
                del os.environ[key]
            else:
                os.environ[key] = before

        example_args_collection = tuple(
            (torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8)
        )

        class MLP(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = nn.Linear(2, 4, bias=True)
                self.fc2 = nn.Linear(4, 2, bias=True)

            def forward(self, tensor_x: torch.Tensor):
                tensor_x = self.fc1(tensor_x)
                tensor_x = torch.sigmoid(tensor_x)
                tensor_x = self.fc2(tensor_x)
                tensor_x = torch.sigmoid(tensor_x)
                return tensor_x

        if test_local_backend:
            local_aot_ort, local_ort = make_aot_ort(dynamic=True)
        else:
            local_aot_ort, local_ort = "onnxrt", None

        prefix = f"test_dump_model_{'local' if test_local_backend else 'onnxrt'}_"
        expected = f"{prefix}0.onnx"
        expected_graph = f"{prefix}0.txt"
        if os.path.exists(expected):
            os.remove(expected)
        if os.path.exists(expected_graph):
            os.remove(expected_graph)
        not_expected = f"{prefix}1.onnx"
        self.assertFalse(os.path.exists(not_expected))

        model = MLP()
        compiled_model = torch.compile(
            model if not isinstance(model, torch.nn.Module) else copy.deepcopy(model),
            backend=local_aot_ort,
            dynamic=True,
        )

        self.assertFalse(os.path.exists(expected))
        self.assertFalse(os.path.exists(not_expected))

        with onnxrt_dump_path(prefix):
            example_args = example_args_collection[0]
            result = compiled_model(*example_args)
            self.assertTrue(os.path.exists(expected))
            self.assertTrue(os.path.exists(expected_graph))
            self.assertFalse(os.path.exists(not_expected))

            result = compiled_model(*example_args)
            self.assertTrue(os.path.exists(expected))
            self.assertFalse(os.path.exists(not_expected))

    @unittest.skipIf(not torch.cuda.is_available(), "No CUDA to run mix devicei nputs")
    def test_mix_device_inputs(self):
        data = torch.randn(4, 8, device="cuda")
        ref_data = torch.randn(8, 4, device="cpu")

        def reshape_wrapper(data, ref_cpu_data):
            # Dummy line to make sure ref_cpu_data
            # is included in the captured graph.
            ref_cpu_data += 1
            shape = ref_cpu_data.shape
            # A call with GPU and CPU inputs.
            return torch.reshape(data, shape)

        compiled_model = torch.compile(
            reshape_wrapper,
            backend="onnxrt",
            dynamic=True,
        )

        result = compiled_model(data, ref_data)

        self.assertTrue(torch.allclose(result, data.view(ref_data.shape)))

    def test_no_input(self):
        def reshape_wrapper():
            # A model without input.
            ones = torch.ones(4, 8)
            zeros = torch.zeros(4, 8)
            return ones + zeros

        recorded_models = []

        def record_onnx_model_transform(onnx_model):
            # Record the ONNX model seen by the transform.
            recorded_models.append(onnx_model)

        compiled_model = torch.compile(
            reshape_wrapper,
            backend="onnxrt",
            dynamic=True,
            options=torch.onnx._OrtBackendOptions(
                pre_ort_model_transforms=[
                    record_onnx_model_transform,
                ]
            ),
        )

        result = compiled_model()

        self.assertEqual(len(recorded_models), 1)
        # NOTE: Constant folded by optimizer
        self.assertTrue(
            "Constant" in [node.op_type for node in recorded_models[0].graph.node]
        )

        self.assertEqual(result, torch.ones(4, 8))

    def test_custom_onnx_transform(self):
        # This test consists of 2 parts:
        # 1. If a registered ONNX transform is called and recorded a model.
        # 2. If a registered ONNX transform is called and changed the model

        # Part 1: Record the ONNX model seen by the transform.
        # This list contains the models recorded by record_onnx_model_transform.
        recorded_models = []

        def record_onnx_model_transform(onnx_model):
            # Record the ONNX model seen by the transform.
            recorded_models.append(onnx_model)

        def example_model(x: torch.Tensor):
            y = torch.sigmoid(x)
            z = x + y
            return z

        compiled_model = torch.compile(
            example_model,
            backend="onnxrt",
            dynamic=True,
            options=torch.onnx._OrtBackendOptions(
                pre_ort_model_transforms=[record_onnx_model_transform]
            ),
        )

        x = torch.randn(2)
        assert len(recorded_models) == 0
        y = compiled_model(x)
        assert len(recorded_models) == 1

        # Part 2: Change the ONNX model seen by the transform so that
        # ORT receives a different model.
        # NOTE: the function is optimized away by optimizer
        def replace_relu_with_sigmoid(onnx_model):
            for node in onnx_model.graph.node:
                if node.op_type == "Relu":
                    node.op_type = "Sigmoid"

        def another_example_model(x: torch.Tensor):
            y = torch.relu(x)
            z = x + y
            return z

        another_compiled = torch.compile(
            another_example_model,
            backend="onnxrt",
            dynamic=True,
            options=torch.onnx._OrtBackendOptions(
                pre_ort_model_transforms=[
                    replace_relu_with_sigmoid,
                    record_onnx_model_transform,
                ]
            ),
        )

        another_y = another_compiled(x)
        # We have 2 models recorded `record_onnx_model_transform`
        # by the 2 torch.compile calls above.
        assert len(recorded_models) == 2
        # Since we have changed "Relu" to "Sigmoid" in replace_sigmoid_with_relu,
        # the result should be the same to previous y.
        torch.testing.assert_close(y, another_y)
        # another_example_model still uses "Relu", so the result should be different
        # than y.
        self.assertFalse(torch.allclose(y, another_example_model(x)))


if __name__ == "__main__":
    common_utils.run_tests()
