# Owner(s): ["module: multiprocessing"]

import os
import pickle
import random
import signal
import sys
import time
import unittest

import torch.multiprocessing as mp

from torch.testing._internal.common_utils import (
    IS_WINDOWS,
    NO_MULTIPROCESSING_SPAWN,
    run_tests,
    TestCase,
)

def _test_success_func(i):
    pass


def _test_success_single_arg_func(i, arg):
    if arg:
        arg.put(i)


def _test_exception_single_func(i, arg):
    if i == arg:
        raise ValueError("legitimate exception from process %d" % i)
    time.sleep(1.0)


def _test_exception_all_func(i):
    time.sleep(random.random() / 10)
    raise ValueError("legitimate exception from process %d" % i)


def _test_terminate_signal_func(i):
    if i == 0:
        os.kill(os.getpid(), signal.SIGABRT)
    time.sleep(1.0)


def _test_terminate_exit_func(i, arg):
    if i == 0:
        sys.exit(arg)
    time.sleep(1.0)


def _test_success_first_then_exception_func(i, arg):
    if i == 0:
        return
    time.sleep(0.1)
    raise ValueError("legitimate exception")


def _test_nested_child_body(i, ready_queue, nested_child_sleep):
    ready_queue.put(None)
    time.sleep(nested_child_sleep)


def _test_infinite_task(i):
    while True:
        time.sleep(1)


def _test_process_exit(idx):
    sys.exit(12)


def _test_nested(i, pids_queue, nested_child_sleep, start_method):
    context = mp.get_context(start_method)
    nested_child_ready_queue = context.Queue()
    nprocs = 2
    mp_context = mp.start_processes(
        fn=_test_nested_child_body,
        args=(nested_child_ready_queue, nested_child_sleep),
        nprocs=nprocs,
        join=False,
        daemon=False,
        start_method=start_method,
    )
    pids_queue.put(mp_context.pids())

    # Wait for both children to have started, to ensure that they
    # have called prctl(2) to register a parent death signal.
    for _ in range(nprocs):
        nested_child_ready_queue.get()

    # Kill self. This should take down the child processes as well.
    os.kill(os.getpid(), signal.SIGTERM)

class _TestMultiProcessing:
    start_method = None

    def test_success(self):
        mp.start_processes(_test_success_func, nprocs=2, start_method=self.start_method)

    def test_success_non_blocking(self):
        mp_context = mp.start_processes(_test_success_func, nprocs=2, join=False, start_method=self.start_method)

        # After all processes (nproc=2) have joined it must return True
        mp_context.join(timeout=None)
        mp_context.join(timeout=None)
        self.assertTrue(mp_context.join(timeout=None))

    def test_first_argument_index(self):
        context = mp.get_context(self.start_method)
        queue = context.SimpleQueue()
        mp.start_processes(_test_success_single_arg_func, args=(queue,), nprocs=2, start_method=self.start_method)
        self.assertEqual([0, 1], sorted([queue.get(), queue.get()]))

    def test_exception_single(self):
        nprocs = 2
        for i in range(nprocs):
            with self.assertRaisesRegex(
                Exception,
                "\nValueError: legitimate exception from process %d$" % i,
            ):
                mp.start_processes(_test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method)

    def test_exception_all(self):
        with self.assertRaisesRegex(
            Exception,
            "\nValueError: legitimate exception from process (0|1)$",
        ):
            mp.start_processes(_test_exception_all_func, nprocs=2, start_method=self.start_method)

    def test_terminate_signal(self):
        # SIGABRT is aliased with SIGIOT
        message = "process 0 terminated with signal (SIGABRT|SIGIOT)"

        # Termination through with signal is expressed as a negative exit code
        # in multiprocessing, so we know it was a signal that caused the exit.
        # This doesn't appear to exist on Windows, where the exit code is always
        # positive, and therefore results in a different exception message.
        # Exit code 22 means "ERROR_BAD_COMMAND".
        if IS_WINDOWS:
            message = "process 0 terminated with exit code 22"

        with self.assertRaisesRegex(Exception, message):
            mp.start_processes(_test_terminate_signal_func, nprocs=2, start_method=self.start_method)

    def test_terminate_exit(self):
        exitcode = 123
        with self.assertRaisesRegex(
            Exception,
            "process 0 terminated with exit code %d" % exitcode,
        ):
            mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method)

    def test_success_first_then_exception(self):
        exitcode = 123
        with self.assertRaisesRegex(
            Exception,
            "ValueError: legitimate exception",
        ):
            mp.start_processes(_test_success_first_then_exception_func, args=(exitcode,), nprocs=2, start_method=self.start_method)

    @unittest.skipIf(
        sys.platform != "linux",
        "Only runs on Linux; requires prctl(2)",
    )
    def _test_nested(self):
        context = mp.get_context(self.start_method)
        pids_queue = context.Queue()
        nested_child_sleep = 20.0
        mp_context = mp.start_processes(
            fn=_test_nested,
            args=(pids_queue, nested_child_sleep, self.start_method),
            nprocs=1,
            join=False,
            daemon=False,
            start_method=self.start_method,
        )

        # Wait for nested children to terminate in time
        pids = pids_queue.get()
        start = time.time()
        while len(pids) > 0:
            for pid in pids:
                try:
                    os.kill(pid, 0)
                except ProcessLookupError:
                    pids.remove(pid)
                    break

            # This assert fails if any nested child process is still
            # alive after (nested_child_sleep / 2) seconds. By
            # extension, this test times out with an assertion error
            # after (nested_child_sleep / 2) seconds.
            self.assertLess(time.time() - start, nested_child_sleep / 2)
            time.sleep(0.1)

@unittest.skipIf(
    NO_MULTIPROCESSING_SPAWN,
    "Disabled for environments that don't support the spawn start method")
class SpawnTest(TestCase, _TestMultiProcessing):
    start_method = 'spawn'

    def test_exception_raises(self):
        with self.assertRaises(mp.ProcessRaisedException):
            mp.spawn(_test_success_first_then_exception_func, args=(), nprocs=1)

    def test_signal_raises(self):
        context = mp.spawn(_test_infinite_task, args=(), nprocs=1, join=False)
        for pid in context.pids():
            os.kill(pid, signal.SIGTERM)
        with self.assertRaises(mp.ProcessExitedException):
            context.join()

    def _test_process_exited(self):
        with self.assertRaises(mp.ProcessExitedException) as e:
            mp.spawn(_test_process_exit, args=(), nprocs=1)
            self.assertEqual(12, e.exit_code)


@unittest.skipIf(
    IS_WINDOWS,
    "Fork is only available on Unix",
)
class ForkTest(TestCase, _TestMultiProcessing):
    start_method = 'fork'


@unittest.skipIf(
    IS_WINDOWS,
    "Fork is only available on Unix",
)
class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing):
    orig_paralell_env_val = None

    def setUp(self):
        super().setUp()
        self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
        os.environ[mp.ENV_VAR_PARALLEL_START] = "1"

    def tearDown(self):
        super().tearDown()
        if self.orig_paralell_env_val is None:
            del os.environ[mp.ENV_VAR_PARALLEL_START]
        else:
            os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val


@unittest.skipIf(
    IS_WINDOWS,
    "Fork is only available on Unix",
)
class ParallelForkServerPerfTest(TestCase):

    def test_forkserver_perf(self):

        start_method = 'forkserver'
        expensive = Expensive()
        nprocs = 4
        orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)

        # test the non parallel case
        os.environ[mp.ENV_VAR_PARALLEL_START] = "0"
        start = time.perf_counter()
        mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
        elapsed = time.perf_counter() - start
        # the elapsed time should be at least {nprocs}x the sleep time
        self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs)

        # test the parallel case
        os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
        start = time.perf_counter()
        mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
        elapsed = time.perf_counter() - start
        # the elapsed time should be less than {nprocs}x the sleep time
        self.assertLess(elapsed, Expensive.SLEEP_SECS * nprocs)

        if orig_paralell_env_val is None:
            del os.environ[mp.ENV_VAR_PARALLEL_START]
        else:
            os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val


class Expensive:
    SLEEP_SECS = 5
    # Simulate startup overhead such as large imports
    time.sleep(SLEEP_SECS)

    def __init__(self):
        self.config: str = "*" * 1000000

    def my_call(self, *args):
        pass


class ErrorTest(TestCase):
    def test_errors_pickleable(self):
        for error in (
            mp.ProcessRaisedException("Oh no!", 1, 1),
            mp.ProcessExitedException("Oh no!", 1, 1, 1),
        ):
            pickle.loads(pickle.dumps(error))


if __name__ == '__main__':
    run_tests()
