/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "utils/wordpiece_tokenizer.h"

#include "utils/utf8/unicodetext.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"

namespace libtextclassifier3 {

namespace {

LookupStatus Lookup(int byte_start, int byte_end, const absl::string_view token,
                    const std::string& suffix_indicator,
                    const WordpieceVocab* vocab_map, bool* in_vocab) {
  int byte_len = byte_end - byte_start;
  absl::string_view substr(token.data() + byte_start, byte_len);
  std::string lookup_value;
  if (byte_start > 0) {
    lookup_value = absl::StrCat(suffix_indicator, substr);
  } else {
    // absl::CopyToString
    lookup_value.assign(substr.begin(), substr.end());
  }
  return vocab_map->Contains(lookup_value, in_vocab);
}

// Sets byte_end to the longest byte sequence which:
// 1) is a proper UTF8 sequence
// 2) is in the vocab OR if split_unknown_characters is true, is a single
//    UTF8 character.
// If no match is found, found_match is set to false.
LookupStatus LongestMatchStartingAt(
    int byte_start, const absl::string_view token,
    const std::string& suffix_indicator, const int max_chars_per_subtoken,
    bool split_unknown_characters, const WordpieceVocab* vocab_map,
    int* byte_end, bool* found_match, bool* match_is_unknown_character) {
  *match_is_unknown_character = false;
  *found_match = false;
  const UnicodeText unicode_token =
      UTF8ToUnicodeText(token.substr(byte_start), /*do_copy=*/false);
  std::vector<int32_t> byte_ends;
  int32_t codepoint_offset = byte_start;
  for (auto it = unicode_token.begin(); it != unicode_token.end(); ++it) {
    codepoint_offset += it.utf8_length();
    byte_ends.push_back(codepoint_offset);
    if (max_chars_per_subtoken > 0 &&
        byte_ends.size() == max_chars_per_subtoken) {
      // If the max bytes of a subtoken is known, do not search beyond that
      // length.
      break;
    }
  }
  int n = byte_ends.size();
  for (int i = n - 1; i >= 0; i--) {
    bool in_vocab;
    auto status = Lookup(byte_start, byte_ends[i], token, suffix_indicator,
                         vocab_map, &in_vocab);
    if (!status.success) return status;
    if (in_vocab) {
      *byte_end = byte_ends[i];
      *found_match = true;
      return LookupStatus::OK();
    }
    if (i == 0 && split_unknown_characters) {
      *byte_end = byte_ends[0];
      *found_match = true;
      *match_is_unknown_character = true;
      return LookupStatus::OK();
    }
  }
  return LookupStatus::OK();
}

// Sets the outputs 'begin_offset', 'end_offset' and 'num_word_pieces' when no
// token is found.
LookupStatus NoTokenFound(const absl::string_view token, bool use_unknown_token,
                          const std::string& unknown_token,
                          std::vector<std::string>* subwords,
                          std::vector<int>* begin_offset,
                          std::vector<int>* end_offset, int* num_word_pieces) {
  begin_offset->push_back(0);
  if (use_unknown_token) {
    subwords->push_back(unknown_token);
    end_offset->push_back(token.length());
  } else {
    subwords->emplace_back(token.data(), token.length());
    end_offset->push_back(token.length());
  }
  ++(*num_word_pieces);

  return LookupStatus::OK();
}

// When a subword is found, this helper function will add the outputs to
// 'subwords', 'begin_offset' and 'end_offset'.
void AddWord(const absl::string_view token, int byte_start, int byte_end,
             const std::string& suffix_indicator,
             std::vector<std::string>* subwords, std::vector<int>* begin_offset,
             std::vector<int>* end_offset) {
  begin_offset->push_back(byte_start);
  int len = byte_end - byte_start;

  if (byte_start > 0) {
    // Prepend suffix_indicator if the token is within a word.
    subwords->push_back(::absl::StrCat(
        suffix_indicator, absl::string_view(token.data() + byte_start, len)));
  } else {
    subwords->emplace_back(token.data(), len);
  }
  end_offset->push_back(byte_end);
}

// Adds a single unknown character subword, found when split_unknown_characters
// is true.
void AddUnknownCharacter(const absl::string_view token, int byte_start,
                         int byte_end, const std::string& suffix_indicator,
                         bool use_unknown_token,
                         const std::string& unknown_token,
                         std::vector<std::string>* subwords,
                         std::vector<int>* begin_offset,
                         std::vector<int>* end_offset) {
  begin_offset->push_back(byte_start);
  end_offset->push_back(byte_end);
  int len = byte_end - byte_start;
  if (use_unknown_token) {
    if (byte_start > 0) {
      // Prepend suffix_indicator if the character is within a word.
      subwords->push_back(::absl::StrCat(suffix_indicator, unknown_token));
    } else {
      subwords->push_back(unknown_token);
    }
  } else {
    if (byte_start > 0) {
      // Prepend suffix_indicator if the character is within a word.
      subwords->push_back(::absl::StrCat(
          suffix_indicator, absl::string_view(token.data() + byte_start, len)));
    } else {
      subwords->emplace_back(token.data(), len);
    }
  }
}

LookupStatus TokenizeL2RGreedy(
    const absl::string_view token, const int max_bytes_per_token,
    const int max_chars_per_subtoken, const std::string& suffix_indicator,
    bool use_unknown_token, const std::string& unknown_token,
    bool split_unknown_characters, const WordpieceVocab* vocab_map,
    std::vector<std::string>* subwords, std::vector<int>* begin_offset,
    std::vector<int>* end_offset, int* num_word_pieces) {
  std::vector<std::string> candidate_subwords;
  std::vector<int> candidate_begin_offsets;
  std::vector<int> candidate_end_offsets;
  const int token_len = token.length();
  for (int byte_start = 0; byte_start < token_len;) {
    int byte_end;
    bool found_subword;
    bool match_is_unknown_character;
    auto status = LongestMatchStartingAt(
        byte_start, token, suffix_indicator, max_chars_per_subtoken,
        split_unknown_characters, vocab_map, &byte_end, &found_subword,
        &match_is_unknown_character);
    if (!status.success) return status;
    if (found_subword) {
      if (match_is_unknown_character) {
        AddUnknownCharacter(token, byte_start, byte_end, suffix_indicator,
                            use_unknown_token, unknown_token,
                            &candidate_subwords, &candidate_begin_offsets,
                            &candidate_end_offsets);
      } else {
        AddWord(token, byte_start, byte_end, suffix_indicator,
                &candidate_subwords, &candidate_begin_offsets,
                &candidate_end_offsets);
      }
      byte_start = byte_end;
    } else {
      return NoTokenFound(token, use_unknown_token, unknown_token, subwords,
                          begin_offset, end_offset, num_word_pieces);
    }
  }

  subwords->insert(subwords->end(), candidate_subwords.begin(),
                   candidate_subwords.end());
  begin_offset->insert(begin_offset->end(), candidate_begin_offsets.begin(),
                       candidate_begin_offsets.end());
  end_offset->insert(end_offset->end(), candidate_end_offsets.begin(),
                     candidate_end_offsets.end());
  *num_word_pieces += candidate_subwords.size();
  return LookupStatus::OK();
}

}  // namespace

LookupStatus WordpieceTokenize(
    const absl::string_view token, const int max_bytes_per_token,
    const int max_chars_per_subtoken, const std::string& suffix_indicator,
    bool use_unknown_token, const std::string& unknown_token,
    bool split_unknown_characters, const WordpieceVocab* vocab_map,
    std::vector<std::string>* subwords, std::vector<int>* begin_offset,
    std::vector<int>* end_offset, int* num_word_pieces) {
  int token_len = token.size();
  if (token_len > max_bytes_per_token) {
    begin_offset->push_back(0);
    *num_word_pieces = 1;
    if (use_unknown_token) {
      subwords->emplace_back(unknown_token);
    } else {
      subwords->emplace_back(token);
    }
    end_offset->push_back(token.size());
    return LookupStatus::OK();
  }
  return TokenizeL2RGreedy(token, max_bytes_per_token, max_chars_per_subtoken,
                           suffix_indicator, use_unknown_token, unknown_token,
                           split_unknown_characters, vocab_map, subwords,
                           begin_offset, end_offset, num_word_pieces);
}

LookupStatus WordpieceTokenize(
    const absl::string_view token, const int max_bytes_per_token,
    const std::string& suffix_indicator, bool use_unknown_token,
    const std::string& unknown_token, const WordpieceVocab* vocab_map,
    std::vector<std::string>* subwords, std::vector<int>* begin_offset,
    std::vector<int>* end_offset, int* num_word_pieces) {
  return WordpieceTokenize(token, max_bytes_per_token,
                           /* max_chars_per_subtoken= */ 0, suffix_indicator,
                           use_unknown_token, unknown_token,
                           /* split_unknown_characters= */ false, vocab_map,
                           subwords, begin_offset, end_offset, num_word_pieces);
}
}  // namespace libtextclassifier3
