/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "actions/regex-actions.h"

#include "actions/utils.h"
#include "utils/base/logging.h"
#include "utils/regex-match.h"
#include "utils/utf8/unicodetext.h"
#include "utils/zlib/zlib_regex.h"

namespace libtextclassifier3 {
namespace {

// Creates an annotation from a regex capturing group.
bool FillAnnotationFromMatchGroup(
    const UniLib::RegexMatcher* matcher,
    const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
    const std::string& group_match_text, const int message_index,
    ActionSuggestionAnnotation* annotation) {
  if (group->annotation_name() != nullptr ||
      group->annotation_type() != nullptr) {
    int status = UniLib::RegexMatcher::kNoError;
    const CodepointSpan span = {matcher->Start(group->group_id(), &status),
                                matcher->End(group->group_id(), &status)};
    if (status != UniLib::RegexMatcher::kNoError) {
      TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
      return false;
    }
    return FillAnnotationFromCapturingMatch(span, group, message_index,
                                            group_match_text, annotation);
  }
  return true;
}

}  // namespace

bool RegexActions::InitializeRules(
    const RulesModel* rules, const RulesModel* low_confidence_rules,
    const TriggeringPreconditions* triggering_preconditions_overlay,
    ZlibDecompressor* decompressor) {
  if (rules != nullptr) {
    if (!InitializeRulesModel(rules, decompressor, &rules_)) {
      TC3_LOG(ERROR) << "Could not initialize action rules.";
      return false;
    }
  }

  if (low_confidence_rules != nullptr) {
    if (!InitializeRulesModel(low_confidence_rules, decompressor,
                              &low_confidence_rules_)) {
      TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
      return false;
    }
  }

  // Extend by rules provided by the overwrite.
  // NOTE: The rules from the original models are *not* cleared.
  if (triggering_preconditions_overlay != nullptr &&
      triggering_preconditions_overlay->low_confidence_rules() != nullptr) {
    // These rules are optionally compressed, but separately.
    std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
        ZlibDecompressor::Instance();
    if (overwrite_decompressor == nullptr) {
      TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
      return false;
    }
    if (!InitializeRulesModel(
            triggering_preconditions_overlay->low_confidence_rules(),
            overwrite_decompressor.get(), &low_confidence_rules_)) {
      TC3_LOG(ERROR)
          << "Could not initialize low confidence rules from overwrite.";
      return false;
    }
  }

  return true;
}

bool RegexActions::InitializeRulesModel(
    const RulesModel* rules, ZlibDecompressor* decompressor,
    std::vector<CompiledRule>* compiled_rules) const {
  if (rules->regex_rule() == nullptr) {
    return true;
  }
  for (const RulesModel_::RegexRule* rule : *rules->regex_rule()) {
    std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
        UncompressMakeRegexPattern(
            unilib_, rule->pattern(), rule->compressed_pattern(),
            rules->lazy_regex_compilation(), decompressor);
    if (compiled_pattern == nullptr) {
      TC3_LOG(ERROR) << "Failed to load rule pattern.";
      return false;
    }

    // Check whether there is a check on the output.
    std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
    if (rule->output_pattern() != nullptr ||
        rule->compressed_output_pattern() != nullptr) {
      compiled_output_pattern = UncompressMakeRegexPattern(
          unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
          rules->lazy_regex_compilation(), decompressor);
      if (compiled_output_pattern == nullptr) {
        TC3_LOG(ERROR) << "Failed to load rule output pattern.";
        return false;
      }
    }

    compiled_rules->emplace_back(rule, std::move(compiled_pattern),
                                 std::move(compiled_output_pattern));
  }

  return true;
}

bool RegexActions::IsLowConfidenceInput(
    const Conversation& conversation, const int num_messages,
    std::vector<const UniLib::RegexPattern*>* post_check_rules) const {
  for (int i = 1; i <= num_messages; i++) {
    const std::string& message =
        conversation.messages[conversation.messages.size() - i].text;
    const UnicodeText message_unicode(
        UTF8ToUnicodeText(message, /*do_copy=*/false));
    for (int low_confidence_rule = 0;
         low_confidence_rule < low_confidence_rules_.size();
         low_confidence_rule++) {
      const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
      const std::unique_ptr<UniLib::RegexMatcher> matcher =
          rule.pattern->Matcher(message_unicode);
      int status = UniLib::RegexMatcher::kNoError;
      if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
        // Rule only applies to input-output pairs, so defer the check.
        if (rule.output_pattern != nullptr) {
          post_check_rules->push_back(rule.output_pattern.get());
          continue;
        }
        return true;
      }
    }
  }
  return false;
}

bool RegexActions::FilterConfidenceOutput(
    const std::vector<const UniLib::RegexPattern*>& post_check_rules,
    std::vector<ActionSuggestion>* actions) const {
  if (post_check_rules.empty() || actions->empty()) {
    return true;
  }
  std::vector<ActionSuggestion> filtered_text_replies;
  for (const ActionSuggestion& action : *actions) {
    if (action.response_text.empty()) {
      filtered_text_replies.push_back(action);
      continue;
    }
    bool passes_post_check = true;
    const UnicodeText text_reply_unicode(
        UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
    for (const UniLib::RegexPattern* post_check_rule : post_check_rules) {
      const std::unique_ptr<UniLib::RegexMatcher> matcher =
          post_check_rule->Matcher(text_reply_unicode);
      if (matcher == nullptr) {
        TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
        return false;
      }
      int status = UniLib::RegexMatcher::kNoError;
      if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
        passes_post_check = false;
        break;
      }
    }
    if (passes_post_check) {
      filtered_text_replies.push_back(action);
    }
  }
  *actions = std::move(filtered_text_replies);
  return true;
}

bool RegexActions::SuggestActions(
    const Conversation& conversation,
    const MutableFlatbufferBuilder* entity_data_builder,
    std::vector<ActionSuggestion>* actions) const {
  // Create actions based on rules checking the last message.
  const int message_index = conversation.messages.size() - 1;
  const std::string& message = conversation.messages.back().text;
  const UnicodeText message_unicode(
      UTF8ToUnicodeText(message, /*do_copy=*/false));
  for (const CompiledRule& rule : rules_) {
    const std::unique_ptr<UniLib::RegexMatcher> matcher =
        rule.pattern->Matcher(message_unicode);
    int status = UniLib::RegexMatcher::kNoError;
    while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
      for (const RulesModel_::RuleActionSpec* rule_action :
           *rule.rule->actions()) {
        const ActionSuggestionSpec* action = rule_action->action();
        std::vector<ActionSuggestionAnnotation> annotations;

        std::unique_ptr<MutableFlatbuffer> entity_data =
            entity_data_builder != nullptr ? entity_data_builder->NewRoot()
                                           : nullptr;

        // Add entity data from rule capturing groups.
        if (rule_action->capturing_group() != nullptr) {
          for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
               *rule_action->capturing_group()) {
            Optional<std::string> group_match_text =
                GetCapturingGroupText(matcher.get(), group->group_id());
            if (!group_match_text.has_value()) {
              // The group was not part of the match, ignore and continue.
              continue;
            }

            UnicodeText normalized_group_match_text =
                NormalizeMatchText(unilib_, group, group_match_text.value());

            if (!MergeEntityDataFromCapturingMatch(
                    group, normalized_group_match_text.ToUTF8String(),
                    entity_data.get())) {
              TC3_LOG(ERROR)
                  << "Could not merge entity data from a capturing match.";
              return false;
            }

            // Create a text annotation for the group span.
            ActionSuggestionAnnotation annotation;
            if (FillAnnotationFromMatchGroup(matcher.get(), group,
                                             group_match_text.value(),
                                             message_index, &annotation)) {
              annotations.push_back(annotation);
            }

            // Create text reply.
            SuggestTextRepliesFromCapturingMatch(
                entity_data_builder, group, normalized_group_match_text,
                smart_reply_action_type_, actions);
          }
        }

        if (action != nullptr) {
          ActionSuggestion suggestion;
          suggestion.annotations = annotations;
          FillSuggestionFromSpec(action, entity_data.get(), &suggestion);
          actions->push_back(suggestion);
        }
      }
    }
  }
  return true;
}

}  // namespace libtextclassifier3
