/*
 * Copyright (C) 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "Subprocess.h"

#include <dlfcn.h>
#include <poll.h>
#include <string.h>
#include <sys/prctl.h>
#include <sys/wait.h>
#include <unistd.h>

#include <condition_variable>
#include <mutex>
#include <thread>

namespace gfxstream {
namespace {

template <typename F>
class ScopedCloser {
  public:
    constexpr ScopedCloser(F&& func) : mFunc(std::forward<F>(func)), mEnabled(true) {}

    ~ScopedCloser() {
        if (mEnabled) {
            mFunc();
        }
    }

    void Disable() { mEnabled = false; }

  private:
    F mFunc;
    bool mEnabled = false;
};


int PidfdOpen(pid_t pid) {
  // There is no glibc wrapper for pidfd_open.
#ifndef SYS_pidfd_open
  constexpr int SYS_pidfd_open = 434;
#endif
  return syscall(SYS_pidfd_open, pid, /*flags=*/0);
}

gfxstream::expected<Ok, std::string> WaitForChild(pid_t pid) {
    siginfo_t info;
    if (TEMP_FAILURE_RETRY(waitid(P_PID, pid, &info, WEXITED | WNOWAIT)) != 0) {
        return gfxstream::unexpected("Error from waitid(): " +
                                     std::string(strerror(errno)));
    }
    if (info.si_pid != pid) {
        return gfxstream::unexpected("Error from waitid(): returned different pid.");
    }
    if (info.si_code != CLD_EXITED) {
        return gfxstream::unexpected("Failed to wait for subprocess: terminated by signal " +
                                     std::to_string(info.si_status));
    }
    return Ok{};
}

// When `pidfd_open` is not available, fallback to using a second
// thread to kill the child process after the given timeout.
gfxstream::expected<Ok, std::string> WaitForChildWithTimeoutFallback(
        pid_t pid, std::chrono::milliseconds timeout) {
    bool childExited = false;
    bool childTimedOut = false;
    std::condition_variable cv;
    std::mutex m;

    std::thread wait_thread([&]() {
        std::unique_lock<std::mutex> lock(m);
        if (!cv.wait_for(lock, timeout, [&] { return childExited; })) {
            childTimedOut = true;
            kill(pid, SIGKILL);
        }
    });

    auto result = WaitForChild(pid);
    {
        std::unique_lock<std::mutex> lock(m);
        childExited = true;
    }
    cv.notify_all();
    wait_thread.join();

    if (childTimedOut) {
        return gfxstream::unexpected("Failed to wait for subprocess: timed out.");
    }
    return result;
}

gfxstream::expected<Ok, std::string> WaitForChildWithTimeout(
        pid_t pid,
        int pidfd,
        std::chrono::milliseconds timeout) {
    ScopedCloser cleanup([&]() {
        kill(pid, SIGKILL);
        WaitForChild(pid);
    });

    struct pollfd poll_info = {
        .fd = pidfd,
        .events = POLLIN,
    };
    int ret = TEMP_FAILURE_RETRY(poll(&poll_info, 1, timeout.count()));
    close(pidfd);

    if (ret < 0) {
        return gfxstream::unexpected("Failed to wait for subprocess: poll() returned " +
                                     std::to_string(ret));
    }
    if (ret == 0) {
        return gfxstream::unexpected("Failed to wait for subprocess: subprocess did not "
                                     "finished within " + std::to_string(timeout.count()) +
                                     "ms.");
    }

    cleanup.Disable();
    return WaitForChild(pid);
}

}  // namespace

gfxstream::expected<Ok, std::string> DoWithSubprocessCheck(
        const std::function<gfxstream::expected<Ok, std::string>()>& function,
        std::chrono::milliseconds timeout) {
    pid_t pid = fork();
    if (pid == 0) {
        function();
        _exit(0);
    }

    int pidfd = PidfdOpen(pid);
    if (pidfd >= 0) {
        GFXSTREAM_EXPECT(WaitForChildWithTimeout(pid, pidfd, timeout));
    } else {
        GFXSTREAM_EXPECT(WaitForChildWithTimeoutFallback(pid, timeout));
    }

    return function();
}

}  // namespace gfxstream