//==-- llvm/Support/FileCheck.h ---------------------------*- C++ -*-==//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

// API modified from llvm::FileCheck

#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <c10/util/irange.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <optional>

#include <algorithm>
#include <iostream>
#include <sstream>
#include <string>

namespace torch {
namespace jit {

namespace testing {

enum CheckType {
  CHECK,
  CHECK_NEXT,
  CHECK_SAME,
  CHECK_NOT,
  CHECK_COUNT,
  CHECK_DAG,
  CHECK_SOURCE_HIGHLIGHTED,
  CHECK_REGEX,
};

struct Check {
  Check(
      CheckType type,
      std::string str,
      std::optional<size_t> count = std::nullopt)
      : type_(type), count_(count), search_str_(std::move(str)) {}

  Check(
      CheckType type,
      c10::string_view str,
      std::optional<size_t> count = std::nullopt)
      : Check(type, std::string(str.begin(), str.end()), count) {}

  CheckType type_;
  std::optional<size_t> count_;
  const std::string search_str_;

  friend std::ostream& operator<<(std::ostream& out, const Check& c);
};

std::ostream& operator<<(std::ostream& out, const Check& c) {
  switch (c.type_) {
    case CHECK:
      out << "CHECK";
      break;
    case CHECK_NEXT:
      out << "CHECK-NEXT";
      break;
    case CHECK_SAME:
      out << "CHECK-SAME";
      break;
    case CHECK_NOT:
      out << "CHECK-NOT";
      break;
    case CHECK_DAG:
      out << "CHECK-DAG";
      break;
    case CHECK_COUNT:
      out << "CHECK-COUNT-" << *c.count_;
      break;
    case CHECK_SOURCE_HIGHLIGHTED:
      out << "CHECK-SOURCE-HIGHLIGHTED";
      break;
    case CHECK_REGEX:
      out << "CHECK-REGEX";
      break;
  }
  out << ": " << c.search_str_;
  return out;
};

namespace {

size_t assertFind(
    const SourceRange& search_range,
    const std::string& sub,
    const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
  auto pos = search_range.source()->text_str().find(sub, search_range.start());
  if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
    auto found_range =
        SourceRange(search_range.source(), search_range.start(), sub.size());
    std::stringstream ss;
    ss << "Expected to find ";
    c10::printQuotedString(ss, sub);
    ss << " but did not find it" << std::endl;
    ss << "Searched string:" << std::endl;
    found_range.highlight(ss);
    if (extra_msg) {
      extra_msg(ss);
    }
    throw std::runtime_error(ss.str());
  }
  return pos;
}

size_t assertFind(
    const SourceRange& search_range,
    const std::string& sub,
    const Check& check) {
  return assertFind(search_range, sub, [&](std::ostream& out) {
    out << "From " << check << "\n";
  });
}

size_t assertFind(
    const std::shared_ptr<Source>& source,
    const std::string& sub,
    size_t start,
    const Check& check) {
  return assertFind(SourceRange(source, start, source->size()), sub, check);
}

size_t assertFindRegex(
    const SourceRange& search_range,
    const std::string& sub,
    const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
  auto pos =
      search_range.source()->text_str().find_regex(sub, search_range.start());

  if (pos == std::string::npos) {
    std::stringstream ss;
    ss << "Expected to find regex ";
    c10::printQuotedString(ss, sub);
    ss << " but did not find it" << std::endl;
    ss << "Searched string:" << std::endl;
    if (extra_msg) {
      extra_msg(ss);
    }
    throw std::runtime_error(ss.str());

    return std::string::npos;
  }
  return pos;
}

size_t assertFindRegex(
    const SourceRange& search_range,
    const std::string& sub,
    const Check& check) {
  return assertFindRegex(search_range, sub, [&](std::ostream& out) {
    out << "From " << check << "\n";
  });
}

size_t assertFindRegex(
    const std::shared_ptr<Source>& source,
    const std::string& sub,
    size_t start,
    const Check& check) {
  return assertFindRegex(
      SourceRange(source, start, source->size()), sub, check);
}

void assertNotFind(
    const SourceRange& search_range,
    const std::string& sub,
    const Check& check) {
  auto pos = search_range.source()->text_str().find(sub, search_range.start());
  if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
    auto found_range =
        SourceRange(search_range.source(), pos, sub.size() + pos);
    std::stringstream ss;
    ss << "Expected to not find ";
    c10::printQuotedString(ss, sub);
    ss << " but found it\n";
    found_range.highlight(ss);
    ss << "From " << check << "\n";
    throw std::runtime_error(ss.str());
  }
}

} // namespace

struct FileCheckImpl {
  TORCH_API explicit FileCheckImpl() = default;

  TORCH_API void run(const std::string& test_file) {
    has_run = true;

    if (groups.empty() || groups[0].empty()) {
      throw std::runtime_error(
          "No checks have been added to this instance of"
          "Filecheck! Check for bad input.");
    }

    doChecks(std::make_shared<Source>(test_file));
  }

  TORCH_API void run(
      const std::string& checks_file,
      const std::string& test_file) {
    auto source = std::make_shared<Source>(checks_file);
    parseStrings(source);
    run(test_file);
  }

  TORCH_API void addCheck(const Check& check) {
    // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
    if (groups.empty() ||
        (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) {
      groups.push_back({check});
    } else {
      auto& last_group = groups.back();
      if (last_group.at(0).type_ == check.type_) {
        last_group.push_back(check);
      } else {
        groups.push_back({check});
      }
    }
    has_run = false;
  }

  TORCH_API void addCheck(
      CheckType type,
      const std::string& s,
      std::optional<size_t> count = std::nullopt) {
    addCheck(Check(type, s, count));
  }

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  bool has_run = false;

  friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);

 private:
  bool parseSingleCheck(const std::shared_ptr<Source>& source, size_t* start) {
    const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
        {CHECK, ": "},
        {CHECK_NEXT, "-NEXT: "},
        {CHECK_SAME, "-SAME: "},
        {CHECK_NOT, "-NOT: "},
        {CHECK_DAG, "-DAG: "},
        {CHECK_COUNT, "-COUNT-"}, // needs special parsing
        {CHECK_SOURCE_HIGHLIGHTED, "-SOURCE-HIGHLIGHTED: "},
        {CHECK_REGEX, "-REGEX: "},
    };

    for (const auto& check_pair : check_pairs) {
      const std::string& check_suffix = check_pair.second;
      auto suffix_pos = source->text_str().find(check_suffix, *start);
      if (suffix_pos != *start) {
        continue;
      }
      size_t end_check_string = suffix_pos + check_suffix.size();
      CheckType type = check_pair.first;
      std::optional<size_t> count = std::nullopt;
      auto end_line = source->text_str().find("\n", end_check_string);
      bool exactly = false;
      if (type == CHECK_COUNT) {
        const std::string exact = "EXACTLY-";
        if (source->text_str().find(exact, end_check_string) ==
            end_check_string) {
          exactly = true;
          end_check_string += exact.size();
        }
        size_t end =
            assertFind(SourceRange(source, end_check_string, end_line), ":");
        auto count_view = source->text_str()
                              .substr(end_check_string, end - end_check_string)
                              .str();
        count = std::stoll(std::string(count_view.begin(), count_view.end()));
        end_check_string = end + 2; // add ':' and the space
      }
      auto check = Check(
          type,
          source->text_str()
              .substr(end_check_string, end_line - end_check_string)
              .str(),
          count);
      addCheck(check);
      if (exactly) {
        addCheck(CHECK_NOT, check.search_str_);
      }
      *start = end_line;
      return true;
    }
    return false;
  }

  size_t findNextStart(const std::shared_ptr<Source>& source, size_t prev_end) {
    size_t start = source->text_str().find("#", prev_end);
    if (start == std::string::npos) {
      return start;
    }
    start += 1;
    static constexpr size_t max_whitespace = 6;
    size_t i = 0;
    while (start + i < source->size() && i < max_whitespace) {
      auto c = source->char_at(start + i);
      if (c != ' ' && c != '\t') {
        break;
      }
      i++;
    }
    static const std::string check = "CHECK";
    if (source->text_str().substr(start + i, check.size()) == check) {
      return start + i + check.size();
    } else {
      return findNextStart(source, start + i + 1);
    }
  }

  void parseStrings(const std::shared_ptr<Source>& source) {
    size_t start = 0;
    start = findNextStart(source, 0);
    while (start != std::string::npos) {
      bool found_match = parseSingleCheck(source, &start);
      if (!found_match) {
        std::ostringstream ss;
        ss << "Could not parse check at:\n";
        SourceRange(source, start, start + 1).highlight(ss);
        ss << "Check for bad input.";
        has_run = true;
        throw std::runtime_error(ss.str());
      }
      start = findNextStart(source, start);
    }
  }

  void doCheckNot(
      const std::vector<Check>& nots,
      const std::shared_ptr<Source>& source,
      const SourceRange& prev,
      const SourceRange& next) {
    auto start = prev.end(); // inclusive
    auto end = next.start(); // exclusive
    if (end < start) {
      return;
    }
    for (const auto& check : nots) {
      AT_ASSERT(check.type_ == CHECK_NOT);
      assertNotFind(SourceRange(source, start, end), check.search_str_, check);
    }
  }

  // Checks that source token is highlighted, does not advance search range.
  void doCheckSourceHighlighted(
      const Check& check,
      const std::shared_ptr<Source>& source,
      size_t start_offset) {
    auto construct_error_and_throw = [&](size_t error_start_pos) {
      SourceRange error_range(
          source, error_start_pos, check.search_str_.size());
      std::stringstream ss;
      ss << "Expected to find ";
      c10::printQuotedString(ss, check.search_str_);
      ss << "highlighted but it is not." << std::endl;
      error_range.highlight(ss);
      throw std::runtime_error(ss.str());
    };

    size_t search_start_offset = start_offset;
    bool found_token_at_least_once = false;
    size_t pos = search_start_offset;
    while (pos < source->size()) {
      pos = source->text_str().find(check.search_str_, search_start_offset);
      if (pos == std::string::npos) {
        break;
      }

      found_token_at_least_once = true;

      auto lineno = source->lineno_for_offset(pos);
      auto col = pos - source->offset_for_line(lineno);
      auto highlight_lineno = lineno + 1;

      if (highlight_lineno >= source->num_lines()) {
        construct_error_and_throw(pos);
      }

      auto highlight_start_offset =
          source->offset_for_line(highlight_lineno) + col;
      auto highlight_end_offset = std::min(
          highlight_start_offset + check.search_str_.size(), source->size());

      if (highlight_end_offset >= source->size()) {
        construct_error_and_throw(pos);
      }

      bool found_highlight = true;
      for (const auto posi :
           c10::irange(highlight_start_offset, highlight_end_offset)) {
        if (source->char_at(posi) != '~') {
          found_highlight = false;
        }
      }

      if (found_highlight) {
        assertNotFind(
            SourceRange(
                source, highlight_start_offset - 1, highlight_start_offset),
            "~",
            check);
        assertNotFind(
            SourceRange(source, highlight_end_offset, highlight_end_offset + 1),
            "~",
            check);
        return;
      }

      search_start_offset = pos + 1;
    }

    if (!found_token_at_least_once) {
      // Guaranteed to fail to generate error message.
      assertFind(source, check.search_str_, start_offset, check);
    }

    construct_error_and_throw(start_offset);
  }

  SourceRange matchDagGroup(
      const std::vector<Check>& group,
      const std::shared_ptr<Source>& source,
      const SourceRange& prev) {
    size_t group_beg = std::string::npos;
    size_t group_end = 0;

    AT_ASSERT(!groups.empty());
    for (const auto& check : group) {
      AT_ASSERT(check.type_ == group[0].type_);
      auto pos = assertFind(source, check.search_str_, prev.end(), check);
      group_beg = std::min(pos, group_beg);
      group_end = std::max(pos + check.search_str_.size(), group_end);
    }

    return SourceRange(source, group_beg, group_end);
  }

  SourceRange matchGroup(
      const std::vector<Check>& group,
      const std::shared_ptr<Source>& source,
      const SourceRange& prev) {
    AT_ASSERT(!group.empty());
    CheckType type = group[0].type_;

    if (type == CHECK_DAG) {
      return matchDagGroup(group, source, prev);
    }
    AT_ASSERT(type != CHECK_NOT);
    AT_ASSERT(group.size() == 1);

    const auto& check = group[0];
    size_t start_range = prev.end();
    size_t end_range = start_range;

    switch (check.type_) {
      case CHECK: {
        start_range = assertFind(source, check.search_str_, start_range, check);
        end_range = start_range + check.search_str_.size();
      } break;
      case CHECK_SAME: {
        auto pos = assertFind(source, check.search_str_, start_range, check);
        assertNotFind(SourceRange(source, prev.end(), pos), "\n", check);
        start_range = pos;
        end_range = pos + check.search_str_.size();
      } break;
      case CHECK_NEXT: {
        auto line_end = assertFind(source, "\n", start_range, check);
        auto pos = assertFind(source, check.search_str_, line_end + 1, check);
        assertNotFind(SourceRange(source, line_end + 1, pos), "\n", check);
        start_range = pos;
        end_range = pos + check.search_str_.size();
      } break;
      case CHECK_COUNT: {
        auto group_start_range = std::string::npos;
        AT_ASSERT(check.count_ && *check.count_ != 0);
        for (size_t i = 0; i < *check.count_; ++i) {
          start_range =
              assertFind(source, check.search_str_, start_range, check);
          group_start_range = std::min(start_range, group_start_range);
          end_range = start_range + check.search_str_.size();
          start_range = end_range;
        }
        start_range = group_start_range;
      } break;
      case CHECK_SOURCE_HIGHLIGHTED: {
        doCheckSourceHighlighted(check, source, start_range);
        break;
      }
      case CHECK_REGEX: {
        start_range =
            assertFindRegex(source, check.search_str_, start_range, check);
        end_range = start_range + check.search_str_.size();
        break;
      }
      case CHECK_DAG: {
        AT_ERROR();
      } break;
      case CHECK_NOT: {
        AT_ERROR();
      } break;
    }
    return SourceRange(source, start_range, end_range);
  }

  void doChecks(const std::shared_ptr<Source>& source) {
    SourceRange prev(source, 0, 0);
    for (size_t i = 0; i < groups.size(); i++) {
      const auto& curr_group = groups[i];
      CheckType type = curr_group.at(0).type_;
      if (type != CHECK_NOT) {
        prev = matchGroup(curr_group, source, prev);
      } else {
        if (i + 1 < groups.size()) {
          const auto& next_group = groups[i + 1];
          AT_ASSERT(next_group.at(0).type_ != CHECK_NOT);
          SourceRange after_not = matchGroup(next_group, source, prev);
          doCheckNot(curr_group, source, prev, after_not);
          prev = after_not;
          ++i; // already checked the group after
        } else {
          SourceRange end_of_file(
              source, source->size() + 1, source->size() + 1);
          doCheckNot(curr_group, source, prev, end_of_file);
        }
      }
    }
  }

  std::vector<Check> checks;
  std::vector<std::vector<Check>> groups;
};

FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){};

std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc) {
  out << "FileCheck checks:\n";
  for (const Check& c : fc.checks) {
    out << "\t" << c << "\n";
  }
  return out;
};

FileCheck::~FileCheck() {
  if (!fcImpl->has_run) {
    std::cout << "You have not run this instance of FileCheck!\n";
    std::cout << *fcImpl;
  }
  fcImpl.reset();
};

void FileCheck::run(const std::string& test_file) {
  fcImpl->run(test_file);
};

void FileCheck::run(const Graph& graph) {
  std::stringstream graph_str;
  graph_str << graph;
  fcImpl->run(graph_str.str());
};

void FileCheck::run(
    const std::string& input_checks_string,
    const std::string& test_string) {
  fcImpl->run(input_checks_string, test_string);
}

void FileCheck::run(
    const std::string& input_checks_string,
    const Graph& graph) {
  std::stringstream graph_str;
  graph_str << graph;
  fcImpl->run(input_checks_string, graph_str.str());
}

FileCheck* FileCheck::check(const std::string& str) {
  fcImpl->addCheck(CHECK, str);
  return this;
}

FileCheck* FileCheck::check_not(const std::string& str) {
  fcImpl->addCheck(CHECK_NOT, str);
  return this;
}

FileCheck* FileCheck::check_same(const std::string& str) {
  fcImpl->addCheck(CHECK_SAME, str);
  return this;
}

FileCheck* FileCheck::check_next(const std::string& str) {
  fcImpl->addCheck(CHECK_NEXT, str);
  return this;
}

FileCheck* FileCheck::check_count(
    const std::string& str,
    size_t count,
    bool exactly) {
  TORCH_INTERNAL_ASSERT(
      count != 0 || exactly, "Count == 0 && !exactly doesn't do anything");
  if (count) {
    fcImpl->addCheck(CHECK_COUNT, str, count);
  }
  if (exactly) {
    fcImpl->addCheck(CHECK_NOT, str);
  }
  return this;
}

FileCheck* FileCheck::check_dag(const std::string& str) {
  fcImpl->addCheck(CHECK_DAG, str);
  return this;
}

FileCheck* FileCheck::check_source_highlighted(const std::string& str) {
  fcImpl->addCheck(CHECK_SOURCE_HIGHLIGHTED, str);
  return this;
}

FileCheck* FileCheck::check_regex(const std::string& str) {
  fcImpl->addCheck(CHECK_REGEX, str);
  return this;
}

} // namespace testing
} // namespace jit
} // namespace torch
