#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>

#include <ATen/core/function.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/serialization/python_print.h>

namespace torch::jit {

class JitLoggingConfig {
 public:
  static JitLoggingConfig& getInstance() {
    static JitLoggingConfig instance;
    return instance;
  }
  JitLoggingConfig(JitLoggingConfig const&) = delete;
  void operator=(JitLoggingConfig const&) = delete;

 private:
  std::string logging_levels;
  std::unordered_map<std::string, size_t> files_to_levels;
  std::ostream* out;

  JitLoggingConfig() : out(&std::cerr) {
    const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
    logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level);

    parse();
  }
  void parse();

 public:
  std::string getLoggingLevels() const {
    return this->logging_levels;
  }
  void setLoggingLevels(std::string levels) {
    this->logging_levels = std::move(levels);
    parse();
  }

  const std::unordered_map<std::string, size_t>& getFilesToLevels() const {
    return this->files_to_levels;
  }

  void setOutputStream(std::ostream& out_stream) {
    this->out = &out_stream;
  }

  std::ostream& getOutputStream() {
    return *(this->out);
  }
};

std::string get_jit_logging_levels() {
  return JitLoggingConfig::getInstance().getLoggingLevels();
}

void set_jit_logging_levels(std::string level) {
  JitLoggingConfig::getInstance().setLoggingLevels(std::move(level));
}

void set_jit_logging_output_stream(std::ostream& stream) {
  JitLoggingConfig::getInstance().setOutputStream(stream);
}

std::ostream& get_jit_logging_output_stream() {
  return JitLoggingConfig::getInstance().getOutputStream();
}

// gets a string representation of a node header
// (e.g. outputs, a node kind and outputs)
std::string getHeader(const Node* node) {
  std::stringstream ss;
  node->print(ss, 0, {}, false, false, false, false);
  return ss.str();
}

void JitLoggingConfig::parse() {
  std::stringstream in_ss;
  in_ss << "function:" << this->logging_levels;

  files_to_levels.clear();
  std::string line;
  while (std::getline(in_ss, line, ':')) {
    if (line.empty()) {
      continue;
    }

    auto index_at = line.find_last_of('>');
    auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
    size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
    auto end_index = line.find_last_of('.') == std::string::npos
        ? line.size()
        : line.find_last_of('.');
    auto filename = line.substr(begin_index, end_index - begin_index);
    files_to_levels.insert({filename, logging_level});
  }
}

bool is_enabled(const char* cfname, JitLoggingLevels level) {
  const auto& files_to_levels =
      JitLoggingConfig::getInstance().getFilesToLevels();
  std::string fname{cfname};
  fname = c10::detail::StripBasename(fname);
  const auto end_index = fname.find_last_of('.') == std::string::npos
      ? fname.size()
      : fname.find_last_of('.');
  const auto fname_no_ext = fname.substr(0, end_index);

  const auto it = files_to_levels.find(fname_no_ext);
  if (it == files_to_levels.end()) {
    return false;
  }

  return level <= static_cast<JitLoggingLevels>(it->second);
}

// Unfortunately, in `GraphExecutor` where `log_function` is invoked
// we won't have access to an original function, so we have to construct
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
  torch::jit::GraphFunction func("source_dump", graph, nullptr);
  std::vector<at::IValue> constants;
  PrintDepsTable deps;
  PythonPrint pp(constants, deps);
  pp.printFunction(func);
  return pp.str();
}

std::string jit_log_prefix(
    const std::string& prefix,
    const std::string& in_str) {
  std::stringstream in_ss(in_str);
  std::stringstream out_ss;
  std::string line;
  while (std::getline(in_ss, line)) {
    out_ss << prefix << line << '\n';
  }

  return out_ss.str();
}

std::string jit_log_prefix(
    JitLoggingLevels level,
    const char* fn,
    int l,
    const std::string& in_str) {
  std::stringstream prefix_ss;
  prefix_ss << "[";
  prefix_ss << level << " ";
  prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
  prefix_ss << std::setfill('0') << std::setw(3) << l;
  prefix_ss << "] ";

  return jit_log_prefix(prefix_ss.str(), in_str);
}

std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
  switch (level) {
    case JitLoggingLevels::GRAPH_DUMP:
      out << "DUMP";
      break;
    case JitLoggingLevels::GRAPH_UPDATE:
      out << "UPDATE";
      break;
    case JitLoggingLevels::GRAPH_DEBUG:
      out << "DEBUG";
      break;
    default:
      TORCH_INTERNAL_ASSERT(false, "Invalid level");
  }

  return out;
}

} // namespace torch::jit
