# Owner(s): ["oncall: jit"]

import os
import sys
import unittest
from pathlib import Path

import torch
import torch._C
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo


# hacky way to skip these tests in fbcode:
# during test execution in fbcode, test_nnapi is available during test discovery,
# but not during test execution. So we can't try-catch here, otherwise it'll think
# it sees tests but then fails when it tries to actuall run them.
if not IS_FBCODE:
    from test_nnapi import TestNNAPI

    HAS_TEST_NNAPI = True
else:
    from torch.testing._internal.common_utils import TestCase as TestNNAPI

    HAS_TEST_NNAPI = False


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)

if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )

"""
Unit Tests for Nnapi backend with delegate
Inherits most tests from TestNNAPI, which loads Android NNAPI models
without the delegate API.
"""
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
torch_root = Path(__file__).resolve().parent.parent.parent
lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so"


@skipIfTorchDynamo("weird py38 failures")
@unittest.skipIf(
    not os.path.exists(lib_path),
    "Skipping the test as libnnapi_backend.so was not found",
)
@unittest.skipIf(IS_FBCODE, "test_nnapi.py not found")
class TestNnapiBackend(TestNNAPI):
    def setUp(self):
        super().setUp()

        # Save default dtype
        module = torch.nn.PReLU()
        self.default_dtype = module.weight.dtype
        # Change dtype to float32 (since a different unit test changed dtype to float64,
        # which is not supported by the Android NNAPI delegate)
        # Float32 should typically be the default in other files.
        torch.set_default_dtype(torch.float32)

        # Load nnapi delegate library
        torch.ops.load_library(str(lib_path))

    # Override
    def call_lowering_to_nnapi(self, traced_module, args):
        compile_spec = {"forward": {"inputs": args}}
        return torch._C._jit_to_backend("nnapi", traced_module, compile_spec)

    def test_tensor_input(self):
        # Lower a simple module
        args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
        module = torch.nn.PReLU()
        traced = torch.jit.trace(module, args)

        # Argument input is a single Tensor
        self.call_lowering_to_nnapi(traced, args)
        # Argument input is a Tensor in a list
        self.call_lowering_to_nnapi(traced, [args])

    # Test exceptions for incorrect compile specs
    def test_compile_spec_santiy(self):
        args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
        module = torch.nn.PReLU()
        traced = torch.jit.trace(module, args)

        errorMsgTail = r"""
method_compile_spec should contain a Tensor or Tensor List which bundles input parameters: shape, dtype, quantization, and dimorder.
For input shapes, use 0 for run/load time flexible input.
method_compile_spec must use the following format:
{"forward": {"inputs": at::Tensor}} OR {"forward": {"inputs": c10::List<at::Tensor>}}"""

        # No forward key
        compile_spec = {"backward": {"inputs": args}}
        with self.assertRaisesRegex(
            RuntimeError,
            'method_compile_spec does not contain the "forward" key.' + errorMsgTail,
        ):
            torch._C._jit_to_backend("nnapi", traced, compile_spec)

        # No dictionary under the forward key
        compile_spec = {"forward": 1}
        with self.assertRaisesRegex(
            RuntimeError,
            'method_compile_spec does not contain a dictionary with an "inputs" key, '
            'under it\'s "forward" key.' + errorMsgTail,
        ):
            torch._C._jit_to_backend("nnapi", traced, compile_spec)

        # No inputs key (in the dictionary under the forward key)
        compile_spec = {"forward": {"not inputs": args}}
        with self.assertRaisesRegex(
            RuntimeError,
            'method_compile_spec does not contain a dictionary with an "inputs" key, '
            'under it\'s "forward" key.' + errorMsgTail,
        ):
            torch._C._jit_to_backend("nnapi", traced, compile_spec)

        # No Tensor or TensorList under the inputs key
        compile_spec = {"forward": {"inputs": 1}}
        with self.assertRaisesRegex(
            RuntimeError,
            'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
            + errorMsgTail,
        ):
            torch._C._jit_to_backend("nnapi", traced, compile_spec)
        compile_spec = {"forward": {"inputs": [1]}}
        with self.assertRaisesRegex(
            RuntimeError,
            'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
            + errorMsgTail,
        ):
            torch._C._jit_to_backend("nnapi", traced, compile_spec)

    def tearDown(self):
        # Change dtype back to default (Otherwise, other unit tests will complain)
        torch.set_default_dtype(self.default_dtype)
