#pragma once
#include <c10/util/Exception.h>

#include <mutex>
#include <vector>

namespace torch::autograd::utils {

// Warning handler for multi-threaded contexts. Gather warnings from
// all threads into a single queue, then process together at the end
// in the main thread.
class DelayWarningHandler : public at::WarningHandler {
 public:
  ~DelayWarningHandler() override = default;
  void replay_warnings();

 private:
  void process(const c10::Warning& warning) override;

  std::vector<c10::Warning> warnings_;
  std::mutex mutex_;
};

} // namespace torch::autograd::utils
