# Owner(s): ["module: unknown"]

import os
import tempfile

from backend import get_custom_backend_library_path, Model, to_custom_backend

import torch
from torch.testing._internal.common_utils import run_tests, TestCase


class TestCustomBackend(TestCase):
    def setUp(self):
        # Load the library containing the custom backend.
        self.library_path = get_custom_backend_library_path()
        torch.ops.load_library(self.library_path)
        # Create an instance of the test Module and lower it for
        # the custom backend.
        self.model = to_custom_backend(torch.jit.script(Model()))

    def test_execute(self):
        """
        Test execution using the custom backend.
        """
        a = torch.randn(4)
        b = torch.randn(4)
        # The custom backend is hardcoded to compute f(a, b) = (a + b, a - b).
        expected = (a + b, a - b)
        out = self.model(a, b)
        self.assertTrue(expected[0].allclose(out[0]))
        self.assertTrue(expected[1].allclose(out[1]))

    def test_save_load(self):
        """
        Test that a lowered module can be executed correctly
        after saving and loading.
        """
        # Test execution before saving and loading to make sure
        # the lowered module works in the first place.
        self.test_execute()

        # Save and load.
        f = tempfile.NamedTemporaryFile(delete=False)
        try:
            f.close()
            torch.jit.save(self.model, f.name)
            loaded = torch.jit.load(f.name)
        finally:
            os.unlink(f.name)
        self.model = loaded

        # Test execution again.
        self.test_execute()


if __name__ == "__main__":
    run_tests()
