# Owner(s): ["module: onnx"]

import io
import itertools

import onnx

import pytorch_test_common

import torch
import torch.onnx
from torch.nn import Module
from torch.onnx import producer_name, producer_version
from torch.onnx._globals import GLOBALS
from torch.testing._internal import common_utils


def check_onnx_opset_operator(
    model, ops, opset_version=GLOBALS.export_onnx_opset_version
):
    # check_onnx_components
    assert (
        model.producer_name == producer_name
        and model.producer_version == producer_version
        and model.opset_import[0].version == opset_version
    )

    # check the schema with the onnx checker
    onnx.checker.check_model(model)

    # check target type and attributes
    graph = model.graph
    # ops should contain an object for each node
    # in graph.node, in the right order.
    # At least the op_name should be specified,
    # but the op's attributes can optionally be
    # specified as well
    assert len(ops) == len(graph.node)
    for i in range(0, len(ops)):
        assert graph.node[i].op_type == ops[i]["op_name"]
        if "attributes" in ops[i]:
            attributes = ops[i]["attributes"]
            assert len(attributes) == len(graph.node[i].attribute)
            for j in range(0, len(attributes)):
                for attribute_field in attributes[j].keys():
                    assert attributes[j][attribute_field] == getattr(
                        graph.node[i].attribute[j], attribute_field
                    )


def check_onnx_opsets_operator(
    module,
    x,
    ops,
    opset_versions,
    training=torch.onnx.TrainingMode.EVAL,
    input_names=None,
    dynamic_axes=None,
):
    for opset_version in opset_versions:
        f = io.BytesIO()
        torch.onnx.export(
            module,
            x,
            f,
            opset_version=opset_version,
            training=training,
            input_names=input_names,
            dynamic_axes=dynamic_axes,
        )
        model = onnx.load(io.BytesIO(f.getvalue()))
        check_onnx_opset_operator(model, ops[opset_version], opset_version)


class TestONNXOpset(pytorch_test_common.ExportTestCase):
    def test_opset_fallback(self):
        class MyModule(Module):
            def forward(self, x):
                return torch.isnan(x)

        ops = [{"op_name": "IsNaN"}]
        ops = {9: ops, 10: ops}
        x = torch.tensor([1.0, float("nan"), 2.0])
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])

    def test_topk(self):
        class MyModule(Module):
            def forward(self, x):
                return torch.topk(x, 3)

        ops_9 = [
            {
                "op_name": "TopK",
                "attributes": [
                    {"name": "axis", "i": -1, "type": 2},
                    {"name": "k", "i": 3, "type": 2},
                ],
            }
        ]
        ops_10 = [
            {"op_name": "Constant"},
            {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
        ]
        ops = {9: ops_9, 10: ops_10}
        x = torch.arange(1.0, 6.0, requires_grad=True)
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])

        # test with dynamic k
        class MyModuleDynamic(torch.jit.ScriptModule):
            @torch.jit.script_method
            def forward(self, input, k):
                return torch.topk(input, k)

        ops_10 = [
            {"op_name": "Constant", "attributes": [{"name": "value", "type": 4}]},
            {"op_name": "Reshape"},
            {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
        ]
        ops = {10: ops_10}
        x = torch.arange(1.0, 6.0, requires_grad=True)
        k = torch.tensor(3)
        module = MyModuleDynamic()
        check_onnx_opsets_operator(module, (x, k), ops, opset_versions=[10])

    def test_maxpool(self):
        module = torch.nn.MaxPool1d(2, stride=1)

        ops_9 = [
            {
                "op_name": "MaxPool",
                "attributes": [
                    {"name": "kernel_shape", "ints": [2], "type": 7},
                    {"name": "pads", "ints": [0, 0], "type": 7},
                    {"name": "strides", "ints": [1], "type": 7},
                ],
            }
        ]
        ops_10 = [
            {
                "op_name": "MaxPool",
                "attributes": [
                    {"name": "ceil_mode", "i": 0, "type": 2},
                    {"name": "dilations", "ints": [1], "type": 7},
                    {"name": "kernel_shape", "ints": [2], "type": 7},
                    {"name": "pads", "ints": [0, 0], "type": 7},
                    {"name": "strides", "ints": [1], "type": 7},
                ],
            }
        ]
        ops = {9: ops_9, 10: ops_10}
        x = torch.randn(20, 16, 50)
        check_onnx_opsets_operator(module, x, ops, opset_versions=[9, 10])

        # add test with dilations
        module = torch.nn.MaxPool1d(2, stride=1, dilation=2)

        ops_10 = [
            {
                "op_name": "MaxPool",
                "attributes": [
                    {"name": "ceil_mode", "i": 0, "type": 2},
                    {"name": "dilations", "ints": [2], "type": 7},
                    {"name": "kernel_shape", "ints": [2], "type": 7},
                    {"name": "pads", "ints": [0, 0], "type": 7},
                    {"name": "strides", "ints": [1], "type": 7},
                ],
            }
        ]
        ops = {10: ops_10}
        x = torch.randn(20, 16, 50)
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])

    def test_upsample(self):
        class MyModule(Module):
            def forward(self, x):
                size = [v * 2 for v in x.size()[2:]]
                size = [int(i) for i in size]
                return torch.nn.functional.interpolate(x, size=size, mode="nearest")

        module = MyModule()
        ops8 = [
            {
                "op_name": "Upsample",
                "attributes": [
                    {"name": "mode", "s": (b"nearest"), "type": 3},
                    {"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6},
                ],
            }
        ]
        ops9 = [
            {"op_name": "Constant"},
            {
                "op_name": "Upsample",
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
            },
        ]
        ops = {8: ops8, 9: ops9}
        x = torch.randn(2, 2, 2, 2)
        check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])

    def test_cast_constant(self):
        class MyModule(Module):
            def forward(self, x):
                return x - 1

        module = MyModule()
        ops_8 = [
            {"op_name": "Constant"},
            {"op_name": "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
            {"op_name": "Sub"},
        ]
        ops_9 = [{"op_name": "Constant"}, {"op_name": "Sub"}]
        ops = {8: ops_8, 9: ops_9}
        x = torch.ones(5, 6, dtype=torch.long)
        check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])

    def test_slice(self):
        class MyModule(Module):
            def forward(self, x):
                return x[0:1]

        ops_9 = [
            {
                "op_name": "Slice",
                "attributes": [
                    {"name": "axes", "ints": [0], "type": 7},
                    {"name": "ends", "ints": [1], "type": 7},
                    {"name": "starts", "ints": [0], "type": 7},
                ],
            }
        ]
        ops_10 = [
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Slice", "attributes": []},
        ]
        ops = {9: ops_9, 10: ops_10}
        x = torch.randn(3)
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])

        class DynamicSliceModel(torch.jit.ScriptModule):
            @torch.jit.script_method
            def forward(self, x):
                return x[1 : x.size(0)]

        module = DynamicSliceModel()
        x = torch.rand(1, 2)
        ops_10 = [
            {"op_name": "Shape"},
            {"op_name": "Constant"},
            {"op_name": "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {
                "op_name": "Unsqueeze",
                "attributes": [{"name": "axes", "i": 0, "type": 7}],
            },
            {"op_name": "Constant"},
            {"op_name": "Slice", "attributes": []},
        ]
        ops = {10: ops_10}
        check_onnx_opsets_operator(
            module,
            x,
            ops,
            opset_versions=[10],
            input_names=["x"],
            dynamic_axes={"x": [0, 1]},
        )

        ops_10 = [
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Slice", "attributes": []},
        ]
        ops = {10: ops_10}
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])

    def test_flip(self):
        class MyModule(Module):
            def forward(self, x):
                return torch.flip(x, dims=[0])

        ops_10 = [
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Slice", "attributes": []},
        ]
        ops = {10: ops_10}
        import numpy

        x = torch.tensor(numpy.arange(6.0).reshape(2, 3))
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[10])

    def test_dropout(self):
        class MyModule(Module):
            def __init__(self) -> None:
                super().__init__()
                self.dropout = torch.nn.Dropout(0.5)

            def forward(self, x):
                return self.dropout(x)

        x = torch.randn(1, 2, 3)

        # we should only export the onnx Dropout op in training mode; test both modes

        # test training mode
        ops = [
            {
                "op_name": "Dropout",
                "attributes": [{"name": "ratio", "f": 0.5, "type": 1}],
            }
        ]
        ops = {9: ops, 10: ops}
        check_onnx_opsets_operator(
            MyModule(),
            x,
            ops,
            opset_versions=[9, 10],
            training=torch.onnx.TrainingMode.TRAINING,
        )

        # test eval mode
        ops = [{"op_name": "Identity"}]
        ops = {9: ops, 10: ops}
        check_onnx_opsets_operator(
            MyModule(),
            x,
            ops,
            opset_versions=[9, 10],
            training=torch.onnx.TrainingMode.EVAL,
        )

    def test_full(self):
        class MyModule(Module):
            def forward(self, x):
                return torch.full((3, 4), x)

        ops = [
            {"op_name": "Constant"},
            {"op_name": "ConstantOfShape"},
            {"op_name": "Add"},
        ]
        ops = {9: ops, 10: ops}
        x = torch.tensor(12.0)
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])

    def test_interpolate(self):
        class MyModel(torch.nn.Module):
            def forward(self, x):
                size = [v * 2 for v in x.size()[2:]]
                return torch.nn.functional.interpolate(x, size=size, mode="nearest")

        ops_9 = [
            {"op_name": "Shape"},
            {"op_name": "Constant"},
            {"op_name": "Gather"},
            {"op_name": "Shape"},
            {"op_name": "Constant"},
            {"op_name": "Gather"},
            {"op_name": "Constant"},
            {"op_name": "Mul"},
            {"op_name": "Constant"},
            {"op_name": "Mul"},
            {"op_name": "Unsqueeze"},
            {"op_name": "Unsqueeze"},
            {"op_name": "Concat"},
            {"op_name": "Cast"},
            {"op_name": "Shape"},
            {"op_name": "Slice"},
            {"op_name": "Cast"},
            {"op_name": "Div"},
            {"op_name": "Constant"},
            {"op_name": "Concat"},
            {
                "op_name": "Upsample",
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
            },
        ]
        ops_10 = [
            {"op_name": "Shape"},
            {"op_name": "Constant"},
            {"op_name": "Gather"},
            {"op_name": "Shape"},
            {"op_name": "Constant"},
            {"op_name": "Gather"},
            {"op_name": "Constant"},
            {"op_name": "Mul"},
            {"op_name": "Constant"},
            {"op_name": "Mul"},
            {"op_name": "Unsqueeze"},
            {"op_name": "Unsqueeze"},
            {"op_name": "Concat"},
            {"op_name": "Cast"},
            {"op_name": "Shape"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Slice"},
            {"op_name": "Cast"},
            {"op_name": "Div"},
            {"op_name": "Constant"},
            {"op_name": "Concat"},
            {
                "op_name": "Resize",
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
            },
        ]

        ops = {9: ops_9, 10: ops_10}
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
        check_onnx_opsets_operator(
            MyModel(),
            x,
            ops,
            opset_versions=[9, 10],
            input_names=["x"],
            dynamic_axes={"x": [0, 1, 2, 3]},
        )

        ops_9 = [
            {"op_name": "Constant"},
            {"op_name": "Shape"},
            {"op_name": "Slice"},
            {"op_name": "Cast"},
            {"op_name": "Div"},
            {"op_name": "Constant"},
            {"op_name": "Concat"},
            {
                "op_name": "Upsample",
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
            },
        ]
        ops_10 = [
            {"op_name": "Constant"},
            {"op_name": "Shape"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Constant"},
            {"op_name": "Slice"},
            {"op_name": "Cast"},
            {"op_name": "Div"},
            {"op_name": "Constant"},
            {"op_name": "Concat"},
            {"op_name": "Resize"},
        ]

        ops = {9: ops_9, 10: ops_10}
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
        check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10])

        class MyDynamicModel(torch.nn.Module):
            def forward(self, x):
                size = [v * 2 for v in x.size()[2:]]
                # work around for now: turn the dynamic sizes into constant
                size = [int(i) for i in size]
                return torch.nn.functional.interpolate(x, size=size, mode="nearest")

        ops_9 = [
            {"op_name": "Constant"},
            {
                "op_name": "Upsample",
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
            },
        ]
        ops_10 = [
            {"op_name": "Constant"},
            {
                "op_name": "Resize",
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
            },
        ]
        ops = {9: ops_9, 10: ops_10}
        x = torch.randn(20, 16, 50)
        check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])

    def test_affine_grid(self):
        class MyModule(Module):
            def __init__(self, align_corners):
                super().__init__()
                self.align_corners = align_corners

            def forward(self, theta, size):
                return torch.nn.functional.affine_grid(
                    theta, size, align_corners=self.align_corners
                )

        opset_version = 20
        ops_2d = {
            opset_version: [
                {"op_name": "Constant"},
                {"op_name": "Unsqueeze"},
                {"op_name": "Constant"},
                {"op_name": "Unsqueeze"},
                {"op_name": "Constant"},
                {"op_name": "Unsqueeze"},
                {"op_name": "Constant"},
                {"op_name": "Unsqueeze"},
                {"op_name": "Concat"},
                {"op_name": "AffineGrid"},
            ]
        }

        ops_3d = {
            opset_version: [
                {"op_name": "Constant"},
                {"op_name": "Unsqueeze"},
                {"op_name": "Constant"},
                {"op_name": "Unsqueeze"},
                {"op_name": "Constant"},
                {"op_name": "Unsqueeze"},
                {"op_name": "Constant"},
                {"op_name": "Unsqueeze"},
                {"op_name": "Constant"},
                {"op_name": "Unsqueeze"},
                {"op_name": "Concat"},
                {"op_name": "AffineGrid"},
            ]
        }
        # 2D affine
        theta_2d = torch.empty(1, 2, 3, dtype=torch.double)
        size_2d = torch.Size([1, 1, 2, 2])
        # 3D affine
        theta_3d = torch.empty(1, 3, 4, dtype=torch.double)
        size_3d = torch.Size([1, 1, 2, 2, 2])

        for inputs, align_corners in itertools.product(
            ((theta_2d, size_2d, ops_2d), (theta_3d, size_3d, ops_3d)),
            (True, False),
        ):
            theta, size, ops = inputs
            args = (
                theta,
                size,
            )
            check_onnx_opsets_operator(
                MyModule(align_corners=align_corners),
                args,
                ops,
                opset_versions=[opset_version],
                training=torch.onnx.TrainingMode.TRAINING,
            )
            check_onnx_opsets_operator(
                MyModule(align_corners=align_corners),
                args,
                ops,
                opset_versions=[opset_version],
                training=torch.onnx.TrainingMode.EVAL,
            )

    def test_grid_sample(self):
        class MyModule(torch.nn.Module):
            def __init__(self, mode, padding_mode, align_corners):
                super().__init__()
                self.mode = mode
                self.padding_mode = padding_mode
                self.align_corners = align_corners

            def forward(self, x, grid):
                return torch.nn.functional.grid_sample(
                    x,
                    grid,
                    mode=self.mode,
                    padding_mode=self.padding_mode,
                    align_corners=self.align_corners,
                )

        for mode, padding_mode, align_corners, opset_version in itertools.product(
            ("bilinear", "nearest", "bicubic"),
            ("zeros", "border", "reflection"),
            (True, False),
            (16, 20),
        ):

            def test_eval_and_training(
                ops, opset_version, mode, padding_mode, align_corners, x_shape, grid
            ):
                args = (
                    torch.randn(*x_shape),  # x
                    torch.randn(grid),  # grid,
                )
                check_onnx_opsets_operator(
                    MyModule(
                        mode=mode,
                        padding_mode=padding_mode,
                        align_corners=align_corners,
                    ),
                    args,
                    ops,
                    opset_versions=[opset_version],
                    training=torch.onnx.TrainingMode.TRAINING,
                )
                check_onnx_opsets_operator(
                    MyModule(
                        mode=mode,
                        padding_mode=padding_mode,
                        align_corners=align_corners,
                    ),
                    args,
                    ops,
                    opset_versions=[opset_version],
                    training=torch.onnx.TrainingMode.EVAL,
                )

            ops = {opset_version: [{"op_name": "GridSample"}]}
            # mode = convert_grid_sample_mode(mode) if opset_version == 20 else mode
            n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4
            test_eval_and_training(
                ops,
                opset_version,
                mode,
                padding_mode,
                align_corners,
                (n, c, h_in, w_in),
                (n, h_out, w_out, 2),
            )
            if opset_version == 20 and mode != "bicubic":
                test_eval_and_training(
                    ops,
                    opset_version,
                    mode,
                    padding_mode,
                    align_corners,
                    (n, c, d_in, h_in, w_in),
                    (n, d_out, h_out, w_out, 3),
                )

    def test_flatten(self):
        class MyModule(Module):
            def forward(self, x):
                return torch.flatten(x)

        module = MyModule()

        ops_0d = [{"op_name": "Constant"}, {"op_name": "Reshape"}]
        ops_1d = [{"op_name": "Identity"}]
        for shape in ([], [3]):
            x = torch.randn(shape)
            for opset_version in [9, 10]:
                ops = {opset_version: (ops_0d if len(shape) == 0 else ops_1d)}
                check_onnx_opsets_operator(
                    module, x, ops, opset_versions=[opset_version]
                )


if __name__ == "__main__":
    common_utils.run_tests()
