#!/usr/bin/env python3
# Owner(s): ["oncall: mobile"]

import ctypes
import os
import unittest
from typing import Tuple

import torch
from torch.backends._nnapi.prepare import convert_model_to_nnapi
from torch.testing._internal.common_quantized import supported_qengines
from torch.testing._internal.common_utils import run_tests, TestCase


def qpt(t, scale, zero_point, dtype=torch.quint8):
    t = torch.tensor(t)
    return torch.quantize_per_tensor(t, scale, zero_point, dtype)


def nhwc(t):
    t = t.clone().contiguous(memory_format=torch.channels_last)
    t.nnapi_nhwc = True
    return t


@unittest.skipUnless(
    "qnnpack" in supported_qengines,
    "This Pytorch Build has not been built with or does not support QNNPACK",
)
class TestNNAPI(TestCase):
    def setUp(self):
        # Avoid saturation in fbgemm
        torch.backends.quantized.engine = "qnnpack"

        libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH")
        if libneuralnetworks_path:
            ctypes.cdll.LoadLibrary(libneuralnetworks_path)
            print("Will attempt to run NNAPI models.")
            self.can_run_nnapi = True
        else:
            self.can_run_nnapi = False

    # Created for easy override by subclasses (eg TestNnapiBackend)
    def call_lowering_to_nnapi(self, traced_module, args):
        return convert_model_to_nnapi(traced_module, args)

    # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend)
    def set_can_run_nnapi(self, can_run):
        self.can_run_nnapi = can_run

    def check(
        self,
        module,
        arg_or_args,
        *,
        trace_args=None,
        convert_args=None,
        atol_rtol=None,
        limit=None,
        expected_memory_format=None,
    ):
        with torch.no_grad():
            if isinstance(arg_or_args, torch.Tensor):
                args = [arg_or_args]
            else:
                args = arg_or_args
            module.eval()
            traced = torch.jit.trace(module, trace_args or args)
            nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args)
            if not self.can_run_nnapi:
                # Only test that the model was converted successfully.
                return
            eager_output = module(*args)
            nnapi_output = nnapi_module(*args)
            kwargs = {}
            if atol_rtol is not None:
                kwargs["atol"] = atol_rtol[0]
                kwargs["rtol"] = atol_rtol[1]
            self.assertEqual(eager_output, nnapi_output, **kwargs)
            if limit is not None:
                mismatches = eager_output.int_repr().to(
                    torch.int32
                ) - nnapi_output.int_repr().to(torch.int32)
                if mismatches.count_nonzero() > limit:
                    # Too many mismatches.  Re-run the check with no tolerance
                    # to get a nice message.
                    self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
            if expected_memory_format:
                self.assertTrue(
                    nnapi_output.is_contiguous(memory_format=expected_memory_format)
                )

    def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
        torch.manual_seed(29)
        inp_quant = qpt(inp_float, 0.03, 128)
        return [
            ("float", inp_float),
            ("float-nhwc", nhwc(inp_float)),
            ("quant", inp_quant),
            ("quant-nhwc", nhwc(inp_quant)),
        ]

    def test_prelu(self):
        arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
        single_a = torch.nn.PReLU()
        self.check(single_a, arg)
        multi_a = torch.nn.PReLU(4)
        with torch.no_grad():
            multi_a.weight.copy_(torch.tensor([0.1, 0.2, 0.3, 0.4]))
        self.check(multi_a, nhwc(arg))

        # Test flexible size
        self.check(
            multi_a,
            arg,
            trace_args=[torch.zeros(1, 4, 3, 3)],
            convert_args=[nhwc(torch.zeros(1, 4, 0, 0))],
        )

    def test_quantize(self):
        self.check(
            torch.ao.nn.quantized.Quantize(0.25, 2, torch.quint8),
            nhwc(torch.tensor([[[[1.0]], [[2.0]]]])),
        )

    def test_dequantize(self):
        self.check(
            torch.ao.nn.quantized.DeQuantize(), nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2))
        )

    def test_unsqueeze(self):
        class UnsqueezeModule(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.dim = dim

            def forward(self, arg):
                return arg.unsqueeze(self.dim)

        self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2))
        self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2))
        self.check(UnsqueezeModule(0), torch.randn(4, 2, 2))
        self.check(UnsqueezeModule(1), torch.randn(4, 2, 2))
        self.check(UnsqueezeModule(2), torch.randn(4, 2, 2))

    def test_reshape(self):
        class ReshapeModule(torch.nn.Module):
            def __init__(self, shape):
                super().__init__()
                self.shape = shape

            def forward(self, arg):
                return arg.reshape(self.shape)

        self.check(ReshapeModule((2, 4)), torch.randn(4, 2, 1, 1))

        self.check(ReshapeModule((8, -1)), nhwc(torch.randn(4, 2, 1, 1)))

        with self.assertRaisesRegex(Exception, "target size"):
            self.check(ReshapeModule((2, 4)), nhwc(torch.randn(4, 2, 1, 1)))

    def test_flatten(self):
        for mod in [
            torch.nn.Flatten(),
            torch.nn.Flatten(start_dim=2, end_dim=3),
            torch.nn.Flatten(start_dim=2, end_dim=4),
            torch.nn.Flatten(start_dim=0, end_dim=-2),
            torch.nn.Flatten(start_dim=0, end_dim=4),
        ]:
            self.check(mod, torch.randn(4, 2, 1, 3, 7))

        # flex inputs
        self.check(
            torch.nn.Flatten(),
            torch.randn(4, 2, 1, 3, 7),
            convert_args=[torch.zeros(0, 2, 1, 3, 7)],
        )

        # channels last
        self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 1, 4, 7)))
        self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 3, 1, 1)))

        # Exceptions
        with self.assertRaisesRegex(Exception, "not supported on NHWC"):
            self.check(torch.nn.Flatten(), nhwc(torch.randn(1, 3, 4, 4)))
        with self.assertRaisesRegex(
            Exception, "Flattening flexible dims is not supported yet"
        ):
            self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
        with self.assertRaisesRegex(Exception, "Only 1 dim"):
            self.check(
                torch.nn.Flatten(start_dim=1, end_dim=-2), torch.randn(0, 2, 1, 3, 0)
            )

    def test_slice(self):
        class SliceModule(torch.nn.Module):
            def __init__(self, start, stop, step):
                super().__init__()
                self.start = start
                self.stop = stop
                self.step = step

            def forward(self, t):
                return t[1:, self.start : self.stop : self.step, :]

        class SliceModule2(torch.nn.Module):
            def forward(self, t):
                return t[3:]

        self.check(SliceModule(1, 5, 2), torch.randn(4, 6, 2))
        self.check(SliceModule2(), torch.randn(5))

        # flex inputs
        self.check(
            SliceModule(1, 5, 2),
            torch.randn(4, 6, 2),
            convert_args=[torch.zeros(4, 6, 0)],
        )
        with self.assertRaisesRegex(Exception, "slice with flexible shape"):
            self.check(
                SliceModule(1, 5, 2),
                torch.randn(4, 6, 2),
                convert_args=[torch.zeros(0, 0, 0)],
            )

    def test_cat(self):
        class CatModule(torch.nn.Module):
            def __init__(self, dim):
                super().__init__()
                self.dim = dim

            def forward(self, t1, t2):
                return torch.cat([t1, t2], self.dim)

        self.check(
            CatModule(0),
            [
                torch.randn(1, 2, 3, 3),
                torch.randn(2, 2, 3, 3),
            ],
        )

        self.check(
            CatModule(1),
            [
                torch.randn(1, 2, 3, 3),
                torch.randn(1, 4, 3, 3),
            ],
        )

        self.check(
            CatModule(1),
            [
                nhwc(torch.randn(1, 2, 3, 3)),
                nhwc(torch.randn(1, 4, 3, 3)),
            ],
        )

        self.check(
            CatModule(1),
            [
                torch.randn(1, 2, 3, 3),
                torch.randn(1, 4, 3, 3),
            ],
            convert_args=[torch.zeros(0, 0, 0, 0), torch.zeros(0, 0, 0, 0)],
        )

    def test_pointwise_unary(self):
        for op in ["relu", "sigmoid"]:
            with self.subTest(op):

                class UnaryModule(torch.nn.Module):
                    def forward(self, arg):
                        if op == "relu":
                            return torch.nn.functional.relu(arg)
                        if op == "sigmoid":
                            return torch.sigmoid(arg)
                        raise Exception("Bad op")  # noqa: TRY002

                self.check(UnaryModule(), torch.tensor([-1.0, 1.0]))
                self.check(
                    UnaryModule(),
                    qpt(torch.tensor([-1.0, 1.0]), 1.0 / 256, 0),
                )

    def test_pointwise_binary(self):
        for op in ["add", "sub", "mul", "div"]:
            with self.subTest(op):

                class BinaryModule(torch.nn.Module):
                    def forward(self, lhs, rhs):
                        if op == "add":
                            return lhs + rhs
                        if op == "sub":
                            return lhs - rhs
                        if op == "mul":
                            return lhs * rhs
                        if op == "div":
                            return lhs / rhs
                        raise Exception("Bad op")  # noqa: TRY002

                self.check(
                    BinaryModule(),
                    [
                        torch.tensor([1.0, 2.0]),
                        torch.tensor([3.0, 4.0]),
                    ],
                )

                self.check(
                    BinaryModule(),
                    [
                        torch.tensor([[1.0, 2.0]]),
                        torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
                    ],
                )

                with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"):
                    self.check(
                        BinaryModule(),
                        [
                            torch.tensor([1.0, 2.0]),
                            torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
                        ],
                    )

    def test_pointwise_binary_const(self):
        const = torch.randn(1, 4, 6, 6)

        class ArgPlusConst(torch.nn.Module):
            def forward(self, arg):
                return arg + const

        class ConstPlusArg(torch.nn.Module):
            def forward(self, arg):
                return const + arg

        arg_contig = torch.randn(2, 4, 6, 6)
        arg_nhwc = nhwc(torch.randn(2, 4, 6, 6))

        for mod_class in [ArgPlusConst, ConstPlusArg]:
            for use_nhwc in [False, True]:
                with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
                    arg = arg_nhwc if use_nhwc else arg_contig
                    memory_format = (
                        torch.channels_last if use_nhwc else torch.contiguous_format
                    )
                    self.check(mod_class(), arg, expected_memory_format=memory_format)

    def test_hardtanh(self):
        inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
        self.check(torch.nn.Hardtanh(), inp)
        self.check(torch.nn.Hardtanh(0.0, 6.0), inp)
        with self.assertRaisesRegex(Exception, "hardtanh with args"):
            self.check(torch.nn.Hardtanh(0.0, 5.0), inp)

    def test_softmax(self):
        inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]])
        self.check(torch.nn.Softmax(), inp)
        self.check(torch.nn.Softmax(dim=0), inp)
        # Test flexible size
        self.check(
            torch.nn.Softmax(),
            inp,
            convert_args=[torch.zeros(0, 0)],
        )

    def test_to(self):
        class ToCPU(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.prelu = torch.nn.PReLU()

            def forward(self, x):
                y = x.to("cpu")
                # add prelu since input operand can't be output
                return self.prelu(y)

        arg = torch.randn(1, 2, 3, 3)
        self.check(ToCPU(), arg)
        # Test flexible size
        self.check(
            ToCPU(),
            arg,
            convert_args=[torch.zeros(1, 2, 0, 0)],
        )

    def test_detach(self):
        class DetachModule(torch.nn.Module):
            def forward(self, x):
                y = x.detach()
                return torch.nn.functional.relu(y)

        self.check(DetachModule(), torch.randn(1, 2, 3, 3))
        self.check(
            DetachModule(),
            torch.randn(1, 2, 3, 3),
            convert_args=[torch.zeros(1, 2, 0, 0)],
        )

    def test_log_softmax(self):
        inp = torch.randn(3, 10)
        self.check(torch.nn.LogSoftmax(), inp)
        self.check(torch.nn.LogSoftmax(0), inp)

    def test_mean(self):
        class MeanModule(torch.nn.Module):
            def __init__(self, dim, keep=False):
                super().__init__()
                self.dim = dim
                self.keep = keep

            def forward(self, t):
                return torch.mean(t, dim=self.dim, keepdim=self.keep)

        self.check(MeanModule(0), torch.randn(2, 3))
        self.check(MeanModule(1), torch.randn(2, 3))
        self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6))
        self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6)))
        self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6)))
        self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6)))

    def test_max_pool2d(self):
        for name, inp in self.float_and_quant_and_nhwc(
            torch.randn(2, 3, 12, 16), 0.3, 128
        ):
            with self.subTest(name):
                self.check(torch.nn.MaxPool2d(2), inp)
                self.check(torch.nn.MaxPool2d((3, 4)), inp)
                self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp)

    def test_avg_pool2d(self):
        for name, inp in self.float_and_quant_and_nhwc(
            torch.randn(2, 3, 12, 16), 0.3, 128
        ):
            with self.subTest(name):
                atol_rtol = None
                limit = None
                convert_dims = (2, 3, 0, 0)
                convert_arg = torch.zeros(*convert_dims)

                for model in (
                    torch.nn.AvgPool2d(2),
                    torch.nn.AvgPool2d((3, 4)),
                    torch.nn.AvgPool2d((3, 4), (1, 2)),
                ):
                    if "quant" in name:
                        atol_rtol = (1, 0)
                        limit = model(inp).numel()
                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
                    if "nhwc" in name:
                        convert_arg = nhwc(convert_arg)

                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
                    self.check(
                        model,
                        inp,
                        convert_args=[convert_arg],
                        atol_rtol=atol_rtol,
                        limit=limit,
                    )

    def test_adaptive_avg_pool2d(self):
        for name, inp in self.float_and_quant_and_nhwc(
            torch.randn(2, 3, 12, 16), 0.3, 128
        ):
            with self.subTest(name):
                self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp)
                with self.assertRaisesRegex(Exception, "with output size"):
                    self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp)

    def test_upsample_nearest2d(self):
        convert_args = dict(
            self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128)
        )
        for name, inp in self.float_and_quant_and_nhwc(
            torch.randn(2, 3, 12, 16), 0.3, 128
        ):
            with self.subTest(name):
                self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp)
                self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp)
                self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp)
                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp)
                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp)
                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp)

                self.check(
                    torch.nn.UpsamplingNearest2d(size=(24, 32)),
                    inp,
                    convert_args=[convert_args[name]],
                )
                self.check(
                    torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)),
                    inp,
                    convert_args=[convert_args[name]],
                )

    def test_linear(self):
        torch.manual_seed(29)
        self.check(torch.nn.Linear(16, 32), torch.randn(2, 16))
        self.check(
            torch.nn.Linear(16, 32),
            torch.randn(2, 16),
            convert_args=[torch.zeros(0, 16)],
        )

    def test_conv2d(self):
        cases = [
            # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim,      name
            (4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"),  # noqa: E201,E241
            (4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"),  # noqa: E201,E241
            (4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"),  # noqa: E201,E241
            (8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"),  # noqa: E201,E241
            (4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"),  # noqa: E201,E241
            (4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"),  # noqa: E201,E241
            (8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"),  # noqa: E201,E241
        ]

        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
            for case in cases:
                (
                    in_ch,
                    out_ch,
                    kernel,
                    stride,
                    padding,
                    groups,
                    bias,
                    input_dim,
                    name,
                ) = case
                with self.subTest(f"{kind}-{name}"):
                    inp = torch.randn(input_dim)
                    model = torch.nn.Conv2d(
                        in_ch,
                        out_ch,
                        kernel,
                        stride,
                        padding,
                        groups=groups,
                        bias=bool(bias),
                    )
                    output_size = model(inp).numel()
                    atol_rtol = None
                    limit = None
                    convert_dims = (0, in_ch, 0, 0)
                    convert_arg = torch.zeros(*convert_dims)

                    if "quant" in kind:
                        model = torch.nn.Sequential(model)
                        model.eval()
                        model.qconfig = torch.ao.quantization.get_default_qconfig(
                            "qnnpack"
                        )
                        model = torch.ao.quantization.prepare(model)
                        model(inp)
                        model = torch.ao.quantization.convert(model)
                        inp = qpt(inp, 1.0 / 16, 128)
                        # I've seen numerical differences between QNNPACK and NNAPI,
                        # but never more than 1 quantum, and never more than ~1% of
                        # the output in this test.
                        atol_rtol = (1, 0)
                        limit = output_size * 0.03
                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)

                    if "nhwc" in kind:
                        inp = nhwc(inp)
                        convert_arg = nhwc(convert_arg)

                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
                    self.check(
                        model,
                        inp,
                        convert_args=[convert_arg],
                        atol_rtol=atol_rtol,
                        limit=limit,
                    )

    def test_conv2d_transpose(self):
        torch.manual_seed(29)
        in_ch, out_ch, kernel = (5, 7, (2, 2))
        input_dim = (4, 5, 3, 3)
        convert_dims = input_dim[:2] + (0, 0)

        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
            with self.subTest(kind):
                inp = torch.randn(input_dim)
                model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel)
                output_size = model(inp).numel()
                atol_rtol = (0.0002, 0)
                limit = None
                convert_arg = torch.zeros(*convert_dims)

                if "quant" in kind:
                    model = torch.ao.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel)
                    model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
                    inp = qpt(inp, 1.0 / 16, 128)
                    # I've seen numerical differences between QNNPACK and NNAPI,
                    # but never more than 1 quantum, and never more than ~10% of
                    # the output in this test.
                    atol_rtol = (1, 0)
                    limit = output_size * 0.1
                    convert_arg = qpt(convert_arg, 1.0 / 16, 128)

                if "nhwc" in kind:
                    inp = nhwc(inp)
                    convert_arg = nhwc(convert_arg)

                self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
                self.check(
                    model,
                    inp,
                    convert_args=[convert_arg],
                    atol_rtol=atol_rtol,
                    limit=limit,
                )

    def test_qadd(self):
        func = torch.ao.nn.quantized.QFunctional()
        func.scale = 0.5
        func.zero_point = 120

        class AddMod(torch.nn.Module):
            def forward(self, lhs, rhs):
                return func.add(lhs, rhs)

        class AddReluMod(torch.nn.Module):
            def forward(self, lhs, rhs):
                return func.add_relu(lhs, rhs)

        class MulMod(torch.nn.Module):
            def forward(self, lhs, rhs):
                return func.mul(lhs, rhs)

        for name, mod in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]:
            with self.subTest(name):
                self.check(
                    mod(),
                    [
                        qpt([1.0, 2.0], 0.25, 128),
                        qpt([3.0, 4.0], 0.25, 128),
                    ],
                )
                self.check(
                    mod(),
                    [
                        qpt([[1.0, 2.0]], 0.25, 128),
                        qpt([[3.0, 4.0]], 0.25, 128),
                    ],
                    convert_args=[
                        qpt([[1.0, 2.0]], 0.25, 128),
                        qpt(torch.zeros((1, 2)), 0.25, 128),
                    ],
                )
                self.check(
                    mod(),
                    [
                        qpt([[1.0, 2.0]], 0.25, 128),
                        qpt([[3.0, 4.0]], 0.25, 128),
                    ],
                    convert_args=[
                        qpt(torch.zeros((1, 2)), 0.25, 128),
                        qpt([[3.0, 4.0]], 0.25, 128),
                    ],
                )
                self.check(
                    mod(),
                    [
                        qpt([[1.0, 2.0]], 0.25, 128),
                        qpt([[3.0, 4.0]], 0.25, 128),
                    ],
                    convert_args=[
                        qpt(torch.zeros((1, 2)), 0.25, 128),
                        qpt(torch.zeros((1, 2)), 0.25, 128),
                    ],
                )
                # NOTE: NNAPI qadd supports broadcast, but PT does not.

    def test_qlinear(self):
        torch.manual_seed(29)
        weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8)
        bias = torch.randn(16)
        mod = torch.ao.nn.quantized.Linear(32, 16)
        mod.set_weight_bias(weight, bias)
        inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
        self.check(mod, inp)

    def test_seblock_mul(self):
        class MulModel(torch.nn.Module):
            def forward(self, lhs, rhs):
                return lhs * rhs

        self.check(
            MulModel(),
            [
                nhwc(torch.randn(2, 3, 4, 4)),
                torch.randn(1, 3, 1, 1),
            ],
        )

    def test_multi_output(self):
        class MultiModel(torch.nn.Module):
            def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
                the_sum = lhs + rhs
                the_diff = lhs - rhs
                return the_sum, the_diff

        self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])])


if __name__ == "__main__":
    run_tests()
