/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "annotator/number/number.h"

#include <climits>
#include <cstdlib>
#include <string>

#include "annotator/collections.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/strings/split.h"
#include "utils/utf8/unicodetext.h"

namespace libtextclassifier3 {

bool NumberAnnotator::ClassifyText(
    const UnicodeText& context, CodepointSpan selection_indices,
    AnnotationUsecase annotation_usecase,
    ClassificationResult* classification_result) const {
  TC3_CHECK(classification_result != nullptr);

  const UnicodeText substring_selected = UnicodeText::Substring(
      context, selection_indices.first, selection_indices.second);

  std::vector<AnnotatedSpan> results;
  if (!FindAll(substring_selected, annotation_usecase, ModeFlag_CLASSIFICATION,
               &results)) {
    return false;
  }

  for (const AnnotatedSpan& result : results) {
    if (result.classification.empty()) {
      continue;
    }

    // We make sure that the result span is equal to the stripped selection span
    // to avoid validating cases like "23 asdf 3.14 pct asdf". FindAll will
    // anyway only find valid numbers and percentages and a given selection with
    // more than two tokens won't pass this check.
    if (result.span.first + selection_indices.first ==
            selection_indices.first &&
        result.span.second + selection_indices.first ==
            selection_indices.second) {
      *classification_result = result.classification[0];
      return true;
    }
  }
  return false;
}

bool NumberAnnotator::IsCJTterm(UnicodeText::const_iterator token_begin_it,
                                const int token_length) const {
  auto token_end_it = token_begin_it;
  std::advance(token_end_it, token_length);
  for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
    if (!unilib_->IsCJTletter(*char_it)) {
      return false;
    }
  }
  return true;
}

bool NumberAnnotator::TokensAreValidStart(const std::vector<Token>& tokens,
                                          const int start_index) const {
  if (start_index < 0 || tokens[start_index].is_whitespace) {
    return true;
  }
  return false;
}

bool NumberAnnotator::TokensAreValidNumberPrefix(
    const std::vector<Token>& tokens, const int prefix_end_index) const {
  if (TokensAreValidStart(tokens, prefix_end_index)) {
    return true;
  }

  auto prefix_begin_it =
      UTF8ToUnicodeText(tokens[prefix_end_index].value, /*do_copy=*/false)
          .begin();
  const int token_length =
      tokens[prefix_end_index].end - tokens[prefix_end_index].start;
  if (token_length == 1 && unilib_->IsOpeningBracket(*prefix_begin_it) &&
      TokensAreValidStart(tokens, prefix_end_index - 1)) {
    return true;
  }
  if (token_length == 1 && unilib_->IsNumberSign(*prefix_begin_it) &&
      TokensAreValidStart(tokens, prefix_end_index - 1)) {
    return true;
  }
  if (token_length == 1 && unilib_->IsSlash(*prefix_begin_it) &&
      prefix_end_index >= 1 &&
      TokensAreValidStart(tokens, prefix_end_index - 2)) {
    int64 int_val;
    double double_val;
    return TryParseNumber(UTF8ToUnicodeText(tokens[prefix_end_index - 1].value,
                                            /*do_copy=*/false),
                          false, &int_val, &double_val);
  }
  if (IsCJTterm(prefix_begin_it, token_length)) {
    return true;
  }

  return false;
}

bool NumberAnnotator::TokensAreValidEnding(const std::vector<Token>& tokens,
                                           const int ending_index) const {
  if (ending_index >= tokens.size() || tokens[ending_index].is_whitespace) {
    return true;
  }

  auto ending_begin_it =
      UTF8ToUnicodeText(tokens[ending_index].value, /*do_copy=*/false).begin();
  if (ending_index == tokens.size() - 1 &&
      tokens[ending_index].end - tokens[ending_index].start == 1 &&
      unilib_->IsPunctuation(*ending_begin_it)) {
    return true;
  }
  if (ending_index < tokens.size() - 1 &&
      tokens[ending_index].end - tokens[ending_index].start == 1 &&
      unilib_->IsPunctuation(*ending_begin_it) &&
      tokens[ending_index + 1].is_whitespace) {
    return true;
  }

  return false;
}

bool NumberAnnotator::TokensAreValidNumberSuffix(
    const std::vector<Token>& tokens, const int suffix_start_index) const {
  if (TokensAreValidEnding(tokens, suffix_start_index)) {
    return true;
  }

  auto suffix_begin_it =
      UTF8ToUnicodeText(tokens[suffix_start_index].value, /*do_copy=*/false)
          .begin();

  if (percent_suffixes_.find(tokens[suffix_start_index].value) !=
          percent_suffixes_.end() &&
      TokensAreValidEnding(tokens, suffix_start_index + 1)) {
    return true;
  }

  const int token_length =
      tokens[suffix_start_index].end - tokens[suffix_start_index].start;
  if (token_length == 1 && unilib_->IsSlash(*suffix_begin_it) &&
      suffix_start_index <= tokens.size() - 2 &&
      TokensAreValidEnding(tokens, suffix_start_index + 2)) {
    int64 int_val;
    double double_val;
    return TryParseNumber(
        UTF8ToUnicodeText(tokens[suffix_start_index + 1].value,
                          /*do_copy=*/false),
        false, &int_val, &double_val);
  }
  if (IsCJTterm(suffix_begin_it, token_length)) {
    return true;
  }

  return false;
}

int NumberAnnotator::FindPercentSuffixEndCodepoint(
    const std::vector<Token>& tokens,
    const int suffix_token_start_index) const {
  if (suffix_token_start_index >= tokens.size()) {
    return -1;
  }

  if (percent_suffixes_.find(tokens[suffix_token_start_index].value) !=
          percent_suffixes_.end() &&
      TokensAreValidEnding(tokens, suffix_token_start_index + 1)) {
    return tokens[suffix_token_start_index].end;
  }
  if (tokens[suffix_token_start_index].is_whitespace) {
    return FindPercentSuffixEndCodepoint(tokens, suffix_token_start_index + 1);
  }

  return -1;
}

bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text,
                                     const bool is_negative,
                                     int64* parsed_int_value,
                                     double* parsed_double_value) const {
  if (token_text.ToUTF8String().size() >= max_number_of_digits_) {
    return false;
  }
  const bool is_double = unilib_->ParseDouble(token_text, parsed_double_value);
  if (!is_double) {
    return false;
  }
  *parsed_int_value = std::trunc(*parsed_double_value);
  if (is_negative) {
    *parsed_int_value *= -1;
    *parsed_double_value *= -1;
  }

  return true;
}

bool NumberAnnotator::FindAll(const UnicodeText& context,
                              AnnotationUsecase annotation_usecase,
                              ModeFlag mode,
                              std::vector<AnnotatedSpan>* result) const {
  if (!options_->enabled() || !(options_->enabled_modes() & mode)) {
    return true;
  }

  const std::vector<Token> tokens = tokenizer_.Tokenize(context);
  for (int i = 0; i < tokens.size(); ++i) {
    const Token token = tokens[i];
    if (tokens[i].value.empty() ||
        !unilib_->IsDigit(
            *UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false).begin())) {
      continue;
    }

    const UnicodeText token_text =
        UTF8ToUnicodeText(token.value, /*do_copy=*/false);
    int64 parsed_int_value;
    double parsed_double_value;
    bool is_negative =
        (i > 0) &&
        unilib_->IsMinus(
            *UTF8ToUnicodeText(tokens[i - 1].value, /*do_copy=*/false).begin());
    if (!TryParseNumber(token_text, is_negative, &parsed_int_value,
                        &parsed_double_value)) {
      continue;
    }
    if (!TokensAreValidNumberPrefix(tokens, is_negative ? i - 2 : i - 1) ||
        !TokensAreValidNumberSuffix(tokens, i + 1)) {
      continue;
    }

    const bool has_decimal = !(parsed_int_value == parsed_double_value);
    const int new_start_codepoint = is_negative ? token.start - 1 : token.start;

    if (((1 << annotation_usecase) & options_->enabled_annotation_usecases()) !=
        0) {
      result->push_back(CreateAnnotatedSpan(
          new_start_codepoint, token.end, parsed_int_value, parsed_double_value,
          Collections::Number(), options_->score(),
          /*priority_score=*/
          has_decimal ? options_->float_number_priority_score()
                      : options_->priority_score()));
    }

    const int percent_end_codepoint =
        FindPercentSuffixEndCodepoint(tokens, i + 1);
    if (percent_end_codepoint != -1 &&
        ((1 << annotation_usecase) &
         options_->percentage_annotation_usecases()) != 0) {
      result->push_back(CreateAnnotatedSpan(
          new_start_codepoint, percent_end_codepoint, parsed_int_value,
          parsed_double_value, Collections::Percentage(), options_->score(),
          options_->percentage_priority_score()));
    }
  }

  return true;
}

AnnotatedSpan NumberAnnotator::CreateAnnotatedSpan(
    const int start, const int end, const int int_value,
    const double double_value, const std::string collection, const float score,
    const float priority_score) const {
  ClassificationResult classification{collection, score};
  classification.numeric_value = int_value;
  classification.numeric_double_value = double_value;
  classification.priority_score = priority_score;

  AnnotatedSpan annotated_span;
  annotated_span.span = {start, end};
  annotated_span.classification.push_back(classification);
  return annotated_span;
}

std::unordered_set<std::string>
NumberAnnotator::FromFlatbufferStringToUnordredSet(
    const flatbuffers::String* flatbuffer_percent_strings) {
  std::unordered_set<std::string> strings_set;
  if (flatbuffer_percent_strings == nullptr) {
    return strings_set;
  }

  const std::string percent_strings = flatbuffer_percent_strings->str();
  for (StringPiece suffix : strings::Split(percent_strings, '\0')) {
    std::string percent_suffix = suffix.ToString();
    percent_suffix.erase(
        std::remove_if(percent_suffix.begin(), percent_suffix.end(),
                       [](unsigned char x) { return std::isspace(x); }),
        percent_suffix.end());
    strings_set.insert(percent_suffix);
  }

  return strings_set;
}

}  // namespace libtextclassifier3
