#include <ATen/core/dynamic_type.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <torch/csrc/jit/mobile/prim_ops_registery.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h>

namespace torch::jit {

char const* toString(OpCode op);
namespace mobile {
Function::Function(c10::QualifiedName name) : name_(std::move(name)) {}

Function::Function(
    c10::QualifiedName name,
    Code code,
    std::optional<c10::FunctionSchema> schema)
    : name_(std::move(name)),
      code_(std::move(code)),
      schema_(std::move(schema)) {}

const c10::QualifiedName& Function::qualname() const {
  return name_;
}

void Function::append_instruction(
    OpCode op,
    int64_t X,
    int64_t N,
    int64_t dbg_handle) {
  TORCH_CHECK(
      isOpSupportedInMobile(op),
      toString(op),
      " is not supported in mobile module.");
  code_.instructions_.emplace_back(op, X, N);
  code_.debug_handles_.emplace_back(dbg_handle);
}

void Function::append_instruction(OpCode op, int64_t X, int64_t N) {
  TORCH_CHECK(
      isOpSupportedInMobile(op),
      toString(op),
      " is not supported in mobile module.");
  code_.instructions_.emplace_back(op, X, N);
}

void Function::append_operator(
    const std::string& name,
    const std::string& overload_name,
    const std::optional<int>& num_specified_args) {
  // Keep the original opname in code_
  code_.op_names_.emplace_back(name, overload_name);
  code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
}

std::string operator_str(const c10::OperatorName& opname) {
  std::string result = opname.name;
  if (!opname.overload_name.empty()) {
    result += "." + opname.overload_name;
  }
  return result;
}

bool Function::initialize_operators(bool should_check_operators) {
  if (code_.initialized) {
    return true;
  }
  std::unordered_set<std::string> unsupported_op_names;
  code_.operators_.resize(code_.op_names_.size());
  bool all_ops_supported = true;
  for (unsigned i = 0; i < code_.op_names_.size(); i++) {
    const auto& opname = code_.op_names_[i];
    int num_args = code_.operator_input_sizes_[i];
    std::optional<int> num_specified_args =
        num_args < 0 ? std::nullopt : std::optional<int>(num_args);
    auto func = makeOperatorFunction(opname, num_specified_args);
    if (!func.has_value()) {
      unsupported_op_names.insert(operator_str(opname));
      all_ops_supported = false;
    } else {
      code_.operators_[i] = *func;
    }
  }
  if (should_check_operators) {
    TORCH_CHECK(
        unsupported_op_names.empty(),
        "Following ops cannot be found: [",
        c10::Join(", ", unsupported_op_names),
        "]. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/c/mobile/");
  }
  code_.initialized = all_ops_supported;
  return all_ops_supported;
}

void Function::append_constant(const c10::IValue& constant) {
  code_.constants_.push_back(constant);
}

void Function::append_type(const at::TypePtr& type) {
  code_.types_.push_back(type);
}

void Function::append_function(mobile::Function& function) {
  code_.functions_.push_back(&function);
}

void Function::set_register_size(size_t size) {
  code_.register_size_ = size;
}

int64_t Function::get_debug_handle(size_t pc) const {
  TORCH_CHECK(
      pc < code_.debug_handles_.size(),
      "Module debug info index out of boundary.");
  return code_.debug_handles_[pc];
}

torch::jit::Function& Function::setSchema(c10::FunctionSchema schema) {
  schema_ = std::move(schema);
  return *this;
}

bool Function::hasSchema() const {
  return schema_.has_value();
}

const c10::FunctionSchema& Function::getSchema() const {
  return *schema_;
}

void Function::run(Stack& stack) {
  initialize_operators(/* should_check_operators */ true);
  if (hasSchema()) { // if we have a schema then resolve optional args if any
    getSchema().checkAndNormalizeInputs<c10::DynamicType>(
        stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
  }
  InterpreterState interp_state(code_);
  interp_state.run(stack);
}

at::IValue Function::operator()(Stack& stack) {
  run(stack);
  return stack.front();
}

size_t Function::num_inputs() const {
  return schema_->arguments().size();
}

bool Function::call(Stack&, c10::function_ref<void(const mobile::Code&)> f) {
  initialize_operators(true);
  f(code_);
  return true;
}

const Code& Function::get_code() const {
  return code_;
}

Code& Function::get_code() {
  return code_;
}

const std::vector<int64_t>& Function::getExceptionDebugHandles() const {
  return getInterpretersExceptionDebugHandles();
}

std::optional<std::function<void(Stack&)>> makeOperatorFunction(
    const c10::OperatorName& opname,
    std::optional<int> num_specified_args) {
  std::function<void(Stack&)> fn;
  const auto full_name = c10::toString(opname);
  const std::vector<c10::Argument>* pArgs = nullptr;
  bool promoted_op = mobile::hasPrimOpsFn(full_name);
  if (promoted_op) {
    fn = mobile::getPrimOpsFn(full_name);
  } else {
    std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
    if (jit_op) {
      fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
      pArgs = &jit_op->schema().arguments();
    } else {
      auto op = c10::Dispatcher::singleton().findSchema(opname);
      if (op.has_value()) {
        fn = [op](Stack& stack) { op->callBoxed(&stack); };
        if (op->hasSchema()) {
          pArgs = &op->schema().arguments();
        } else {
          TORCH_CHECK(false, "arguments are missing for operator ", opname);
        }
      } else {
        return std::nullopt;
      }
    }
  }

  if (!promoted_op) {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
    const auto& args = *pArgs;
    // num_specified_args >= 0 indicates number of arguments are available
    // from model. We can use it to handle backward compatibility.
    if (num_specified_args &&
        num_specified_args.value() < static_cast<int64_t>(args.size())) {
      fn = [fn, num_specified_args, &args](Stack& stack) {
        std::vector<IValue> out_args;
        // The following logic pops and temporarily stores all out arguments
        // from the stack (which can be 0 or more, and always appended to the
        // schema), in order to push the necessary default values. Finally,
        // the out arguments are pushed back into the stack.
        for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
          out_args.push_back(stack.back());
          stack.pop_back();
        }
        TORCH_CHECK(
            static_cast<size_t>(num_specified_args.value()) >= out_args.size(),
            "The number of output arguments is: ",
            out_args.size(),
            ", which is more then the number of specified arguments: ",
            num_specified_args.value());
        size_t start_index = num_specified_args.value() - out_args.size();
        for (size_t i = start_index; i < (args.size() - out_args.size()); ++i) {
          TORCH_CHECK(
              args[i].default_value().has_value(),
              "Error happened at preparing for default values for the argument. The ",
              i,
              "th argument ",
              args[i].name(),
              " does not have a specified value or default value. ");

          stack.emplace_back(args[i].default_value());
        }
        stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
        fn(stack);
      };
    }
  }
  return fn;
}

Function& Function::registerFunc(
    const std::string& qualified_name,
    const std::vector<Instruction>& instructions,
    const std::vector<c10::IValue>& constants,
    const std::vector<c10::TypePtr>& types,
    const size_t register_size) {
  static std::unordered_map<c10::QualifiedName, Function>
      upgrader_function_holder;
  c10::QualifiedName name = c10::QualifiedName(qualified_name);
  auto found = upgrader_function_holder.find(name);
  // Register the function if it's not found in the map.
  if (found == upgrader_function_holder.end()) {
    auto name_function_pair =
        upgrader_function_holder.emplace(name, Function(name));
    auto& func = name_function_pair.first->second;
    for (auto const& inst : instructions) {
      func.append_instruction(inst.op, inst.X, inst.N);
    }
    for (auto const& constant : constants) {
      func.append_constant(constant);
    }
    for (auto const& type : types) {
      func.append_type(type);
    }
    func.set_register_size(register_size);
    return func;
  }
  auto& upgrader_function_in_holder = found->second;
  return upgrader_function_in_holder;
}

} // namespace mobile
} // namespace torch::jit
