/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
#define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_

#include <vector>

#include "actions/types.h"
#include "annotator/types.h"
#include "utils/flatbuffers/mutable.h"
#include "utils/strings/stringpiece.h"
#include "utils/variant.h"
#include "flatbuffers/reflection_generated.h"

#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lua.h"
#include "lualib.h"
#ifdef __cplusplus
}
#endif

namespace libtextclassifier3 {

static constexpr const char kLengthKey[] = "__len";
static constexpr const char kPairsKey[] = "__pairs";
static constexpr const char kIndexKey[] = "__index";
static constexpr const char kGcKey[] = "__gc";
static constexpr const char kNextKey[] = "__next";

static constexpr const int kIndexStackTop = -1;

// Casts to the lua user data type.
template <typename T>
void* AsUserData(const T* value) {
  return static_cast<void*>(const_cast<T*>(value));
}
template <typename T>
void* AsUserData(const T value) {
  return reinterpret_cast<void*>(value);
}

// Retrieves up-values.
template <typename T>
T FromUpValue(const int index, lua_State* state) {
  return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
}

class LuaEnvironment {
 public:
  virtual ~LuaEnvironment();
  explicit LuaEnvironment();

  // Compile a lua snippet into binary bytecode.
  // NOTE: The compiled bytecode might not be compatible across Lua versions
  // and platforms.
  bool Compile(StringPiece snippet, std::string* bytecode) const;

  // Loads default libraries.
  void LoadDefaultLibraries();

  // Provides a callback to Lua.
  template <typename T>
  void PushFunction(int (T::*handler)()) {
    PushFunction(std::bind(handler, static_cast<T*>(this)));
  }

  template <typename F>
  void PushFunction(const F& func) const {
    // Copy closure to the lua stack.
    new (lua_newuserdata(state_, sizeof(func))) F(func);

    // Register garbage collection callback.
    lua_newtable(state_);
    lua_pushcfunction(state_, &ReleaseFunction<F>);
    lua_setfield(state_, -2, kGcKey);
    lua_setmetatable(state_, -2);

    // Push dispatch.
    lua_pushcclosure(state_, &CallFunction<F>, 1);
  }

  // Sets up a named table that calls back whenever a member is accessed.
  // This allows to lazily provide required information to the script.
  template <typename T>
  void PushLazyObject(int (T::*handler)()) {
    PushLazyObject(std::bind(handler, static_cast<T*>(this)));
  }

  template <typename F>
  void PushLazyObject(const F& func) const {
    lua_newtable(state_);
    lua_newtable(state_);
    PushFunction(func);
    lua_setfield(state_, -2, kIndexKey);
    lua_setmetatable(state_, -2);
  }

  void Push(const int64 value) const { lua_pushinteger(state_, value); }
  void Push(const uint64 value) const { lua_pushinteger(state_, value); }
  void Push(const int32 value) const { lua_pushinteger(state_, value); }
  void Push(const uint32 value) const { lua_pushinteger(state_, value); }
  void Push(const int16 value) const { lua_pushinteger(state_, value); }
  void Push(const uint16 value) const { lua_pushinteger(state_, value); }
  void Push(const int8 value) const { lua_pushinteger(state_, value); }
  void Push(const uint8 value) const { lua_pushinteger(state_, value); }
  void Push(const float value) const { lua_pushnumber(state_, value); }
  void Push(const double value) const { lua_pushnumber(state_, value); }
  void Push(const bool value) const { lua_pushboolean(state_, value); }
  void Push(const StringPiece value) const { PushString(value); }
  void Push(const flatbuffers::String* value) const {
    if (value == nullptr) {
      PushString("");
    } else {
      PushString(StringPiece(value->c_str(), value->size()));
    }
  }

  template <typename T>
  T Read(const int index = -1) const;

  template <>
  int64 Read<int64>(const int index) const {
    return static_cast<int64>(lua_tointeger(state_, /*idx=*/index));
  }

  template <>
  uint64 Read<uint64>(const int index) const {
    return static_cast<uint64>(lua_tointeger(state_, /*idx=*/index));
  }

  template <>
  int32 Read<int32>(const int index) const {
    return static_cast<int32>(lua_tointeger(state_, /*idx=*/index));
  }

  template <>
  uint32 Read<uint32>(const int index) const {
    return static_cast<uint32>(lua_tointeger(state_, /*idx=*/index));
  }

  template <>
  int16 Read<int16>(const int index) const {
    return static_cast<int16>(lua_tointeger(state_, /*idx=*/index));
  }

  template <>
  uint16 Read<uint16>(const int index) const {
    return static_cast<uint16>(lua_tointeger(state_, /*idx=*/index));
  }

  template <>
  int8 Read<int8>(const int index) const {
    return static_cast<int8>(lua_tointeger(state_, /*idx=*/index));
  }

  template <>
  uint8 Read<uint8>(const int index) const {
    return static_cast<uint8>(lua_tointeger(state_, /*idx=*/index));
  }

  template <>
  float Read<float>(const int index) const {
    return static_cast<float>(lua_tonumber(state_, /*idx=*/index));
  }

  template <>
  double Read<double>(const int index) const {
    return static_cast<double>(lua_tonumber(state_, /*idx=*/index));
  }

  template <>
  bool Read<bool>(const int index) const {
    return lua_toboolean(state_, /*idx=*/index);
  }

  template <>
  StringPiece Read<StringPiece>(const int index) const {
    return ReadString(index);
  }

  template <>
  std::string Read<std::string>(const int index) const {
    return ReadString(index).ToString();
  }

  // Reads a string from the stack.
  StringPiece ReadString(int index) const;

  // Pushes a string to the stack.
  void PushString(const StringPiece str) const;

  // Pushes a flatbuffer to the stack.
  void PushFlatbuffer(const reflection::Schema* schema,
                      const flatbuffers::Table* table) const {
    PushFlatbuffer(schema, schema->root_table(), table);
  }

  // Reads a flatbuffer from the stack.
  int ReadFlatbuffer(int index, MutableFlatbuffer* buffer) const;

  // Pushes an iterator.
  template <typename ItemCallback, typename KeyCallback>
  void PushIterator(const int length, const ItemCallback& item_callback,
                    const KeyCallback& key_callback) const {
    lua_newtable(state_);
    CreateIteratorMetatable(length, item_callback);
    PushFunction([this, length, item_callback, key_callback]() {
      return Iterator::Dispatch(this, length, item_callback, key_callback);
    });
    lua_setfield(state_, -2, kIndexKey);
    lua_setmetatable(state_, -2);
  }

  template <typename ItemCallback>
  void PushIterator(const int length, const ItemCallback& item_callback) const {
    lua_newtable(state_);
    CreateIteratorMetatable(length, item_callback);
    PushFunction([this, length, item_callback]() {
      return Iterator::Dispatch(this, length, item_callback);
    });
    lua_setfield(state_, -2, kIndexKey);
    lua_setmetatable(state_, -2);
  }

  template <typename ItemCallback>
  void CreateIteratorMetatable(const int length,
                               const ItemCallback& item_callback) const {
    lua_newtable(state_);
    PushFunction([this, length]() { return Iterator::Length(this, length); });
    lua_setfield(state_, -2, kLengthKey);
    PushFunction([this, length, item_callback]() {
      return Iterator::IterItems(this, length, item_callback);
    });
    lua_setfield(state_, -2, kPairsKey);
    PushFunction([this, length, item_callback]() {
      return Iterator::Next(this, length, item_callback);
    });
    lua_setfield(state_, -2, kNextKey);
  }

  template <typename T>
  void PushVectorIterator(const std::vector<T>* items) const {
    PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
      this->Push(items->at(pos));
      return 1;
    });
  }

  template <typename T>
  void PushVector(const std::vector<T>& items) const {
    lua_newtable(state_);
    for (int i = 0; i < items.size(); i++) {
      // Key: index, 1-based.
      Push(i + 1);

      // Value.
      Push(items[i]);
      lua_settable(state_, /*idx=*/-3);
    }
  }

  void PushEmptyVector() const { lua_newtable(state_); }

  template <typename T>
  std::vector<T> ReadVector(const int index = -1) const {
    std::vector<T> result;
    if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
      TC3_LOG(ERROR) << "Expected a table, got: "
                     << lua_type(state_, /*idx=*/kIndexStackTop);
      lua_pop(state_, 1);
      return {};
    }
    lua_pushnil(state_);
    while (Next(index - 1)) {
      result.push_back(Read<T>(/*index=*/kIndexStackTop));
      lua_pop(state_, 1);
    }
    return result;
  }

  // Runs a closure in protected mode.
  // `func`: closure to run in protected mode.
  // `num_lua_args`: number of arguments from the lua stack to process.
  // `num_results`: number of result values pushed on the stack.
  template <typename F>
  int RunProtected(const F& func, const int num_args = 0,
                   const int num_results = 0) const {
    PushFunction(func);
    // Put the closure before the arguments on the stack.
    if (num_args > 0) {
      lua_insert(state_, -(1 + num_args));
    }
    return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
  }

  // Auxiliary methods to handle model results.
  // Provides an annotation to lua.
  void PushAnnotation(const ClassificationResult& classification,
                      const reflection::Schema* entity_data_schema) const;
  void PushAnnotation(const ClassificationResult& classification,
                      StringPiece text,
                      const reflection::Schema* entity_data_schema) const;
  void PushAnnotation(const ActionSuggestionAnnotation& annotation,
                      const reflection::Schema* entity_data_schema) const;

  template <typename Annotation>
  void PushAnnotations(const std::vector<Annotation>* annotations,
                       const reflection::Schema* entity_data_schema) const {
    PushIterator(
        annotations ? annotations->size() : 0,
        [this, annotations, entity_data_schema](const int64 index) {
          PushAnnotation(annotations->at(index), entity_data_schema);
          return 1;
        },
        [this, annotations, entity_data_schema](StringPiece name) {
          if (const Annotation* annotation =
                  GetAnnotationByName(*annotations, name)) {
            PushAnnotation(*annotation, entity_data_schema);
            return 1;
          } else {
            return 0;
          }
        });
  }

  // Pushes a span to the lua stack.
  void PushAnnotatedSpan(const AnnotatedSpan& annotated_span,
                         const reflection::Schema* entity_data_schema) const;
  void PushAnnotatedSpans(const std::vector<AnnotatedSpan>* annotated_spans,
                          const reflection::Schema* entity_data_schema) const;

  // Reads a message text span from lua.
  MessageTextSpan ReadSpan() const;

  ActionSuggestionAnnotation ReadAnnotation(
      const reflection::Schema* entity_data_schema) const;
  int ReadAnnotations(
      const reflection::Schema* entity_data_schema,
      std::vector<ActionSuggestionAnnotation>* annotations) const;
  ClassificationResult ReadClassificationResult(
      const reflection::Schema* entity_data_schema) const;

  // Provides an action to lua.
  void PushAction(
      const ActionSuggestion& action,
      const reflection::Schema* actions_entity_data_schema,
      const reflection::Schema* annotations_entity_data_schema) const;

  void PushActions(
      const std::vector<ActionSuggestion>* actions,
      const reflection::Schema* actions_entity_data_schema,
      const reflection::Schema* annotations_entity_data_schema) const;

  ActionSuggestion ReadAction(
      const reflection::Schema* actions_entity_data_schema,
      const reflection::Schema* annotations_entity_data_schema) const;

  int ReadActions(const reflection::Schema* actions_entity_data_schema,
                  const reflection::Schema* annotations_entity_data_schema,
                  std::vector<ActionSuggestion>* actions) const;

  // Conversation message iterator.
  void PushConversation(
      const std::vector<ConversationMessage>* conversation,
      const reflection::Schema* annotations_entity_data_schema) const;

  lua_State* state() const { return state_; }

 protected:
  // Wrapper for handling iteration over containers.
  class Iterator {
   public:
    // Starts a new key-value pair iterator.
    template <typename ItemCallback>
    static int IterItems(const LuaEnvironment* env, const int length,
                         const ItemCallback& callback) {
      env->PushFunction([env, callback, length, pos = 0]() mutable {
        if (pos >= length) {
          lua_pushnil(env->state());
          return 1;
        }

        // Push key.
        lua_pushinteger(env->state(), pos + 1);

        // Push item.
        return 1 + callback(pos++);
      });
      return 1;  // Num. results.
    }

    // Gets the next element.
    template <typename ItemCallback>
    static int Next(const LuaEnvironment* env, const int length,
                    const ItemCallback& item_callback) {
      int64 pos = lua_isnil(env->state(), /*idx=*/kIndexStackTop)
                      ? 0
                      : env->Read<int64>(/*index=*/kIndexStackTop);
      if (pos < length) {
        // Push next key.
        lua_pushinteger(env->state(), pos + 1);

        // Push item.
        return 1 + item_callback(pos);
      } else {
        lua_pushnil(env->state());
        return 1;
      }
    }

    // Returns the length of the container the iterator processes.
    static int Length(const LuaEnvironment* env, const int length) {
      lua_pushinteger(env->state(), length);
      return 1;  // Num. results.
    }

    // Handles item queries to the iterator.
    // Elements of the container can either be queried by name or index.
    // Dispatch will check how an element is accessed and
    // calls `key_callback` for access by name and `item_callback` for access by
    // index.
    template <typename ItemCallback, typename KeyCallback>
    static int Dispatch(const LuaEnvironment* env, const int length,
                        const ItemCallback& item_callback,
                        const KeyCallback& key_callback) {
      switch (lua_type(env->state(), kIndexStackTop)) {
        case LUA_TNUMBER: {
          // Lua is one based, so adjust the index here.
          const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
          if (index < 0 || index >= length) {
            TC3_LOG(ERROR) << "Invalid index: " << index;
            lua_error(env->state());
            return 0;
          }
          return item_callback(index);
        }
        case LUA_TSTRING: {
          return key_callback(env->ReadString(kIndexStackTop));
        }
        default:
          TC3_LOG(ERROR) << "Unexpected access type: "
                         << lua_type(env->state(), kIndexStackTop);
          lua_error(env->state());
          return 0;
      }
    }

    template <typename ItemCallback>
    static int Dispatch(const LuaEnvironment* env, const int length,
                        const ItemCallback& item_callback) {
      switch (lua_type(env->state(), kIndexStackTop)) {
        case LUA_TNUMBER: {
          // Lua is one based, so adjust the index here.
          const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
          if (index < 0 || index >= length) {
            TC3_LOG(ERROR) << "Invalid index: " << index;
            lua_error(env->state());
            return 0;
          }
          return item_callback(index);
        }
        default:
          TC3_LOG(ERROR) << "Unexpected access type: "
                         << lua_type(env->state(), kIndexStackTop);
          lua_error(env->state());
          return 0;
      }
    }
  };

  // Calls the deconstructor from a previously pushed function.
  template <typename T>
  static int ReleaseFunction(lua_State* state) {
    static_cast<T*>(lua_touserdata(state, 1))->~T();
    return 0;
  }

  template <typename T>
  static int CallFunction(lua_State* state) {
    return (*static_cast<T*>(lua_touserdata(state, lua_upvalueindex(1))))();
  }

  // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
  void PushFlatbuffer(const reflection::Schema* schema,
                      const reflection::Object* type,
                      const flatbuffers::Table* table) const;
  int GetField(const reflection::Schema* schema, const reflection::Object* type,
               const flatbuffers::Table* table) const;

  // Reads a repeated field from lua.
  template <typename T>
  void ReadRepeatedField(const int index, RepeatedField* result) const {
    for (const T& element : ReadVector<T>(index)) {
      result->Add(element);
    }
  }

  template <>
  void ReadRepeatedField<MutableFlatbuffer>(const int index,
                                            RepeatedField* result) const {
    lua_pushnil(state_);
    while (Next(index - 1)) {
      ReadFlatbuffer(index, result->Add());
      lua_pop(state_, 1);
    }
  }

  // Pushes a repeated field to the lua stack.
  template <typename T>
  void PushRepeatedField(const flatbuffers::Vector<T>* items) const {
    PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
      Push(items->Get(pos));
      return 1;  // Num. results.
    });
  }

  void PushRepeatedFlatbufferField(
      const reflection::Schema* schema, const reflection::Object* type,
      const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>>* items)
      const {
    PushIterator(items ? items->size() : 0,
                 [this, schema, type, items](const int64 pos) {
                   PushFlatbuffer(schema, type, items->Get(pos));
                   return 1;  // Num. results.
                 });
  }

  // Overloads Lua next function to use __next key on the metatable.
  // This allows us to treat lua objects and lazy objects provided by our
  // callbacks uniformly.
  int Next(int index) const {
    // Check whether the (meta)table of this object has an associated "__next"
    // entry. This means, we registered our own callback. So we explicitly call
    // that.
    if (luaL_getmetafield(state_, index, kNextKey)) {
      // Callback is now on top of the stack, so adjust relative indices by 1.
      if (index < 0) {
        index--;
      }

      // Copy the reference to the table.
      lua_pushvalue(state_, index);

      // Move the key to top to have it as second argument for the callback.
      // Copy the key to the top.
      lua_pushvalue(state_, -3);

      // Remove the copy of the key.
      lua_remove(state_, -4);

      // Call the callback with (key and table as arguments).
      lua_pcall(state_, /*nargs=*/2 /* table, key */,
                /*nresults=*/2 /* key, item */, 0);

      // Next returned nil, it's the end.
      if (lua_isnil(state_, kIndexStackTop)) {
        // Remove nil value.
        // Results will be padded to `nresults` specified above, so we need
        // to remove two elements here.
        lua_pop(state_, 2);
        return 0;
      }

      return 2;  // Num. results.
    } else if (lua_istable(state_, index)) {
      return lua_next(state_, index);
    }

    // Remove the key.
    lua_pop(state_, 1);
    return 0;
  }

  static const ClassificationResult* GetAnnotationByName(
      const std::vector<ClassificationResult>& annotations, StringPiece name) {
    // Lookup annotation by collection.
    for (const ClassificationResult& annotation : annotations) {
      if (name.Equals(annotation.collection)) {
        return &annotation;
      }
    }
    TC3_LOG(ERROR) << "No annotation with collection: " << name << " found.";
    return nullptr;
  }

  static const ActionSuggestionAnnotation* GetAnnotationByName(
      const std::vector<ActionSuggestionAnnotation>& annotations,
      StringPiece name) {
    // Lookup annotation by name.
    for (const ActionSuggestionAnnotation& annotation : annotations) {
      if (name.Equals(annotation.name)) {
        return &annotation;
      }
    }
    TC3_LOG(ERROR) << "No annotation with name: " << name << " found.";
    return nullptr;
  }

  lua_State* state_;
};  // namespace libtextclassifier3

bool Compile(StringPiece snippet, std::string* bytecode);

}  // namespace libtextclassifier3

#endif  // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
