/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <executorch/extension/threadpool/threadpool.h>

#include <algorithm>
#include <atomic>
#include <memory>

#include <executorch/extension/threadpool/threadpool_guard.h>
#include <executorch/runtime/platform/assert.h>

#include <cpuinfo.h>

namespace executorch::extension::threadpool {

#if !(defined(WIN32))
namespace {
// After fork, the child process inherits the data-structures of the parent
// process' thread-pool, but since those threads don't exist, the thread-pool
// is corrupt. It's leaked in order to prevent segfaults.
// Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
bool leak_corrupted_threadpool = false;

void child_atfork() {
  leak_corrupted_threadpool = true;
}

} // namespace
#endif

ThreadPool::ThreadPool(size_t thread_count)
    : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}

size_t ThreadPool::get_thread_count() const {
  std::lock_guard<std::mutex> lock{mutex_};

  ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
  return pthreadpool_get_threads_count(threadpool_.get());
}

bool ThreadPool::_unsafe_reset_threadpool(uint32_t new_thread_count) {
  // No need to do anything if the count is same or 0
  if (new_thread_count == get_thread_count() || new_thread_count == 0) {
    return true;
  }

  std::lock_guard<std::mutex> lock{mutex_};

  threadpool_.reset(pthreadpool_create(new_thread_count));
  return true;
}

void ThreadPool::run(
    const std::function<void(size_t)>& fn,
    const size_t range) {
  // Run on same thread if NoThreadPoolGuard guard is enabled
  if (NoThreadPoolGuard::is_enabled()) {
    for (size_t i = 0; i < range; ++i) {
      fn(i);
    }
    return;
  }

  std::lock_guard<std::mutex> lock{mutex_};

  ET_CHECK_MSG(!NoThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
  ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");

  struct Context final {
    const std::function<void(size_t)>& fn;
  } context{
      fn,
  };

  pthreadpool_parallelize_1d(
      threadpool_.get(),
      // Note: pthreadpool_parallelize_1d() is a blocking function.  The
      // function pointer to this lambda passed on to
      // pthreadpool_parallelize_1d() cannot go out of scope until
      // pthreadpool_parallelize_1d() returns.
      [](void* const context, const size_t item) {
        NoThreadPoolGuard guard;
        reinterpret_cast<Context*>(context)->fn(item);
      },
      &context,
      range,
      0u);
}

// get_threadpool is not thread safe due to leak_corrupted_threadpool
// Make this part threadsafe: TODO(kimishpatel)
ThreadPool* get_threadpool() {
  ET_CHECK_MSG(cpuinfo_initialize(), "cpuinfo initialization failed");
  int num_threads = cpuinfo_get_processors_count();
  /*
   * For llvm-tsan, holding limit for the number of locks for a single thread
   * is 63 (because of comparison < 64 instead of <=). pthreadpool's worst
   * case is the number of threads in a pool. So we want to limit the threadpool
   * size to 64 when running with tsan. However, sometimes it is tricky to
   * detect if we are running under tsan, for now capping the default
   * threadcount to the tsan limit unconditionally.
   */
  constexpr int tsan_thread_limit = 63;
  num_threads = std::min(num_threads, tsan_thread_limit);
  static auto threadpool = std::make_unique<ThreadPool>(num_threads);

// Inheriting from old threadpool to get around segfault issue
// commented above at child_atfork
#if !(defined(WIN32))
  // @lint-ignore CLANGTIDY facebook-hte-std::once_flag
  static std::once_flag flag;
  // @lint-ignore CLANGTIDY facebook-hte-std::call_once
  std::call_once(
      flag, []() { pthread_atfork(nullptr, nullptr, child_atfork); });
  if ET_UNLIKELY (leak_corrupted_threadpool) {
    leak_corrupted_threadpool = false;
    if (auto leaked = threadpool.release()) {
      auto t = leaked->get_thread_count();
      threadpool = std::make_unique<ThreadPool>(t);
    }
  }
#endif
  return threadpool.get();
}

pthreadpool_t get_pthreadpool() {
  if (NoThreadPoolGuard::is_enabled()) {
    return nullptr;
  }
  ThreadPool* const threadpool = get_threadpool();
  ET_CHECK_MSG(threadpool, "Failed to acquire an instance of ThreadPool!");
  return threadpool->threadpool_.get();
}

} // namespace executorch::extension::threadpool
