# Owner(s): ["oncall: jit"]

import unittest

import torch
import torch._C


torch.ops.load_library("//caffe2:xnnpack_backend")


class TestXNNPackBackend(unittest.TestCase):
    def test_xnnpack_constant_data(self):
        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self._constant = torch.ones(4, 4, 4)

            def forward(self, x):
                return x + self._constant

        scripted_module = torch.jit.script(Module())

        lowered_module = torch._C._jit_to_backend(
            "xnnpack",
            scripted_module,
            {
                "forward": {
                    "inputs": [torch.randn(4, 4, 4)],
                    "outputs": [torch.randn(4, 4, 4)],
                }
            },
        )

        for i in range(0, 20):
            sample_input = torch.randn(4, 4, 4)
            actual_output = scripted_module(sample_input)
            expected_output = lowered_module(sample_input)
            self.assertTrue(
                torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
            )

    def test_xnnpack_lowering(self):
        class Module(torch.nn.Module):
            def forward(self, x):
                return x + x

        scripted_module = torch.jit.script(Module())

        faulty_compile_spec = {
            "backward": {
                "inputs": [torch.zeros(1)],
                "outputs": [torch.zeros(1)],
            }
        }
        error_msg = 'method_compile_spec does not contain the "forward" key.'

        with self.assertRaisesRegex(
            RuntimeError,
            error_msg,
        ):
            _ = torch._C._jit_to_backend(
                "xnnpack",
                scripted_module,
                faulty_compile_spec,
            )

        mismatch_compile_spec = {
            "forward": {
                "inputs": [torch.zeros(1), torch.zeros(1)],
                "outputs": [torch.zeros(1)],
            }
        }
        error_msg = (
            "method_compile_spec inputs do not match expected number of forward inputs"
        )

        with self.assertRaisesRegex(
            RuntimeError,
            error_msg,
        ):
            _ = torch._C._jit_to_backend(
                "xnnpack", scripted_module, mismatch_compile_spec
            )

        lowered = torch._C._jit_to_backend(
            "xnnpack",
            scripted_module,
            {
                "forward": {
                    "inputs": [torch.zeros(1)],
                    "outputs": [torch.zeros(1)],
                }
            },
        )
        lowered(torch.zeros(1))

    def test_xnnpack_backend_add(self):
        class AddModule(torch.nn.Module):
            def forward(self, x, y):
                z = x + y
                z = z + x
                z = z + x
                return z

        add_module = AddModule()
        sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3))
        sample_output = torch.zeros(1, 512, 512, 3)

        add_module = torch.jit.script(add_module)
        expected_output = add_module(sample_inputs[0], sample_inputs[1])

        lowered_add_module = torch._C._jit_to_backend(
            "xnnpack",
            add_module,
            {
                "forward": {
                    "inputs": [sample_inputs[0].clone(), sample_inputs[1].clone()],
                    "outputs": [sample_output],
                }
            },
        )

        actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
        self.assertTrue(
            torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
        )

    def test_xnnpack_broadcasting(self):
        class AddModule(torch.nn.Module):
            def forward(self, x, y):
                return x + y

        add_module = AddModule()
        sample_inputs = (torch.rand(5, 1, 4, 1), torch.rand(3, 1, 1))
        sample_output = torch.zeros(5, 3, 4, 1)

        add_module = torch.jit.script(add_module)
        expected_output = add_module(sample_inputs[0], sample_inputs[1])

        lowered_add_module = torch._C._jit_to_backend(
            "xnnpack",
            add_module,
            {
                "forward": {
                    "inputs": [sample_inputs[0], sample_inputs[1]],
                    "outputs": [sample_output],
                }
            },
        )

        actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
        self.assertTrue(
            torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
        )

    def test_xnnpack_unsupported(self):
        class AddSpliceModule(torch.nn.Module):
            def forward(self, x, y):
                z = x + y[:, :, 1, :]
                return z

        sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3))
        sample_output = torch.zeros(1, 512, 512, 3)

        error_msg = (
            "the module contains the following unsupported ops:\n"
            "aten::select\n"
            "aten::slice\n"
        )

        add_module = torch.jit.script(AddSpliceModule())
        with self.assertRaisesRegex(
            RuntimeError,
            error_msg,
        ):
            _ = torch._C._jit_to_backend(
                "xnnpack",
                add_module,
                {
                    "forward": {
                        "inputs": [sample_inputs[0], sample_inputs[1]],
                        "outputs": [sample_output],
                    }
                },
            )
