#include <torch/csrc/DataLoader.h>

// Together with `torch/utils/data/_utils/signal_handling.py`, the following
// is an effort to do our best to provide some error message to users when a
// worker dies due to error / critical signals.
//
// See NOTE [ Signal handling in multiprocessing data loading ] for more
// details.

// TODO: The following don't work on Windows. Specifically, sigaction, waitid
// calls, and SIGCHLD handler. Currently, dummy implementations are provided
// for Windows.

#ifndef _WIN32

#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/python_numbers.h>

#include <c10/util/irange.h>
#include <fmt/format.h>

#include <sys/wait.h>
#include <csignal>
#include <map>
#include <set>
#include <sstream>

using namespace torch;

// Critical signal handlers should be registered on worker processes before
// doing work.
// The handler will raise default handler so that the kill information will be
// retrieved from main process.
// Python handle is _set_worker_signal_handlers().
#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG)                    \
  static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) {          \
    auto _w =                                                              \
        write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
    (void)_w;                                                              \
    struct sigaction sa {};                                                \
    sa.sa_handler = SIG_DFL;                                               \
    sa.sa_flags = 0;                                                       \
    if (sigemptyset(&sa.sa_mask) != 0 ||                                   \
        sigaction(SIGNAL, &sa, nullptr) != 0) {                            \
      _exit(EXIT_FAILURE);                                                 \
    } else {                                                               \
      raise(SIGNAL);                                                       \
    }                                                                      \
  }

// signal(2) is really not portable. So use sigaction.
// http://man7.org/linux/man-pages/man2/signal.2.html
static inline void setSignalHandler(
    int signal,
    void (*handler)(int, siginfo_t*, void*),
    struct sigaction* old_sa_ptr) {
  struct sigaction sa {};
  sa.sa_sigaction = handler;
  sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
  if (sigemptyset(&sa.sa_mask) != 0 ||
      sigaction(signal, &sa, old_sa_ptr) != 0) {
    std::ostringstream oss;
    oss << "An error occurred while setting handler for " << strsignal(signal)
        << ".";
    throw std::runtime_error(oss.str());
  }
}

SIGNAL_HANDLER(
    SIGBUS,
    handler_SIGBUS,
    "ERROR: Unexpected bus error encountered in worker. "
    "This might be caused by insufficient shared memory (shm).\n");
SIGNAL_HANDLER(
    SIGSEGV,
    handler_SIGSEGV,
    "ERROR: Unexpected segmentation fault encountered in worker.\n");
SIGNAL_HANDLER(
    SIGFPE,
    handler_SIGFPE,
    "ERROR: Unexpected floating-point exception encountered in worker.\n");

// When an error happened in DataLoader methods and Python starts to exit, the
// error trace will keep the loader alive, and Python may kill the children
// processes first before deleting the loader object. Then the cleaning up
// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
// again, and then it defeats the whole purpose.
static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) {
  if (info->si_pid == getppid()) {
    _exit(EXIT_SUCCESS);
  }
  struct sigaction sa {};
  sa.sa_handler = SIG_DFL;
  sa.sa_flags = 0;
  if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) {
    _exit(EXIT_FAILURE);
  } else {
    raise(SIGTERM);
  }
}

__attribute__((weak)) void setDataLoaderSignalHandlers() {}

static PyObject* THPModule_setWorkerSignalHandlers(
    PyObject* module,
    PyObject* arg) {
  HANDLE_TH_ERRORS
  setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
  setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
  setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
  setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
  setDataLoaderSignalHandlers();
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

static std::map<int64_t, std::set<pid_t>> worker_pids = {};

static PyObject* THPModule_errorIfAnyWorkerFails(
    PyObject* module,
    PyObject* noargs) {
  HANDLE_TH_ERRORS

  // Only check the pids we care about
  for (auto& w : worker_pids) {
    auto& pid_set = w.second;
    for (auto worker_pid : pid_set) {
      // Use waitid rather than waitpid so that we can set NOWAIT, and that
      // Python and other handlers can get whatever info they want about the
      // child.
      siginfo_t infop{};
      infop.si_pid = 0;
      auto error =
          waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
      // ignore errors and case with no waitable child
      if (error < 0 || infop.si_pid == 0)
        continue;
      if (infop.si_code == CLD_EXITED &&
          infop.si_status != EXIT_SUCCESS) { // exit with error
        std::ostringstream oss;
        oss << "DataLoader worker (pid " << worker_pid << ") exited "
            << "unexpectedly with exit code " << infop.si_status << ". "
            << "Details are lost due to multiprocessing. Rerunning with "
            << "num_workers=0 may give better error trace.";
        // This is necessary. Otherwise, the runtime error will kill the other
        // workers, and trigger this again.
        pid_set.clear();
        throw std::runtime_error(oss.str());
      } else if (
          infop.si_code == CLD_KILLED ||
          infop.si_code == CLD_DUMPED) { // killed by signal
        std::ostringstream oss;
        oss << "DataLoader worker (pid " << worker_pid << ") is killed "
            << "by signal: " << strsignal(infop.si_status) << ". ";
        if (infop.si_status == SIGBUS) {
          oss << "It is possible that dataloader's workers are out of shared memory. "
              << "Please try to raise your shared memory limit.";
        }
        // This is necessary. Otherwise, the runtime error will kill the other
        // workers, and trigger this again.
        pid_set.clear();
        throw std::runtime_error(oss.str());
      }
    }
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
// of pids we are interested in.
static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) {
  HANDLE_TH_ERRORS
  TORCH_CHECK_TYPE(
      PyTuple_GET_SIZE(args) == 2,
      "_set_worker_pids expects exactly 2 arguments.");
  int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
  TORCH_CHECK_VALUE(
      worker_pids.find(key) == worker_pids.end(),
      "_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
  PyObject* child_pids = PyTuple_GET_ITEM(args, 1);
  TORCH_CHECK_TYPE(
      PyTuple_Check(child_pids),
      "_set_worker_pids expects a tuple for child_pids, but got ",
      Py_TYPE(child_pids)->tp_name,
      ".");
  std::set<pid_t> pids_set = {};
  auto size = PyTuple_GET_SIZE(child_pids);
  for (const auto idx : c10::irange(size)) {
    PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
    pids_set.insert(static_cast<pid_t>(THPUtils_unpackLong(obj)));
  }

  worker_pids[key] = pids_set;

  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

static PyObject* THPModule_removeWorkerPIDs(
    PyObject* module,
    PyObject* loader_id) {
  HANDLE_TH_ERRORS

  int64_t key = THPUtils_unpackLong(loader_id);
  auto it = worker_pids.find(key);
  TORCH_CHECK_VALUE(
      it != worker_pids.end(),
      "Cannot find worker information for _BaseDataLoaderIter with id ",
      key);
  worker_pids.erase(it);

  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

#undef SIGNAL_HANDLER

#else
// dummy implementations for windows

static PyObject* THPModule_setWorkerSignalHandlers(
    PyObject* module,
    PyObject* _ignored) {
  Py_RETURN_NONE;
}

static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) {
  Py_RETURN_NONE;
}

static PyObject* THPModule_removeWorkerPIDs(
    PyObject* module,
    PyObject* _ignored) {
  Py_RETURN_NONE;
}

static PyObject* THPModule_errorIfAnyWorkerFails(
    PyObject* module,
    PyObject* _ignored) {
  Py_RETURN_NONE;
}

#endif

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
PyMethodDef DataLoaderMethods[] = {
    {"_set_worker_signal_handlers",
     THPModule_setWorkerSignalHandlers,
     METH_NOARGS,
     nullptr},
    {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr},
    {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr},
    {"_error_if_any_worker_fails",
     THPModule_errorIfAnyWorkerFails,
     METH_NOARGS,
     nullptr},
    {nullptr, nullptr, 0, nullptr}};
