# mypy: allow-untyped-defs
# Owner(s): ["module: unknown"]

import threading
import time
import torch
import unittest
from torch.futures import Future
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
from typing import TypeVar

T = TypeVar("T")


def add_one(fut):
    return fut.wait() + 1


class TestFuture(TestCase):
    def test_set_exception(self) -> None:
        # This test is to ensure errors can propagate across futures.
        error_msg = "Intentional Value Error"
        value_error = ValueError(error_msg)

        f = Future[T]()  # type: ignore[valid-type]
        # Set exception
        f.set_exception(value_error)
        # Exception should throw on wait
        with self.assertRaisesRegex(ValueError, "Intentional"):
            f.wait()

        # Exception should also throw on value
        f = Future[T]()  # type: ignore[valid-type]
        f.set_exception(value_error)
        with self.assertRaisesRegex(ValueError, "Intentional"):
            f.value()

        def cb(fut):
            fut.value()

        f = Future[T]()  # type: ignore[valid-type]
        f.set_exception(value_error)

        with self.assertRaisesRegex(RuntimeError, "Got the following error"):
            cb_fut = f.then(cb)
            cb_fut.wait()

    def test_set_exception_multithreading(self) -> None:
        # Ensure errors can propagate when one thread waits on future result
        # and the other sets it with an error.
        error_msg = "Intentional Value Error"
        value_error = ValueError(error_msg)

        def wait_future(f):
            with self.assertRaisesRegex(ValueError, "Intentional"):
                f.wait()

        f = Future[T]()  # type: ignore[valid-type]
        t = threading.Thread(target=wait_future, args=(f, ))
        t.start()
        f.set_exception(value_error)
        t.join()

        def cb(fut):
            fut.value()

        def then_future(f):
            fut = f.then(cb)
            with self.assertRaisesRegex(RuntimeError, "Got the following error"):
                fut.wait()

        f = Future[T]()  # type: ignore[valid-type]
        t = threading.Thread(target=then_future, args=(f, ))
        t.start()
        f.set_exception(value_error)
        t.join()

    def test_done(self) -> None:
        f = Future[torch.Tensor]()
        self.assertFalse(f.done())

        f.set_result(torch.ones(2, 2))
        self.assertTrue(f.done())

    def test_done_exception(self) -> None:
        err_msg = "Intentional Value Error"

        def raise_exception(unused_future):
            raise RuntimeError(err_msg)

        f1 = Future[torch.Tensor]()
        self.assertFalse(f1.done())
        f1.set_result(torch.ones(2, 2))
        self.assertTrue(f1.done())

        f2 = f1.then(raise_exception)
        self.assertTrue(f2.done())
        with self.assertRaisesRegex(RuntimeError, err_msg):
            f2.wait()

    def test_wait(self) -> None:
        f = Future[torch.Tensor]()
        f.set_result(torch.ones(2, 2))

        self.assertEqual(f.wait(), torch.ones(2, 2))

    def test_wait_multi_thread(self) -> None:

        def slow_set_future(fut, value):
            time.sleep(0.5)
            fut.set_result(value)

        f = Future[torch.Tensor]()

        t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
        t.start()

        self.assertEqual(f.wait(), torch.ones(2, 2))
        t.join()

    def test_mark_future_twice(self) -> None:
        fut = Future[int]()
        fut.set_result(1)
        with self.assertRaisesRegex(
            RuntimeError,
            "Future can only be marked completed once"
        ):
            fut.set_result(1)

    def test_pickle_future(self):
        fut = Future[int]()
        errMsg = "Can not pickle torch.futures.Future"
        with TemporaryFileName() as fname:
            with self.assertRaisesRegex(RuntimeError, errMsg):
                torch.save(fut, fname)

    def test_then(self):
        fut = Future[torch.Tensor]()
        then_fut = fut.then(lambda x: x.wait() + 1)

        fut.set_result(torch.ones(2, 2))
        self.assertEqual(fut.wait(), torch.ones(2, 2))
        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)

    def test_chained_then(self):
        fut = Future[torch.Tensor]()
        futs = []
        last_fut = fut
        for _ in range(20):
            last_fut = last_fut.then(add_one)
            futs.append(last_fut)

        fut.set_result(torch.ones(2, 2))

        for i in range(len(futs)):
            self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)

    def _test_then_error(self, cb, errMsg):
        fut = Future[int]()
        then_fut = fut.then(cb)

        fut.set_result(5)
        self.assertEqual(5, fut.wait())
        with self.assertRaisesRegex(RuntimeError, errMsg):
            then_fut.wait()

    def test_then_wrong_arg(self):

        def wrong_arg(tensor):
            return tensor + 1

        self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int")

    def test_then_no_arg(self):

        def no_arg():
            return True

        self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given")

    def test_then_raise(self):

        def raise_value_error(fut):
            raise ValueError("Expected error")

        self._test_then_error(raise_value_error, "Expected error")

    def test_add_done_callback_simple(self):
        callback_result = False

        def callback(fut):
            nonlocal callback_result
            fut.wait()
            callback_result = True

        fut = Future[torch.Tensor]()
        fut.add_done_callback(callback)

        self.assertFalse(callback_result)
        fut.set_result(torch.ones(2, 2))
        self.assertEqual(fut.wait(), torch.ones(2, 2))
        self.assertTrue(callback_result)

    def test_add_done_callback_maintains_callback_order(self):
        callback_result = 0

        def callback_set1(fut):
            nonlocal callback_result
            fut.wait()
            callback_result = 1

        def callback_set2(fut):
            nonlocal callback_result
            fut.wait()
            callback_result = 2

        fut = Future[torch.Tensor]()
        fut.add_done_callback(callback_set1)
        fut.add_done_callback(callback_set2)

        fut.set_result(torch.ones(2, 2))
        self.assertEqual(fut.wait(), torch.ones(2, 2))
        # set2 called last, callback_result = 2
        self.assertEqual(callback_result, 2)

    def _test_add_done_callback_error_ignored(self, cb):
        fut = Future[int]()
        fut.add_done_callback(cb)

        fut.set_result(5)
        # error msg logged to stdout
        self.assertEqual(5, fut.wait())

    def test_add_done_callback_error_is_ignored(self):

        def raise_value_error(fut):
            raise ValueError("Expected error")

        self._test_add_done_callback_error_ignored(raise_value_error)

    def test_add_done_callback_no_arg_error_is_ignored(self):

        def no_arg():
            return True

        # Adding another level of function indirection here on purpose.
        # Otherwise mypy will pick up on no_arg having an incompatible type and fail CI
        self._test_add_done_callback_error_ignored(no_arg)

    def test_interleaving_then_and_add_done_callback_maintains_callback_order(self):
        callback_result = 0

        def callback_set1(fut):
            nonlocal callback_result
            fut.wait()
            callback_result = 1

        def callback_set2(fut):
            nonlocal callback_result
            fut.wait()
            callback_result = 2

        def callback_then(fut):
            nonlocal callback_result
            return fut.wait() + callback_result

        fut = Future[torch.Tensor]()
        fut.add_done_callback(callback_set1)
        then_fut = fut.then(callback_then)
        fut.add_done_callback(callback_set2)

        fut.set_result(torch.ones(2, 2))
        self.assertEqual(fut.wait(), torch.ones(2, 2))
        # then_fut's callback is called with callback_result = 1
        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
        # set2 called last, callback_result = 2
        self.assertEqual(callback_result, 2)

    def test_interleaving_then_and_add_done_callback_propagates_error(self):
        def raise_value_error(fut):
            raise ValueError("Expected error")

        fut = Future[torch.Tensor]()
        then_fut = fut.then(raise_value_error)
        fut.add_done_callback(raise_value_error)
        fut.set_result(torch.ones(2, 2))

        # error from add_done_callback's callback is swallowed
        # error from then's callback is not
        self.assertEqual(fut.wait(), torch.ones(2, 2))
        with self.assertRaisesRegex(RuntimeError, "Expected error"):
            then_fut.wait()

    def test_collect_all(self):
        fut1 = Future[int]()
        fut2 = Future[int]()
        fut_all = torch.futures.collect_all([fut1, fut2])

        def slow_in_thread(fut, value):
            time.sleep(0.1)
            fut.set_result(value)

        t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
        fut2.set_result(2)
        t.start()

        res = fut_all.wait()
        self.assertEqual(res[0].wait(), 1)
        self.assertEqual(res[1].wait(), 2)
        t.join()

    @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
    def test_wait_all(self):
        fut1 = Future[int]()
        fut2 = Future[int]()

        # No error version
        fut1.set_result(1)
        fut2.set_result(2)
        res = torch.futures.wait_all([fut1, fut2])
        print(res)
        self.assertEqual(res, [1, 2])

        # Version with an exception
        def raise_in_fut(fut):
            raise ValueError("Expected error")
        fut3 = fut1.then(raise_in_fut)
        with self.assertRaisesRegex(RuntimeError, "Expected error"):
            torch.futures.wait_all([fut3, fut2])

    def test_wait_none(self):
        fut1 = Future[int]()
        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
            torch.jit.wait(None)
        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
            torch.futures.wait_all((None,))  # type: ignore[arg-type]
        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
            torch.futures.collect_all((fut1, None,))  # type: ignore[arg-type]

if __name__ == '__main__':
    run_tests()
