/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_

#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "actions/actions_model_generated.h"
#include "actions/conversation_intent_detection/conversation-intent-detection.h"
#include "actions/feature-processor.h"
#include "actions/grammar-actions.h"
#include "actions/ranker.h"
#include "actions/regex-actions.h"
#include "actions/sensitive-classifier-base.h"
#include "actions/types.h"
#include "annotator/annotator.h"
#include "annotator/model-executor.h"
#include "annotator/types.h"
#include "utils/flatbuffers/flatbuffers.h"
#include "utils/flatbuffers/mutable.h"
#include "utils/i18n/locale.h"
#include "utils/memory/mmap.h"
#include "utils/tflite-model-executor.h"
#include "utils/utf8/unilib.h"
#include "utils/variant.h"
#include "utils/zlib/zlib.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/random/random.h"

namespace libtextclassifier3 {

// Class for predicting actions following a conversation.
class ActionsSuggestions {
 public:
  // Creates ActionsSuggestions from given data buffer with model.
  static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
      const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
      const std::string& triggering_preconditions_overlay = "");

  // Creates ActionsSuggestions from model in the ScopedMmap object and takes
  // ownership of it.
  static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
      std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
      const UniLib* unilib = nullptr,
      const std::string& triggering_preconditions_overlay = "");
  // Same as above, but also takes ownership of the unilib.
  static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
      std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
      std::unique_ptr<UniLib> unilib,
      const std::string& triggering_preconditions_overlay);

  // Creates ActionsSuggestions from model given as a file descriptor, offset
  // and size in it. If offset and size are less than 0, will ignore them and
  // will just use the fd.
  static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
      const int fd, const int offset, const int size,
      const UniLib* unilib = nullptr,
      const std::string& triggering_preconditions_overlay = "");
  // Same as above, but also takes ownership of the unilib.
  static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
      const int fd, const int offset, const int size,
      std::unique_ptr<UniLib> unilib,
      const std::string& triggering_preconditions_overlay = "");

  // Creates ActionsSuggestions from model given as a file descriptor.
  static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
      const int fd, const UniLib* unilib = nullptr,
      const std::string& triggering_preconditions_overlay = "");
  // Same as above, but also takes ownership of the unilib.
  static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
      const int fd, std::unique_ptr<UniLib> unilib,
      const std::string& triggering_preconditions_overlay);

  // Creates ActionsSuggestions from model given as a POSIX path.
  static std::unique_ptr<ActionsSuggestions> FromPath(
      const std::string& path, const UniLib* unilib = nullptr,
      const std::string& triggering_preconditions_overlay = "");
  // Same as above, but also takes ownership of unilib.
  static std::unique_ptr<ActionsSuggestions> FromPath(
      const std::string& path, std::unique_ptr<UniLib> unilib,
      const std::string& triggering_preconditions_overlay);

  ActionsSuggestionsResponse SuggestActions(
      const Conversation& conversation,
      const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;

  ActionsSuggestionsResponse SuggestActions(
      const Conversation& conversation, const Annotator* annotator,
      const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;

  bool InitializeConversationIntentDetection(
      const std::string& serialized_config);

  const ActionsModel* model() const;
  const reflection::Schema* entity_data_schema() const;

  static constexpr int kLocalUserId = 0;

 protected:
  // Exposed for testing.
  bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;

  // Embeds the tokens per message separately. Each message is padded to the
  // maximum length with the padding token.
  bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
                             std::vector<float>* embeddings,
                             int* max_num_tokens_per_message) const;

  // Concatenates the embedded message tokens - separated by start and end
  // token between messages.
  // If the total token count is greater than the maximum length, tokens at the
  // start are dropped to fit into the limit.
  // If the total token count is smaller than the minimum length, padding tokens
  // are added to the end.
  // Messages are assumed to be ordered by recency - most recent is last.
  bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
                             std::vector<float>* embeddings,
                             int* total_token_count) const;

  const ActionsModel* model_;

  // Feature extractor and options.
  std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
  std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
  std::vector<float> embedded_padding_token_;
  std::vector<float> embedded_start_token_;
  std::vector<float> embedded_end_token_;
  int token_embedding_size_;

 private:
  // Checks that model contains all required fields, and initializes internal
  // datastructures.
  bool ValidateAndInitialize();

  void SetOrCreateUnilib(const UniLib* unilib);

  // Prepare preconditions.
  // Takes values from flag provided data, but falls back to model provided
  // values for parameters that are not explicitly provided.
  bool InitializeTriggeringPreconditions();

  // Tokenizes a conversation and produces the tokens per message.
  std::vector<std::vector<Token>> Tokenize(
      const std::vector<std::string>& context) const;

  bool AllocateInput(const int conversation_length, const int max_tokens,
                     const int total_token_count,
                     tflite::Interpreter* interpreter) const;

  bool SetupModelInput(const std::vector<std::string>& context,
                       const std::vector<int>& user_ids,
                       const std::vector<float>& time_diffs,
                       const int num_suggestions,
                       const ActionSuggestionOptions& options,
                       tflite::Interpreter* interpreter) const;

  void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
                                            ActionSuggestion* suggestion) const;

  void PopulateTextReplies(
      const tflite::Interpreter* interpreter, int suggestion_index,
      int score_index, const std::string& type, float priority_score,
      const absl::flat_hash_set<std::string>& blocklist,
      const absl::flat_hash_map<std::string, std::vector<std::string>>&
          concept_mappings,
      ActionsSuggestionsResponse* response) const;

  void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
                                int suggestion_index, int score_index,
                                const ActionSuggestionSpec* task_spec,
                                ActionsSuggestionsResponse* response) const;

  bool ReadModelOutput(tflite::Interpreter* interpreter,
                       const ActionSuggestionOptions& options,
                       ActionsSuggestionsResponse* response) const;

  bool SuggestActionsFromModel(
      const Conversation& conversation, const int num_messages,
      const ActionSuggestionOptions& options,
      ActionsSuggestionsResponse* response,
      std::unique_ptr<tflite::Interpreter>* interpreter) const;

  Status SuggestActionsFromConversationIntentDetection(
      const Conversation& conversation, const ActionSuggestionOptions& options,
      std::vector<ActionSuggestion>* actions) const;

  // Creates options for annotation of a message.
  AnnotationOptions AnnotationOptionsForMessage(
      const ConversationMessage& message) const;

  void SuggestActionsFromAnnotations(
      const Conversation& conversation,
      std::vector<ActionSuggestion>* actions) const;

  void SuggestActionsFromAnnotation(
      const int message_index, const ActionSuggestionAnnotation& annotation,
      std::vector<ActionSuggestion>* actions) const;

  // Run annotator on the messages of a conversation.
  Conversation AnnotateConversation(const Conversation& conversation,
                                    const Annotator* annotator) const;

  // Deduplicates equivalent annotations - annotations that have the same type
  // and same span text.
  // Returns the indices of the deduplicated annotations.
  std::vector<int> DeduplicateAnnotations(
      const std::vector<ActionSuggestionAnnotation>& annotations) const;

  bool SuggestActionsFromLua(
      const Conversation& conversation,
      const TfLiteModelExecutor* model_executor,
      const tflite::Interpreter* interpreter,
      const reflection::Schema* annotation_entity_data_schema,
      std::vector<ActionSuggestion>* actions) const;

  bool GatherActionsSuggestions(const Conversation& conversation,
                                const Annotator* annotator,
                                const ActionSuggestionOptions& options,
                                ActionsSuggestionsResponse* response) const;

  std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;

  // Tensorflow Lite models.
  std::unique_ptr<const TfLiteModelExecutor> model_executor_;

  // Regex rules model.
  std::unique_ptr<RegexActions> regex_actions_;

  // The grammar rules model.
  std::unique_ptr<GrammarActions> grammar_actions_;

  std::unique_ptr<UniLib> owned_unilib_;
  const UniLib* unilib_;

  // Locales supported by the model.
  std::vector<Locale> locales_;

  // Annotation entities used by the model.
  std::unordered_set<std::string> annotation_entity_types_;

  // Builder for creating extra data.
  const reflection::Schema* entity_data_schema_;
  std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
  std::unique_ptr<ActionsSuggestionsRanker> ranker_;

  std::string lua_bytecode_;

  // Triggering preconditions. These parameters can be backed by the model and
  // (partially) be provided by flags.
  TriggeringPreconditionsT preconditions_;
  std::string triggering_preconditions_overlay_buffer_;
  const TriggeringPreconditions* triggering_preconditions_overlay_;

  // Low confidence input ngram classifier.
  std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;

  // Conversation intent detection model for additional actions.
  std::unique_ptr<const ConversationIntentDetection>
      conversation_intent_detection_;

  // Used for randomly selecting candidates.
  mutable absl::BitGen bit_gen_;
};

// Interprets the buffer as a Model flatbuffer and returns it for reading.
const ActionsModel* ViewActionsModel(const void* buffer, int size);

// Opens model from given path and runs a function, passing the loaded Model
// flatbuffer as an argument.
//
// This is mainly useful if we don't want to pay the cost for the model
// initialization because we'll be only reading some flatbuffer values from the
// file.
template <typename ReturnType, typename Func>
ReturnType VisitActionsModel(const std::string& path, Func function) {
  ScopedMmap mmap(path);
  if (!mmap.handle().ok()) {
    function(/*model=*/nullptr);
  }
  const ActionsModel* model =
      ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
  return function(model);
}

class ActionsSuggestionsTypes {
 public:
  // Should be in sync with those defined in Android.
  // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
  static const std::string& ViewCalendar() {
    static const std::string& value =
        *[]() { return new std::string("view_calendar"); }();
    return value;
  }
  static const std::string& ViewMap() {
    static const std::string& value =
        *[]() { return new std::string("view_map"); }();
    return value;
  }
  static const std::string& TrackFlight() {
    static const std::string& value =
        *[]() { return new std::string("track_flight"); }();
    return value;
  }
  static const std::string& OpenUrl() {
    static const std::string& value =
        *[]() { return new std::string("open_url"); }();
    return value;
  }
  static const std::string& SendSms() {
    static const std::string& value =
        *[]() { return new std::string("send_sms"); }();
    return value;
  }
  static const std::string& CallPhone() {
    static const std::string& value =
        *[]() { return new std::string("call_phone"); }();
    return value;
  }
  static const std::string& SendEmail() {
    static const std::string& value =
        *[]() { return new std::string("send_email"); }();
    return value;
  }
  static const std::string& ShareLocation() {
    static const std::string& value =
        *[]() { return new std::string("share_location"); }();
    return value;
  }
  static const std::string& CreateReminder() {
    static const std::string& value =
        *[]() { return new std::string("create_reminder"); }();
    return value;
  }
  static const std::string& TextReply() {
    static const std::string& value =
        *[]() { return new std::string("text_reply"); }();
    return value;
  }
  static const std::string& AddContact() {
    static const std::string& value =
        *[]() { return new std::string("add_contact"); }();
    return value;
  }
  static const std::string& Copy() {
    static const std::string& value =
        *[]() { return new std::string("copy"); }();
    return value;
  }
};

}  // namespace libtextclassifier3

#endif  // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
