# Owner(s): ["module: onnx"]

import copy
import functools
import io
import re
import warnings
from typing import Callable

import onnx

import parameterized
import pytorch_test_common
import torchvision
from autograd_helper import CustomFunction as CustomFunction2
from pytorch_test_common import (
    skipIfNoCuda,
    skipIfUnsupportedMaxOpsetVersion,
    skipIfUnsupportedMinOpsetVersion,
)

import torch
import torch.onnx
import torch.utils.cpp_extension
from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import _unpack_list, parse_args
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfNoLapack


def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
    """Remove test environment prefix added to module.

    Remove prefix to normalize scope names, since different test environments add
    prefixes with slight differences.

    Example:

        >>> _remove_test_environment_prefix_from_scope_name(
        >>>     "test_utility_funs.M"
        >>> )
        "M"
        >>> _remove_test_environment_prefix_from_scope_name(
        >>>     "test_utility_funs.test_abc.<locals>.M"
        >>> )
        "M"
        >>> _remove_test_environment_prefix_from_scope_name(
        >>>     "__main__.M"
        >>> )
        "M"
    """
    prefixes_to_remove = ["test_utility_funs", "__main__"]
    for prefix in prefixes_to_remove:
        scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name)
    return scope_name


class _BaseTestCase(pytorch_test_common.ExportTestCase):
    def _model_to_graph(
        self,
        model,
        input,
        do_constant_folding=True,
        training=TrainingMode.EVAL,
        operator_export_type=OperatorExportTypes.ONNX,
        input_names=None,
        dynamic_axes=None,
    ):
        torch.onnx.utils._setup_trace_module_map(model, False)
        if training == torch.onnx.TrainingMode.TRAINING:
            model.train()
        elif training == torch.onnx.TrainingMode.EVAL:
            model.eval()
        utils._validate_dynamic_axes(dynamic_axes, model, None, None)
        graph, params_dict, torch_out = utils._model_to_graph(
            model,
            input,
            do_constant_folding=do_constant_folding,
            _disable_torch_constant_prop=True,
            operator_export_type=operator_export_type,
            training=training,
            input_names=input_names,
            dynamic_axes=dynamic_axes,
        )
        return graph, params_dict, torch_out


@common_utils.instantiate_parametrized_tests
class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
    """Unit tests for the `unconvertible_ops` function."""

    def setUp(self):
        class EinsumModule(torch.nn.Module):
            def forward(self, x):
                return torch.einsum("ii", x)

        self.einsum_module = EinsumModule()

    def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self):
        x = torch.randn(4, 4)

        # Einsum is supported since opset 12. It should be unconvertible at opset 9.
        graph, unconvertible_ops = utils.unconvertible_ops(
            self.einsum_module, (x,), opset_version=9
        )
        nodes = graph.nodes()
        self.assertEqual(next(nodes).kind(), "prim::Constant")
        self.assertEqual(next(nodes).kind(), "prim::ListConstruct")
        self.assertEqual(next(nodes).kind(), "prim::Constant")
        self.assertEqual(next(nodes).kind(), "aten::einsum")
        self.assertEqual(unconvertible_ops, ["aten::einsum"])

    @common_utils.parametrize(
        "jit_function",
        [
            common_utils.subtest(
                functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
                name="traced",
            ),
            common_utils.subtest(torch.jit.script, name="scripted"),
        ],
    )
    def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module(
        self, jit_function: Callable
    ):
        module = jit_function(self.einsum_module)
        x = torch.randn(4, 4)

        # Einsum is supported since opset 12. It should be unconvertible at opset 9.
        _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9)
        self.assertEqual(unconvertible_ops, ["aten::einsum"])

    @common_utils.parametrize(
        "jit_function",
        [
            common_utils.subtest(lambda x: x, name="nn_module"),
            common_utils.subtest(
                functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
                name="traced",
            ),
            common_utils.subtest(torch.jit.script, name="scripted"),
        ],
    )
    def test_it_returns_empty_list_when_all_ops_convertible(
        self, jit_function: Callable
    ):
        module = jit_function(self.einsum_module)
        x = torch.randn(4, 4)

        # Einsum is supported since opset 12
        _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12)
        self.assertEqual(unconvertible_ops, [])

    def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self):
        class SkipConnectionModule(torch.nn.Module):
            def forward(self, x):
                out = x
                out += x
                out = torch.nn.functional.relu(out, inplace=True)
                return out

        module = SkipConnectionModule()
        x = torch.randn(4, 4)
        _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13)
        self.assertEqual(unconvertible_ops, [])


@parameterized.parameterized_class(
    [
        {"opset_version": opset}
        for opset in range(
            _constants.ONNX_BASE_OPSET,
            _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1,
        )
    ],
    class_name_func=lambda cls,
    num,
    params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
)
class TestUtilityFuns(_BaseTestCase):
    opset_version = None

    def test_is_in_onnx_export(self):
        test_self = self

        class MyModule(torch.nn.Module):
            def forward(self, x):
                test_self.assertTrue(torch.onnx.is_in_onnx_export())
                raise ValueError
                return x + 1

        x = torch.randn(3, 4)
        f = io.BytesIO()
        try:
            torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
        except ValueError:
            self.assertFalse(torch.onnx.is_in_onnx_export())

    def test_validate_dynamic_axes_invalid_input_output_name(self):
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            utils._validate_dynamic_axes(
                {"input1": {}, "output": {}, "invalid_name1": {}, "invalid_name2": {}},
                None,
                ["input1", "input2"],
                ["output"],
            )
            messages = [str(warning.message) for warning in w]
        self.assertIn(
            "Provided key invalid_name1 for dynamic axes is not a valid input/output name",
            messages,
        )
        self.assertIn(
            "Provided key invalid_name2 for dynamic axes is not a valid input/output name",
            messages,
        )
        self.assertEqual(len(messages), 2)

    @skipIfUnsupportedMinOpsetVersion(11)
    def test_split_to_slice(self):
        class SplitModule(torch.nn.Module):
            def forward(self, x, y, t):
                splits = (x.size(1), y.size(1))
                out, out2 = torch.split(t, splits, dim=1)
                return out, out2

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.randn(2, 3)
        y = torch.randn(2, 4)
        t = torch.randn(2, 7)
        graph, _, _ = self._model_to_graph(
            SplitModule(),
            (x, y, t),
            input_names=["x", "y", "t"],
            dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]},
        )
        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::SplitToSequence")

    def test_constant_fold_transpose(self):
        class TransposeModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
                b = torch.transpose(a, 1, 0)
                return b + x

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(3, 2)
        graph, _, __ = self._model_to_graph(
            TransposeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Transpose")
            self.assertNotEqual(node.kind(), "onnx::Cast")
        self.assertEqual(len(list(graph.nodes())), 2)

    @skipIfUnsupportedMaxOpsetVersion(17)
    def test_constant_fold_reduceL2(self):
        class ReduceModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
                b = torch.norm(a, p=2, dim=-2, keepdim=False)
                return b + x

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(2, 3)
        graph, _, __ = self._model_to_graph(
            ReduceModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::ReduceL2")

    @skipIfUnsupportedMaxOpsetVersion(17)
    def test_constant_fold_reduceL1(self):
        class NormModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
                b = torch.norm(a, p=1, dim=-2)
                return b + x

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(2, 3)
        graph, _, __ = self._model_to_graph(
            NormModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::ReduceL1")

    def test_constant_fold_slice(self):
        class NarrowModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
                b = torch.narrow(a, 0, 0, 1)
                return b + x

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(1, 3)
        graph, _, __ = self._model_to_graph(
            NarrowModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Slice")
            self.assertNotEqual(node.kind(), "onnx::Cast")
        self.assertEqual(len(list(graph.nodes())), 2)

    def test_constant_fold_slice_index_exceeds_dim(self):
        class SliceIndexExceedsDimModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
                b = a[1:10]  # index exceeds dimension
                return b + x

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(1, 3)
        graph, _, __ = self._model_to_graph(
            SliceIndexExceedsDimModule(),
            (x,),
            input_names=["x"],
            dynamic_axes={"x": [0, 1]},
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Slice")
            self.assertNotEqual(node.kind(), "onnx::Cast")
        self.assertEqual(len(list(graph.nodes())), 2)

    def test_constant_fold_slice_negative_index(self):
        class SliceNegativeIndexModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
                b = a[0:-1]  # index relative to the end
                c = torch.select(a, dim=-1, index=-2)
                d = torch.select(a, dim=1, index=0)
                return b + x, c + d

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(1, 3)
        graph, _, __ = self._model_to_graph(
            SliceNegativeIndexModule(),
            (x,),
            input_names=["x"],
            dynamic_axes={"x": [0, 1]},
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Slice")
            self.assertNotEqual(node.kind(), "onnx::Cast")

    def test_constant_fold_gather(self):
        class GatherModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
                b = torch.select(a, dim=1, index=-2)
                c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1]))
                return b + 1, c + x

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(1, 3)
        model = GatherModule()
        model(x)
        graph, _, __ = self._model_to_graph(
            GatherModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Gather")

    def test_constant_fold_unsqueeze(self):
        class UnsqueezeModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
                b = torch.unsqueeze(a, -2)
                return b + x

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(1, 2, 3)
        graph, _, __ = self._model_to_graph(
            UnsqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
            self.assertNotEqual(node.kind(), "onnx::Cast")
        self.assertEqual(len(list(graph.nodes())), 2)

    def test_constant_fold_unsqueeze_multi_axies(self):
        class PReluModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.prelu = torch.nn.PReLU()

            def forward(self, x):
                a = torch.randn(2, 3, 4, 5, 8, 7)
                return self.prelu(x) + a

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.randn(2, 3, 4, 5, 8, 7)
        graph, _, __ = self._model_to_graph(
            PReluModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3, 4, 5]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
            self.assertNotEqual(node.kind(), "onnx::Cast")
        self.assertEqual(len(list(graph.nodes())), 5)

    def test_constant_fold_squeeze_without_axes(self):
        class SqueezeModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
                return torch.squeeze(a) + x + torch.squeeze(a)

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(2, 3)
        graph, _, __ = self._model_to_graph(
            SqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )
        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Squeeze")
            self.assertNotEqual(node.kind(), "onnx::Cast")
        self.assertEqual(len(list(graph.nodes())), 4)

    def test_constant_fold_squeeze_with_axes(self):
        class SqueezeAxesModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
                return torch.squeeze(a, dim=-3) + x

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(2, 3)
        graph, _, __ = self._model_to_graph(
            SqueezeAxesModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Squeeze")
            self.assertNotEqual(node.kind(), "onnx::Cast")
        self.assertEqual(len(list(graph.nodes())), 2)

    def test_constant_fold_concat(self):
        class ConcatModule(torch.nn.Module):
            def forward(self, x):
                # Why did I insert a Cast here?  There appears to be intentional
                # behavior in ONNX constant folding where constant tensors which
                # are not attached to any known to be foldable onnx
                # operations don't get extracted into the initializer graph.  So
                # without these casts, we will actually fail to pull out one of
                # the constants, thus failing constant folding.  I think the
                # test is wrong but I don't have time to write a more correct
                # test (I think the right way to go about the test is to setup
                # a predicate for what invariant graphs should hold after
                # constant folding, and then verify this predicate holds.
                # I think the asserts below are an attempt at this predicate,
                # but it is not right!)
                #
                # More commentary at
                # https://github.com/pytorch/pytorch/pull/18698/files#r340107552
                a = torch.tensor([[1.0, 2.0, 3.0]]).to(torch.float)
                b = torch.tensor([[4.0, 5.0, 6.0]]).to(torch.float)
                c = torch.cat((a, b), 0)
                d = b + c
                return x + d

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.ones(2, 3)
        graph, _, __ = self._model_to_graph(
            ConcatModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Concat")
            self.assertNotEqual(node.kind(), "onnx::Cast")
        self.assertEqual(len(list(graph.nodes())), 2)

    def test_constant_fold_lstm(self):
        class GruNet(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)

            def forward(self, input, initial_state):
                return self.mygru(input, initial_state)

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        input = torch.randn(5, 3, 7)
        h0 = torch.randn(1, 3, 3)
        graph, _, __ = self._model_to_graph(
            GruNet(),
            (input, h0),
            input_names=["input", "h0"],
            dynamic_axes={"input": [0, 1, 2], "h0": [0, 1, 2]},
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Slice")
            self.assertNotEqual(node.kind(), "onnx::Concat")
            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")

        if self.opset_version <= 12:
            self.assertEqual(len(list(graph.nodes())), 3)
        else:
            # Unsqueeze op parameter "axes" as an input instead of as an attribute when opset version >= 13
            self.assertEqual(len(list(graph.nodes())), 4)

    def test_constant_fold_transpose_matmul(self):
        class MatMulNet(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.B = torch.nn.Parameter(torch.ones(5, 3))

            def forward(self, A):
                return torch.matmul(A, torch.transpose(self.B, -1, -2))

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        A = torch.randn(2, 3)
        graph, _, __ = self._model_to_graph(
            MatMulNet(), (A,), input_names=["A"], dynamic_axes={"A": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Transpose")
        self.assertEqual(len(list(graph.nodes())), 1)

    def test_constant_fold_reshape(self):
        class ReshapeModule(torch.nn.Module):
            def __init__(
                self,
            ):
                super().__init__()
                self.weight = torch.nn.Buffer(torch.ones(5))

            def forward(self, x):
                b = self.weight.reshape(1, -1, 1, 1)
                return x * b

        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        x = torch.randn(4, 5)
        graph, _, __ = self._model_to_graph(
            ReshapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Reshape")
        self.assertEqual(len(list(graph.nodes())), 1)

    def test_constant_fold_div(self):
        class Module(torch.nn.Module):
            def __init__(
                self,
            ):
                super().__init__()
                self.weight = torch.nn.Buffer(torch.ones(5))

            def forward(self, x):
                div = self.weight.div(torch.tensor([1, 2, 3, 4, 5]))
                return div * x

        x = torch.randn(2, 5)
        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        graph, _, __ = self._model_to_graph(
            Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Div")
        self.assertEqual(len(list(graph.nodes())), 1)

    def test_constant_fold_mul(self):
        class Module(torch.nn.Module):
            def __init__(
                self,
            ):
                super().__init__()
                self.weight = torch.nn.Buffer(torch.ones(5))

            def forward(self, x):
                mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5]))
                return mul / x

        x = torch.randn(2, 5)
        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        graph, _, __ = self._model_to_graph(
            Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Mul")
        self.assertEqual(len(list(graph.nodes())), 1)

    def test_constant_fold_add(self):
        class Module(torch.nn.Module):
            def __init__(
                self,
            ):
                super().__init__()
                self.weight = torch.nn.Buffer(torch.ones(5))

            def forward(self, x):
                add = self.weight + torch.tensor([1, 2, 3, 4, 5])
                return add - x

        x = torch.randn(2, 5)
        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        graph, params_dict, __ = self._model_to_graph(
            Module(),
            (x,),
            do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX,
            input_names=["x"],
            dynamic_axes={"x": [0, 1]},
        )
        for node in graph.nodes():
            self.assertTrue(node.kind() != "onnx::Add")
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        self.assertEqual(weight, torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0]))

    def test_constant_fold_sub(self):
        class Module(torch.nn.Module):
            def __init__(
                self,
            ):
                super().__init__()
                self.weight = torch.nn.Buffer(torch.ones(5))

            def forward(self, x):
                sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
                return sub + x

        x = torch.randn(2, 5)
        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        graph, params_dict, __ = self._model_to_graph(
            Module(),
            (x,),
            do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX,
            input_names=["x"],
            dynamic_axes={"x": [0, 1]},
        )
        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Sub")
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        self.assertEqual(weight, torch.tensor([0.0, -1.0, -2.0, -3.0, -4.0]))

    def test_constant_fold_sqrt(self):
        class Module(torch.nn.Module):
            def __init__(
                self,
            ):
                super().__init__()
                self.weight = torch.nn.Buffer(torch.ones(5))

            def forward(self, x):
                sqrt = torch.sqrt(self.weight)
                return sqrt / x

        x = torch.randn(2, 5)
        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        graph, _, __ = self._model_to_graph(
            Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )
        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Sqrt")
        self.assertEqual(len(list(graph.nodes())), 1)

    def test_constant_fold_shape(self):
        class ShapeModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = torch.nn.Buffer(torch.ones(5))

            def forward(self, x):
                shape = self.weight.shape[0]
                return x + shape

        x = torch.randn(2, 5)
        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        graph, _, __ = self._model_to_graph(
            ShapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
        )
        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Shape")
        self.assertEqual(len(list(graph.nodes())), 2)

    def test_constant_fold_upsample_scale_fold_as_constant(self):
        # upsample scale is a constant, not a model parameter,
        # therefore should not be added as initializer after constant folding.
        model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        x = torch.randn(1, 32, 224, 224)
        f = io.BytesIO()
        torch.onnx.export(model, x, f)
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        self.assertEqual(len(onnx_model.graph.initializer), 0)

    def test_verbose(self):
        class MyModule(torch.nn.Module):
            def forward(self, input):
                return torch.exp(input)

        x = torch.randn(3, 4)

        def is_model_stripped(f, verbose=None):
            if verbose is None:
                torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
            else:
                torch.onnx.export(
                    MyModule(), x, f, verbose=verbose, opset_version=self.opset_version
                )
            model = onnx.load(io.BytesIO(f.getvalue()))
            model_strip = copy.copy(model)
            onnx.helper.strip_doc_string(model_strip)
            return model == model_strip

        # test verbose=False (default)
        self.assertTrue(is_model_stripped(io.BytesIO()))
        # test verbose=True
        self.assertFalse(is_model_stripped(io.BytesIO(), True))

    # NB: remove this test once DataParallel can be correctly handled
    def test_error_on_data_parallel(self):
        model = torch.nn.DataParallel(torch.nn.ReflectionPad2d((1, 2, 3, 4)))
        x = torch.randn(1, 2, 3, 4)
        f = io.BytesIO()
        with self.assertRaisesRegex(
            ValueError,
            "torch.nn.DataParallel is not supported by ONNX "
            "exporter, please use 'attribute' module to "
            "unwrap model from torch.nn.DataParallel. Try ",
        ):
            torch.onnx.export(model, x, f, opset_version=self.opset_version)

    @skipIfUnsupportedMinOpsetVersion(11)
    def test_sequence_dim(self):
        class Module(torch.nn.Module):
            def forward(self, x, y):
                return [x, y]

        model = Module()
        # Export with scripting to keep output as Sequence type.
        # Tracing unpacks the list.
        script_model = torch.jit.script(model)
        x = torch.randn(2, 3)

        # Case 1: dynamic axis
        f = io.BytesIO()
        y = torch.randn(2, 3)
        torch.onnx.export(
            script_model,
            (x, y),
            f,
            opset_version=self.opset_version,
            input_names=["x", "y"],
            dynamic_axes={"y": [1]},
        )
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        loop_output_value_info_proto = onnx_model.graph.output[0]
        ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
            loop_output_value_info_proto.name, 1, [2, None]
        )
        self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)

        # Case 2: no dynamic axes.
        f = io.BytesIO()
        y = torch.randn(2, 3)
        torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version)
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        loop_output_value_info_proto = onnx_model.graph.output[0]
        ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
            loop_output_value_info_proto.name, 1, [2, 3]
        )
        self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)

    def test_export_mode(self):
        class MyModule(torch.nn.Module):
            def forward(self, x):
                y = x + 1
                return y

        model = MyModule()
        x = torch.randn(10, 3, 128, 128)
        f = io.BytesIO()

        # set mode to in inference mode and export in training mode
        model.eval()
        old_state = model.training
        torch.onnx.export(
            model,
            (x,),
            f,
            opset_version=self.opset_version,
            training=torch.onnx.TrainingMode.TRAINING,
        )
        # verify that the model state is preserved
        self.assertEqual(model.training, old_state)

        # set mode to training mode and export in inference mode
        model.train()
        old_state = model.training
        torch.onnx.export(
            model,
            (x,),
            f,
            opset_version=self.opset_version,
            training=torch.onnx.TrainingMode.EVAL,
        )
        # verify that the model state is preserved
        self.assertEqual(model.training, old_state)

    def test_export_does_not_fail_on_frozen_scripted_module(self):
        class Inner(torch.nn.Module):
            def forward(self, x):
                if x > 0:
                    return x
                else:
                    return x * x

        class Outer(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.inner = torch.jit.script(Inner())

            def forward(self, x):
                return self.inner(x)

        x = torch.zeros(1)
        # Freezing is only implemented in eval mode. So we need to call eval()
        outer_module = Outer().eval()
        module = torch.jit.trace_module(outer_module, {"forward": (x)})
        # jit.freeze removes the training attribute in the module
        module = torch.jit.freeze(module)

        torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version)

    @skipIfUnsupportedMinOpsetVersion(15)
    def test_local_function(self):
        class N(torch.nn.Module):
            def __init__(self, prob):
                super().__init__()
                self.dropout = torch.nn.Dropout(prob)

            def forward(self, x):
                return self.dropout(x)

        class M(torch.nn.Module):
            def __init__(self, num_layers):
                super().__init__()
                self.num_layers = num_layers
                self.lns = torch.nn.ModuleList(
                    [torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
                )
                self.celu1 = torch.nn.CELU(1.0)
                self.celu2 = torch.nn.CELU(2.0)
                self.dropout = N(0.5)

            def forward(self, x, y, z):
                res1 = self.celu1(x)
                res2 = self.celu2(y)
                for ln in self.lns:
                    z = ln(z)
                return res1 + res2, self.dropout(z)

        x = torch.randn(2, 3)
        y = torch.randn(2, 3)
        z = torch.randn(2, 3)

        # Export specified modules. Test against specifying modules that won't
        # exist in the exported model.
        # Model export in inference mode will remove dropout node,
        # thus the dropout module no longer exist in graph.
        f = io.BytesIO()
        torch.onnx.export(
            M(3),
            (x, y, z),
            f,
            opset_version=self.opset_version,
            export_modules_as_functions={
                torch.nn.CELU,
                torch.nn.Dropout,
                torch.nn.LayerNorm,
            },
        )

        onnx_model = onnx.load(io.BytesIO(f.getvalue()))

        # Check function definition
        funcs = onnx_model.functions
        celu_funcs = [f for f in funcs if f.name == "CELU"]
        self.assertEqual(len(celu_funcs), 1)
        self.assertEqual(celu_funcs[0].domain, "torch.nn.modules.activation")
        self.assertEqual(len(celu_funcs[0].attribute), 3)
        ln_funcs = [f for f in funcs if f.name == "LayerNorm"]
        self.assertEqual(len(ln_funcs), 1)
        self.assertEqual(ln_funcs[0].domain, "torch.nn.modules.normalization")
        self.assertEqual(len(ln_funcs[0].attribute), 3)

        # Check local function nodes
        nodes = onnx_model.graph.node
        celu_ns = [n for n in nodes if n.op_type == "CELU"]
        ln_ns = [n for n in nodes if n.op_type == "LayerNorm"]
        self.assertEqual(len(celu_ns), 2)
        self.assertEqual(celu_ns[0].domain, "torch.nn.modules.activation")
        self.assertEqual(len(celu_ns[0].attribute), 3)
        self.assertEqual(len(ln_ns), 3)
        self.assertEqual(ln_ns[0].domain, "torch.nn.modules.normalization")
        self.assertEqual(len(ln_ns[0].attribute), 3)

        # Export specified modules.
        f = io.BytesIO()
        torch.onnx.export(
            M(3),
            (x, y, z),
            f,
            opset_version=self.opset_version,
            export_modules_as_functions={torch.nn.CELU},
        )

        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        funcs = onnx_model.functions
        self.assertEqual(len(funcs), 1)
        self.assertEqual(funcs[0].name, "CELU")

        # Export with empty specified modules. Normal export.
        f = io.BytesIO()
        torch.onnx.export(
            M(3),
            (x, y, z),
            f,
            opset_version=self.opset_version,
            export_modules_as_functions=set(),
        )

        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        funcs = onnx_model.functions
        self.assertEqual(len(funcs), 0)

        # Export all modules. Should contain {M, CELU, LayerNorm}.
        f = io.BytesIO()
        torch.onnx.export(
            M(3),
            (x, y, z),
            f,
            opset_version=self.opset_version,
            export_modules_as_functions=True,
        )

        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        funcs = onnx_model.functions
        self.assertEqual(len(funcs), 3)

    @skipIfUnsupportedMinOpsetVersion(15)
    def test_local_function_overloads(self):
        class NWithOverloads(torch.nn.Module):
            def forward(self, x, y=None, z=None):
                if y is None:
                    return x + 1
                elif z is None:
                    return x + y
                else:
                    return x + y, x + z

        class M(torch.nn.Module):
            def __init__(self, num_layers):
                super().__init__()
                self.n = NWithOverloads()

            def forward(self, x, y, z):
                return self.n(x), self.n(x, y), self.n(x, y, z)

        x = torch.randn(2, 3)
        y = torch.randn(2, 3)
        z = torch.randn(2, 3)

        f = io.BytesIO()
        torch.onnx.export(
            M(3),
            (x, y, z),
            f,
            opset_version=self.opset_version,
            export_modules_as_functions={NWithOverloads},
        )

        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        funcs = onnx_model.functions
        self.assertEqual(len(funcs), 3)
        func_names = [f.name for f in funcs]
        self.assertIn("NWithOverloads", func_names)
        self.assertIn("NWithOverloads.1", func_names)
        self.assertIn("NWithOverloads.2", func_names)

    # Failing after ONNX 1.13.0
    @skipIfUnsupportedMaxOpsetVersion(1)
    def test_local_function_infer_scopes(self):
        class M(torch.nn.Module):
            def forward(self, x):
                # Concatenation of scalars inserts unscoped tensors in IR graph.
                new_tensor_shape = x.size()[:-1] + (1, 1, -1)
                tensor = x.view(*new_tensor_shape)
                return tensor

        x = torch.randn(4, 5)
        f = io.BytesIO()
        torch.onnx.export(
            M(),
            (x,),
            f,
            export_modules_as_functions=True,
            opset_version=self.opset_version,
            do_constant_folding=False,
        )

        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        funcs = onnx_model.functions
        self.assertIn("M", [f.name for f in funcs])

    @skipIfUnsupportedMinOpsetVersion(15)
    def test_local_function_predefined_attributes(self):
        class M(torch.nn.Module):
            num_layers: int

            def __init__(self, num_layers):
                super().__init__()
                self.num_layers = num_layers
                self.lns = torch.nn.ModuleList(
                    [torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
                )

            def forward(self, x):
                for ln in self.lns:
                    x = ln(x)
                return x

        x = torch.randn(2, 3)
        f = io.BytesIO()
        model = M(3)
        torch.onnx.export(
            model,
            (x,),
            f,
            export_modules_as_functions=True,
            opset_version=self.opset_version,
        )

        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        funcs = onnx_model.functions
        m_funcs = [fn for fn in funcs if fn.name == "M"]
        self.assertEqual(m_funcs[0].attribute, ["num_layers"])
        ln_funcs = [fn for fn in funcs if fn.name == "LayerNorm"]
        self.assertEqual(ln_funcs[0].attribute, ["eps", "elementwise_affine"])

        from onnx import helper

        m_node = [n for n in onnx_model.graph.node if n.op_type == "M"]
        self.assertEqual(
            m_node[0].attribute[0],
            helper.make_attribute("num_layers", model.num_layers),
        )

        ln_nodes = [n for n in m_funcs[0].node if n.op_type == "LayerNorm"]
        expected_ln_attrs = [
            helper.make_attribute(
                "elementwise_affine", model.lns[0].elementwise_affine
            ),
            helper.make_attribute("eps", model.lns[0].eps),
        ]
        for ln_node in ln_nodes:
            self.assertIn(ln_node.attribute[0], expected_ln_attrs)
            self.assertIn(ln_node.attribute[1], expected_ln_attrs)

    # This test cases checks the issue where an object does not have an attribute.
    # When enabling `export_modules_as_functions = True`, the exporter could return an
    # AttributeError. With this test case, we check that the export passes successfully
    # without any AttributeError exceptions.
    # See https://github.com/pytorch/pytorch/pull/109759 for an example. The exception that
    # this test tries to avoid is `AttributeError: 'Embedding' object has no attribute 'freeze'`.
    @skipIfUnsupportedMinOpsetVersion(15)
    def test_local_function_subset_of_predefined_attributes(self):
        class M(torch.nn.Module):
            num_layers: int

            def __init__(self, num_layers):
                super().__init__()
                self.embed_layer = torch.nn.Embedding.from_pretrained(
                    torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
                )
                self.num_layers = num_layers
                self.lns = torch.nn.ModuleList(
                    [torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
                )

            def forward(self, x):
                e = self.embed_layer(torch.LongTensor([1]))
                for ln in self.lns:
                    x = ln(x)
                return x, e

        x = torch.randn(2, 3)
        f = io.BytesIO()
        model = M(3)
        torch.onnx.export(
            model,
            (x,),
            f,
            export_modules_as_functions=True,
            opset_version=self.opset_version,
            verbose=True,  # Allows the test case to print `Skipping module attribute 'freeze'`
        )

    def test_node_scope(self):
        class N(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                return self.relu(x)

        class M(torch.nn.Module):
            def __init__(self, num_layers):
                super().__init__()
                self.num_layers = num_layers
                self.lns = torch.nn.ModuleList(
                    [torch.nn.LayerNorm(3, eps=float(i)) for i in range(num_layers)]
                )
                self.gelu1 = torch.nn.GELU()
                self.gelu2 = torch.nn.GELU()
                self.relu = N()

            def forward(self, x, y, z):
                res1 = self.gelu1(x)
                res2 = self.gelu2(y)
                for ln in self.lns:
                    z = ln(z)
                return res1 + res2, self.relu(z)

        x = torch.randn(2, 3)
        y = torch.randn(2, 3)
        z = torch.randn(2, 3)

        model = M(3)
        expected_scope_names = {
            "M::/torch.nn.modules.activation.GELU::gelu1",
            "M::/torch.nn.modules.activation.GELU::gelu2",
            "M::/torch.nn.modules.normalization.LayerNorm::lns.0",
            "M::/torch.nn.modules.normalization.LayerNorm::lns.1",
            "M::/torch.nn.modules.normalization.LayerNorm::lns.2",
            "M::/N::relu/torch.nn.modules.activation.ReLU::relu",
            "M::",
        }

        graph, _, _ = self._model_to_graph(
            model, (x, y, z), input_names=[], dynamic_axes={}
        )
        for node in graph.nodes():
            self.assertIn(
                _remove_test_environment_prefix_from_scope_name(node.scopeName()),
                expected_scope_names,
            )

        graph, _, _ = self._model_to_graph(
            torch.jit.script(model), (x, y, z), input_names=[], dynamic_axes={}
        )
        for node in graph.nodes():
            self.assertIn(
                _remove_test_environment_prefix_from_scope_name(node.scopeName()),
                expected_scope_names,
            )

    def test_scope_of_constants_when_combined_by_cse_pass(self):
        layer_num = 3

        class M(torch.nn.Module):
            def __init__(self, constant):
                super().__init__()
                self.constant = constant

            def forward(self, x):
                # 'self.constant' is designed to be the same for all layers,
                # hence it is common sub expression.
                return x + self.constant

        class N(torch.nn.Module):
            def __init__(self, layers: int = layer_num):
                super().__init__()
                self.layers = torch.nn.ModuleList(
                    [M(constant=torch.tensor(1.0)) for i in range(layers)]
                )

            def forward(self, x):
                for layer in self.layers:
                    x = layer(x)
                return x

        graph, _, _ = self._model_to_graph(
            N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
        )

        # NOTE: Duplicated constants are populated due to implicit casting in scalar_type_analysis,
        #       so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
        #       If CSE in exporter is improved later, this test needs to be updated.
        #       It should expect 1 constant, with same scope as root.
        expected_root_scope_name = "N::"
        expected_layer_scope_name = "M::layers"
        expected_constant_scope_name = [
            f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
            for i in range(layer_num)
        ]

        constant_scope_names = []
        for node in graph.nodes():
            if node.kind() == "onnx::Constant":
                constant_scope_names.append(
                    _remove_test_environment_prefix_from_scope_name(node.scopeName())
                )
        self.assertEqual(constant_scope_names, expected_constant_scope_name)

    def test_scope_of_nodes_when_combined_by_cse_pass(self):
        layer_num = 3

        class M(torch.nn.Module):
            def __init__(self, constant, bias):
                super().__init__()
                self.constant = constant
                self.bias = bias

            def forward(self, x):
                # 'constant' and 'x' is designed to be the same for all layers,
                # hence `x + self.constant` is common sub expression.
                # 'bias' is designed to be different for all layers,
                # hence `* self.bias` is not common sub expression.
                return (x + self.constant) * self.bias

        class N(torch.nn.Module):
            def __init__(self, layers: int = layer_num):
                super().__init__()

                self.layers = torch.nn.ModuleList(
                    [
                        M(constant=torch.tensor([1.0]), bias=torch.randn(1))
                        for i in range(layers)
                    ]
                )

            def forward(self, x):
                y = []
                for layer in self.layers:
                    y.append(layer(x))
                return y[0], y[1], y[2]

        graph, _, _ = self._model_to_graph(
            N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
        )
        expected_root_scope_name = "N::"
        expected_layer_scope_name = "M::layers"
        expected_add_scope_names = [
            f"{expected_root_scope_name}/{expected_layer_scope_name}.0"
        ]
        expected_mul_scope_names = [
            f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
            for i in range(layer_num)
        ]

        add_scope_names = []
        mul_scope_names = []
        for node in graph.nodes():
            if node.kind() == "onnx::Add":
                add_scope_names.append(
                    _remove_test_environment_prefix_from_scope_name(node.scopeName())
                )
            elif node.kind() == "onnx::Mul":
                mul_scope_names.append(
                    _remove_test_environment_prefix_from_scope_name(node.scopeName())
                )
        self.assertEqual(add_scope_names, expected_add_scope_names)
        self.assertEqual(mul_scope_names, expected_mul_scope_names)

    def test_aten_fallthrough(self):
        # Test aten export of op with no symbolic
        class Module(torch.nn.Module):
            def forward(self, x):
                return torch.erfc(x)

        x = torch.randn(2, 3, 4)
        GLOBALS.export_onnx_opset_version = self.opset_version
        graph, _, __ = self._model_to_graph(
            Module(),
            (x,),
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
            input_names=["x"],
            dynamic_axes={"x": [0, 1, 2]},
        )
        iter = graph.nodes()
        self.assertEqual(next(iter).kind(), "aten::erfc")

    def test_custom_op_fallthrough(self):
        # Test custom op
        op_source = """
        #include <torch/script.h>

        torch::Tensor custom_add(torch::Tensor self, torch::Tensor other) {
          return self + other;
        }

        static auto registry =
          torch::RegisterOperators("custom_namespace::custom_op", &custom_add);
        """

        torch.utils.cpp_extension.load_inline(
            name="custom_add",
            cpp_sources=op_source,
            is_python_module=False,
            verbose=True,
        )

        class FooModel(torch.nn.Module):
            def forward(self, input, other):
                # Calling custom op
                return torch.ops.custom_namespace.custom_op(input, other)

        x = torch.randn(2, 3, 4, requires_grad=False)
        y = torch.randn(2, 3, 4, requires_grad=False)
        model = FooModel()
        graph, _, __ = self._model_to_graph(
            model,
            (x, y),
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
            input_names=["x", "y"],
            dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
        )
        iter = graph.nodes()
        self.assertEqual(next(iter).kind(), "custom_namespace::custom_op")

    # gelu is exported as onnx::Gelu for opset >= 20
    @skipIfUnsupportedMaxOpsetVersion(19)
    def test_custom_opsets_gelu(self):
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9)

        def gelu(g, self, approximate):
            return g.op("com.microsoft::Gelu", self).setType(self.type())

        torch.onnx.register_custom_op_symbolic("::gelu", gelu, 9)
        model = torch.nn.GELU(approximate="none")
        x = torch.randn(3, 3)
        f = io.BytesIO()
        torch.onnx.export(
            model,
            (x,),
            f,
            opset_version=self.opset_version,
            custom_opsets={"com.microsoft": 1},
        )

        graph = onnx.load(io.BytesIO(f.getvalue()))
        self.assertEqual(graph.graph.node[0].op_type, "Gelu")
        self.assertEqual(graph.opset_import[0].version, self.opset_version)
        self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
        self.assertEqual(graph.opset_import[1].version, 1)

    # gelu is exported as onnx::Gelu for opset >= 20
    @skipIfUnsupportedMaxOpsetVersion(19)
    def test_register_aten_custom_op_symbolic(self):
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9)

        def gelu(g, self, approximate):
            return g.op("com.microsoft::Gelu", self).setType(self.type())

        torch.onnx.register_custom_op_symbolic("aten::gelu", gelu, 9)
        model = torch.nn.GELU(approximate="none")
        x = torch.randn(3, 3)
        f = io.BytesIO()
        torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
        graph = onnx.load(io.BytesIO(f.getvalue()))

        self.assertEqual(graph.graph.node[0].op_type, "Gelu")
        self.assertEqual(graph.opset_import[1].domain, "com.microsoft")

    @skipIfNoLapack
    def test_custom_opsets_inverse(self):
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)

        class CustomInverse(torch.nn.Module):
            def forward(self, x):
                return torch.inverse(x) + x

        def linalg_inv(g, self):
            return g.op("com.microsoft::Inverse", self).setType(self.type())

        torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv, 9)
        model = CustomInverse()
        x = torch.randn(2, 3, 3)
        f = io.BytesIO()
        torch.onnx.export(
            model,
            (x,),
            f,
            opset_version=self.opset_version,
            custom_opsets={"com.microsoft": 1},
        )

        graph = onnx.load(io.BytesIO(f.getvalue()))
        self.assertEqual(graph.graph.node[0].op_type, "Inverse")
        self.assertEqual(graph.opset_import[0].version, self.opset_version)
        self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
        self.assertEqual(graph.opset_import[1].version, 1)

    def test_onnx_fallthrough(self):
        # Test aten export of op with symbolic for aten
        class Module(torch.nn.Module):
            def forward(self, x):
                return torch.digamma(x)

        x = torch.randn(100, 128)
        graph, _, __ = self._model_to_graph(
            Module(),
            (x,),
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
            input_names=["x"],
            dynamic_axes={"x": [0, 1]},
        )
        iter = graph.nodes()
        self.assertEqual(next(iter).kind(), "aten::digamma")

    # prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
    @skipIfUnsupportedMaxOpsetVersion(10)
    def test_prim_fallthrough(self):
        # Test prim op
        class PrimModule(torch.jit.ScriptModule):
            @torch.jit.script_method
            def forward(self, x):
                if isinstance(x, list):
                    y = x
                else:
                    y = [x]
                return y

        x = torch.tensor([2])
        model = PrimModule()
        model.eval()
        graph, _, __ = self._model_to_graph(
            model,
            (x,),
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
            input_names=["x"],
            dynamic_axes={"x": [0]},
        )
        iter = graph.nodes()
        self.assertEqual(next(iter).kind(), "prim::ListConstruct")

    def test_custom_layer_tuple(self):
        class CustomFunction(torch.autograd.Function):
            @staticmethod
            def symbolic(g, input):
                return g.op("CustomNamespace::Custom", input, outputs=2)

            @staticmethod
            def forward(ctx, input):
                return input, input

        class Custom(torch.nn.Module):
            def forward(self, input):
                return CustomFunction.apply(input)

        model = Custom()
        batch = torch.FloatTensor(1, 3)

        graph, _, _ = self._model_to_graph(
            model, batch, input_names=["batch"], dynamic_axes={"batch": [0, 1]}
        )
        iter = graph.nodes()
        self.assertEqual(next(iter).kind(), "CustomNamespace::Custom")

    def test_autograd_onnx_fallthrough(self):
        class CustomFunction(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                ctx.save_for_backward(input)
                return input.clamp(min=0)

            @staticmethod
            def backward(ctx, grad_output):
                (input,) = ctx.saved_tensors
                grad_input = grad_output.clone()
                grad_input[input < 0] = 0
                return grad_input

        class Custom(torch.nn.Module):
            def forward(self, input):
                return CustomFunction.apply(input)

        model = Custom()
        batch = torch.FloatTensor(1, 3)

        graph, _, _ = self._model_to_graph(
            model,
            batch,
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
            input_names=["batch"],
            dynamic_axes={"batch": [0, 1]},
        )
        iter = graph.nodes()
        self.assertEqual(next(iter).kind(), "prim::PythonOp")

    def test_autograd_module_name(self):
        class CustomFunction(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                ctx.save_for_backward(input)
                return input.clamp(min=0)

            @staticmethod
            def backward(ctx, grad_output):
                (input,) = ctx.saved_tensors
                grad_input = grad_output.clone()
                grad_input[input < 0] = 0
                return grad_input

        class Custom(torch.nn.Module):
            def forward(self, input):
                return CustomFunction.apply(input) + CustomFunction2.apply(input)

        model = Custom()
        batch = torch.FloatTensor(1, 3)

        graph, _, _ = self._model_to_graph(
            model,
            batch,
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
            input_names=["batch"],
            dynamic_axes={"batch": [0, 1]},
        )
        iter = graph.nodes()
        autograd1 = next(iter)
        autograd2 = next(iter)
        self.assertEqual(autograd1.kind(), "prim::PythonOp")
        self.assertEqual(autograd2.kind(), "prim::PythonOp")
        self.assertNotEqual(autograd1.s("module"), autograd2.s("module"))

    def test_unused_initializers(self):
        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv2 = torch.nn.ConvTranspose2d(
                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1)
                )
                self.k_proj = torch.nn.Linear(5, 5, bias=True)

            def forward(self, x):
                x = self.conv2(x)
                return x

        x = torch.randn(20, 16, 50, 100)
        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        _, params_dict, __ = self._model_to_graph(
            Model(),
            (x,),
            do_constant_folding=False,
            operator_export_type=OperatorExportTypes.ONNX,
            input_names=["x"],
            dynamic_axes={"x": [0, 1, 2, 3]},
        )

        self.assertEqual(len(params_dict), 2)

    def test_scripting_param(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(
                    3, 16, kernel_size=1, stride=2, padding=3, bias=True
                )
                self.bn = torch.nn.BatchNorm2d(16, affine=True)

            def forward(self, x):
                x = self.conv(x)
                bn = self.bn(x)
                return bn

        model = torch.jit.script(MyModule())
        x = torch.randn(10, 3, 128, 128)
        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        graph, _, __ = self._model_to_graph(
            model,
            (x,),
            do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX,
            training=torch.onnx.TrainingMode.TRAINING,
            input_names=["x"],
            dynamic_axes={"x": [0, 1, 2, 3]},
        )

        graph_input_params = [param.debugName() for param in graph.inputs()]
        for item in dict(model.named_parameters()):
            self.assertIn(
                item,
                graph_input_params,
                "Graph parameter names does not match model parameters.",
            )

    def test_fuse_conv_bn(self):
        class Fuse(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(
                    3, 2, kernel_size=1, stride=2, padding=3, bias=True
                )
                self.bn = torch.nn.BatchNorm2d(2)

            def forward(self, x):
                out = self.conv(x)
                return self.bn(out)

        x = torch.randn(2, 3, 2, 2, requires_grad=True)
        graph, _, __ = self._model_to_graph(
            Fuse(),
            (x,),
            training=TrainingMode.EVAL,
            input_names=["x"],
            dynamic_axes={"x": [0, 1, 2, 3]},
        )
        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::BatchNormalization")
            self.assertEqual(node.kind(), "onnx::Conv")

        self.assertEqual(len(list(graph.nodes())), 1)

    def test_fuse_resnet18(self):
        model = torchvision.models.resnet18(weights=None)
        x = torch.randn(2, 3, 224, 224, requires_grad=True)
        graph, _, __ = self._model_to_graph(
            model,
            (x,),
            training=TrainingMode.EVAL,
            input_names=["x"],
            dynamic_axes={"x": [0, 1, 2, 3]},
        )

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::BatchNormalization")

    def test_onnx_function_substitution_pass(self):
        @torch.jit.script
        def f(x: torch.Tensor, y: torch.Tensor):
            z = x - y
            return x + z

        class MyModule(torch.nn.Module):
            def forward(self, x, y):
                return f(x, y)

        input_1 = torch.tensor([11])
        input_2 = torch.tensor([12])
        GLOBALS.export_onnx_opset_version = self.opset_version
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
        graph, _, __ = self._model_to_graph(
            MyModule(),
            (input_1, input_2),
            do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX,
            input_names=["input_1", "input_2"],
            dynamic_axes={"input_1": [0], "input_2": [0]},
        )
        # Check that the prim::Constant node in the graph for representing the
        # scripted function `f` is removed and the following prim::CallFunction
        # is replced by inline graph, with onnx::Sub and onnx::Add nodes.
        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "prim::Constant")
        self.assertEqual(
            len(list(graph.nodes())), 2
        )  # onnx::Sub and onnx::Add nodes only.

    def test_onnx_value_name(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.in_weight = torch.nn.Parameter(torch.Tensor(3, 3))
                self.in_bias = torch.nn.Parameter(torch.Tensor(3))

            def forward(self, x):
                start = 0
                end = None
                weight = self.in_weight
                bias = self.in_bias
                weight = weight[start:end, :]
                if bias is not None:
                    bias = bias[start:end]
                return torch.nn.functional.linear(x, weight, bias)

        model = MyModule()
        x = torch.randn(3, 3)
        f = io.BytesIO()

        model.eval()
        torch.onnx.export(
            model,
            (x,),
            f,
            opset_version=self.opset_version,
            keep_initializers_as_inputs=True,
        )
        graph = onnx.load(io.BytesIO(f.getvalue()))
        self.assertEqual(graph.graph.input[1].name, "in_weight")
        self.assertEqual(graph.graph.input[2].name, "in_bias")

    def test_onnx_node_naming(self):
        class MainModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self._module_1 = torch.nn.Linear(10, 10)
                self._module_2 = torch.nn.Linear(10, 10)
                self._module_3 = torch.nn.Linear(10, 10)
                self._module_4 = torch.nn.Linear(10, 10)

            def forward(self, x):
                y = self._module_1(x)
                z = self._module_2(y)
                z = self._module_3(y * z)
                z = self._module_4(y * z)
                return z

        module = MainModule()
        ref_node_names = [
            "/_module_1/Gemm",
            "/_module_2/Gemm",
            "/_module_3/Gemm",
            "/_module_4/Gemm",
            "/Mul",
            "/Mul_1",
        ]
        f = io.BytesIO()

        torch.onnx.export(module, torch.ones(1, 10), f, output_names=["y"])
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        for n in onnx_model.graph.node:
            self.assertIn(n.name, ref_node_names)

        torch.onnx.export(
            torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"]
        )
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        for n in onnx_model.graph.node:
            self.assertIn(n.name, ref_node_names)

    def _test_deduplicate_initializers(self, torchscript=False):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.layer1 = torch.nn.Linear(3, 3)
                self.layer2 = torch.nn.Linear(3, 3)

                # Reusing layers.
                self.layer3 = self.layer1

                # Reusing parameters.
                self.layer2.weight = self.layer1.weight
                self.layer1.bias = self.layer2.bias

                # Parameter with different tensors equal in value.
                self.param1 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
                self.param2 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))

            def forward(self, x):
                return (
                    self.layer3(self.layer2(self.layer1(x))) + self.param1 + self.param2
                )

        model = torch.jit.script(MyModule()) if torchscript else MyModule()

        x = torch.randn(3, 3)
        param_name_set = {k for k, _ in model.named_parameters()}

        # Test training mode.
        model.train()
        f = io.BytesIO()
        torch.onnx.export(
            model,
            (x,),
            f,
            training=TrainingMode.TRAINING,
            opset_version=self.opset_version,
        )
        graph = onnx.load(io.BytesIO(f.getvalue()))
        self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)

        model.train()
        f = io.BytesIO()
        torch.onnx.export(
            model,
            (x,),
            f,
            training=TrainingMode.PRESERVE,
            opset_version=self.opset_version,
        )
        graph = onnx.load(io.BytesIO(f.getvalue()))
        self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)

        # Test eval mode.
        model.eval()
        f = io.BytesIO()
        torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
        graph = onnx.load(io.BytesIO(f.getvalue()))
        param_name_set.remove("param2")
        self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)

    def test_deduplicate_initializers(self):
        self._test_deduplicate_initializers(torchscript=False)

    def test_deduplicate_initializers_torchscript(self):
        self._test_deduplicate_initializers(torchscript=True)

    @skipIfNoCuda
    def test_deduplicate_initializers_diff_devices(self):
        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w_cpu = torch.nn.Parameter(
                    torch.ones(3, device=torch.device("cpu"))
                )
                self.w_cuda = torch.nn.Parameter(
                    torch.ones(3, device=torch.device("cuda"))
                )

            def forward(self, x, y):
                return x + self.w_cpu, y + self.w_cuda

        x = torch.randn(3, 3, device=torch.device("cpu"))
        y = torch.randn(3, 3, device=torch.device("cuda"))
        f = io.BytesIO()
        torch.onnx.export(Model(), (x, y), f, opset_version=self.opset_version)
        graph = onnx.load(io.BytesIO(f.getvalue()))
        self.assertSetEqual({i.name for i in graph.graph.initializer}, {"w_cpu"})

    def test_duplicated_output_node(self):
        class DuplicatedOutputNet(torch.nn.Module):
            def __init__(self, input_size, num_classes):
                super().__init__()
                self.fc1 = torch.nn.Linear(input_size, num_classes)

            def forward(self, input0, input1):
                out1 = self.fc1(input0)
                out2 = self.fc1(input1)
                return out1, out1, out2, out1, out2

        N, D_in, H, D_out = 64, 784, 500, 10
        pt_model = DuplicatedOutputNet(D_in, D_out)

        f = io.BytesIO()
        x = torch.randn(N, D_in)
        dynamic_axes = {
            "input0": {0: "input0_dim0", 1: "input0_dim1"},
            "input1": {0: "input1_dim0", 1: "input1_dim1"},
            "output-0": {0: "output-0_dim0", 1: "output-0_dim1"},
            "output-1": {0: "output-1_dim0", 1: "output-1_dim1"},
            "output-2": {0: "output-2_dim0", 1: "output-2_dim1"},
            "output-3": {0: "output-3_dim0", 1: "output-3_dim1"},
            "output-4": {0: "output-4_dim0", 1: "output-4_dim1"},
        }

        torch.onnx.export(
            pt_model,
            (x, x),
            f,
            input_names=["input0", "input1"],
            output_names=["output-0", "output-1", "output-2", "output-3", "output-4"],
            do_constant_folding=False,
            training=torch.onnx.TrainingMode.TRAINING,
            dynamic_axes=dynamic_axes,
            verbose=True,
            keep_initializers_as_inputs=True,
        )

        graph = onnx.load(io.BytesIO(f.getvalue()))
        self.assertEqual(graph.graph.input[0].name, "input0")
        self.assertEqual(graph.graph.input[1].name, "input1")
        for i in range(5):
            self.assertEqual(graph.graph.output[i].name, f"output-{i}")
        self.assertEqual(graph.graph.node[0].op_type, "Gemm")
        self.assertEqual(graph.graph.node[1].op_type, "Identity")
        self.assertEqual(graph.graph.node[2].op_type, "Identity")
        self.assertEqual(graph.graph.node[3].op_type, "Gemm")
        self.assertEqual(graph.graph.node[4].op_type, "Identity")

    def test_deduplicate_ignore_upsample_scale(self):
        # upsample scale is a constant, not a model parameter,
        # therefore should be ignored by shared weight deduplication.
        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.upsample_1 = torch.nn.Upsample(scale_factor=2)
                self.upsample_2 = torch.nn.Upsample(scale_factor=2)

            def forward(self, x):
                return self.upsample_1(x), self.upsample_2(x)

        f = io.BytesIO()
        x = torch.randn(1, 32, 224, 224)
        torch.onnx.export(Model(), x, f)
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
        # aten::upsample converts to onnx::resize
        resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"]
        self.assertEqual(len(resize_nodes), 2)
        for resize_node in resize_nodes:
            scale_node = [
                n for n in onnx_model.graph.node if n.output[0] == resize_node.input[2]
            ]
            self.assertEqual(len(scale_node), 1)
            self.assertEqual(scale_node[0].op_type, "Constant")

    def test_bad_symbolic_registration(self):
        _onnx_opset_version = 9

        @parse_args("v")
        def cat(g, tensor_list, dim):
            tensors = _unpack_list(tensor_list)
            return g.op("Concat", *tensors, axis_i=dim)

        torch.onnx.register_custom_op_symbolic("::cat", cat, _onnx_opset_version)

        class CatModel(torch.nn.Module):
            def forward(self, x):
                return torch.cat((x, x, x), 0)

        model = CatModel()
        x = torch.randn(2, 3)
        f = io.BytesIO()
        self.assertExpectedRaisesInline(
            AssertionError,
            lambda: torch.onnx.export(
                model, (x,), f, opset_version=_onnx_opset_version
            ),
            (
                "A mismatch between the number of arguments (2) and their descriptors (1) was found at symbolic function "
                "'cat'. If you believe this is not due to custom symbolic implementation within your code or an external "
                "library, please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to "
                "report this bug."
            ),
        )
        torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version)


if __name__ == "__main__":
    common_utils.run_tests()
