// in memory description of all ATen Ops similar to Caffe2 schema
// once C10 exists this can be removed, or stubbed out, but we need
// it now to implement correct semantic checking for script
#pragma once

#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/core/op_registration/op_allowlist.h>
#include <ATen/core/stack.h>
#include <c10/util/Exception.h>
#include <c10/util/overloaded.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/jit/runtime/operator_options.h>
#include <torch/library.h>

#include <ATen/core/function_schema.h>
#include <ATen/core/symbol.h>

#include <functional>
#include <initializer_list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <variant>
#include <vector>

namespace torch::jit {

struct Node;
using ::c10::Argument;
using ::c10::FunctionSchema;
using ::c10::Symbol;

using OperationCreator = Operation (*)(const Node*);

namespace {
const std::array<at::Tag, 1> kJitOnlyOperatorTags = {
    at::Tag::pt2_compliant_tag};
}

/*
 * Note: JIT relies on Operator instances having static lifetime, because
 * it for example stores a non-owning FunctionSchema* pointer in the Node class,
 * which points to the function schema stored in the Operator instance.
 * Also, jit::Operator is meant to store more operator related information like
 * symbolic derivatives, which also requires them to have static lifetime
 * so that changes to symbolic derivatives are remembered.
 *
 * Currently, the JIT operator library contains a jit::Operator instance
 * with a wrapper for each c10 operator. The c10 operator library registers
 * those wrappers using listeners in register_c10_ops.cpp.
 * TODO Instead of doing it this way, we should only have pure-jit ops in
 * the jit library but have the JIT operator lookup look into the c10 library
 * too.
 */

// An Operator is a thin wrapper around either a pure JIT operator (e.g. prim
// ops) or a c10 operator, allowing some common operations and abstracting away
// the concrete operator nature.
struct TORCH_API Operator {
 private:
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  struct C10Operator final {
    c10::OperatorHandle handle_;
    Operation op_;
  };
  struct UnparsedFunctionSchema final {
    std::string schema_string_;
    mutable std::optional<c10::AliasAnalysisKind> alias_analysis_;
  };
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  struct JitOnlyOperator final {
    // The only valid transition for schema_ is from right->left, i.e.
    // when the schema gets parsed.
    mutable std::variant<FunctionSchema, UnparsedFunctionSchema> schema_;

    std::variant<Operation, OperationCreator> op_;
  };

 public:
  Operator(c10::OperatorHandle opHandle, Operation operation)
      : op_(C10Operator{std::move(opHandle), std::move(operation)}) {}

  Operator(
      std::string schema,
      Operation op,
      c10::AliasAnalysisKind alias_analysis)
      : op_(JitOnlyOperator{
            UnparsedFunctionSchema{std::move(schema), alias_analysis},
            Operation(std::move(op))}) {}

  Operator(
      std::string name,
      std::string overload_name,
      std::vector<Argument> arguments,
      std::vector<Argument> returns,
      Operation op,
      c10::AliasAnalysisKind alias_analysis)
      : op_(JitOnlyOperator{
            FunctionSchema(varArgSchemaWithName(
                std::move(name),
                std::move(overload_name),
                std::move(arguments),
                std::move(returns),
                alias_analysis)),
            std::move(op)}) {}

  Operator(
      std::string schema,
      OperationCreator op_creator,
      c10::AliasAnalysisKind alias_analysis)
      : op_(JitOnlyOperator{
            UnparsedFunctionSchema{std::move(schema), alias_analysis},
            op_creator}) {}

  // Helper constructor to register `op` to run
  // run for _every_ IR Node where n.kind() == name, regardless of arguments.
  // This is accomplished by marking the schema varargs and having no required
  // arguments.
  Operator(
      Symbol name,
      OperationCreator op_creator,
      c10::AliasAnalysisKind alias_analysis)
      : op_(JitOnlyOperator{
            FunctionSchema(varArgSchemaWithName(name, alias_analysis)),
            op_creator}) {}

  Operation getOperation(const Node* node = nullptr) const {
    return std::visit(
        c10::overloaded(
            [](const C10Operator& op) { return op.op_; },
            [node](const JitOnlyOperator& op) {
              return std::visit(
                  c10::overloaded(
                      [](const Operation& op) { return op; },
                      [node](const OperationCreator& op_creator) {
                        return op_creator(node);
                      }),
                  op.op_);
            }),
        op_);
  }

  Operation getOperationForDispatchKey(c10::DispatchKey dk) const {
    // TODO: some sort of caching mechanism?
    return std::visit(
        c10::overloaded(
            [dk](const C10Operator& op) {
              return Operation([op, dk](Stack& stack) {
                op.handle_.callBoxedForDispatchKey(dk, stack);
              });
            },
            [](const JitOnlyOperator& op) {
              TORCH_CHECK(
                  false,
                  "calling a JIT operator for dispatch key is not supported");
              return Operation(nullptr);
            }),
        op_);
  }

  const FunctionSchema& schema() const {
    return std::visit(
        c10::overloaded(
            [](const C10Operator& op) -> const FunctionSchema& {
              return op.handle_.schema();
            },
            [](const JitOnlyOperator& op) -> const FunctionSchema& {
              // we lazily parse schema initialized from strings so that
              // we do less work during static operator registration
              if (op.schema_.index() == 1) {
                auto& unmaterializedSchema =
                    std::get<UnparsedFunctionSchema>(op.schema_);
                FunctionSchema schema =
                    parseSchema(unmaterializedSchema.schema_string_);
                if (unmaterializedSchema.alias_analysis_.has_value()) {
                  // TODO What if it gets set later?
                  schema.setAliasAnalysis(
                      *unmaterializedSchema.alias_analysis_);
                }
                op.schema_ = std::move(schema);
              }
              return std::get<FunctionSchema>(op.schema_);
            }),
        op_);
  }

  c10::ArrayRef<at::Tag> getTags() const {
    return std::visit(
        c10::overloaded(
            [](const C10Operator& op) { return op.handle_.getTags(); },
            [](const JitOnlyOperator& op) {
              // JitOnlyOperators don't have an c10::OperatorHandle or a way to
              // specify tags. We're grandfathering them all into
              // pt2_compliant_tag, but for anything else, please just stop
              // using JitOnlyOperator.
              return c10::ArrayRef<at::Tag>(kJitOnlyOperatorTags);
            }),
        op_);
  }

  bool isC10Op() const {
    return op_.index() == 0;
  }

  c10::AliasAnalysisKind aliasAnalysisKind() const {
    const FunctionSchema& schemaRef = schema();
    c10::AliasAnalysisKind alias_analysis = schemaRef.aliasAnalysis();

    TORCH_CHECK(
        alias_analysis == AliasAnalysisKind::FROM_SCHEMA ||
            !schemaRef.hasAnyAliasInfo(),
        "In operator registration: Tried to register operator ",
        schemaRef,
        " with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA.");
    return alias_analysis;
  }

  bool hasOperation() const {
    return std::visit(
        c10::overloaded(
            [](const C10Operator&) { return true; },
            [](const JitOnlyOperator& op) { return op.op_.index() == 0; }),
        op_);
  }

 private:
  static FunctionSchema varArgSchemaWithName(
      Symbol name,
      AliasAnalysisKind alias_analysis) {
    auto result = FunctionSchema(
        name,
        "",
        {},
        {},
        /*is_vararg*/ true,
        /*is_varret*/ true);
    result.setAliasAnalysis(alias_analysis);
    return result;
  }

  static FunctionSchema varArgSchemaWithName(
      std::string name,
      std::string overload_name,
      std::vector<Argument> arguments,
      std::vector<Argument> returns,
      AliasAnalysisKind alias_analysis) {
    auto result = FunctionSchema(
        std::move(name),
        std::move(overload_name),
        std::move(arguments),
        std::move(returns),
        /*is_vararg*/ false,
        /*is_varret*/ false);
    result.setAliasAnalysis(alias_analysis);
    return result;
  }

  std::variant<C10Operator, JitOnlyOperator> op_;
};

TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);

TORCH_API const std::vector<std::shared_ptr<Operator>> getAllOperators();
TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
    Symbol name);
// Returns operators in the order which OpOverloadPacket resolves them.
TORCH_API std::vector<std::shared_ptr<Operator>> getAllSortedOperatorsFor(
    Symbol name);

// given a operator with an overload name, find the specific operator related to
// it, may return nullptr if no operator exists.
TORCH_API std::shared_ptr<Operator> findOperatorFor(
    const c10::OperatorName& full_name);

TORCH_API std::vector<Symbol> findSimilarOperators(Symbol input_op);

TORCH_API void registerOperator(Operator&& op);
TORCH_API void deregisterOperator(const FunctionSchema& schema);

// XXX: this function is meant to be used with string literals only!
TORCH_API std::shared_ptr<Operator> getOperatorForLiteral(
    const char* signature);

// Ensure the thing that registers c10 ops is defined.
// Otherwise, our registry will not have c10 ops. You can run into this
// scenario if you're querying registered ops during static init.
//
// This fn is defined in register_c10_ops.cpp
TORCH_API void ensure_c10_registerer_defined();

// Used to assert that unschematized operators have an analysis method written
TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym);

// A factory function to generate an optional operator. It has two
// instantiations depending on the template bool arg value. The arg can be a
// compile-time function for the selective op registration based on schema
// string.
template <typename Func>
std::optional<Operator> OperatorGenerator(
    const char* schema_str,
    Func&& op,
    AliasAnalysisKind alias_analysis) {
  return std::optional<Operator>(Operator(
      std::string(schema_str), std::forward<Func>(op), alias_analysis));
}

template <typename Func>
std::optional<Operator> OperatorGenerator(
    torch::detail::SelectiveStr<true> schema_str,
    Func&& op,
    AliasAnalysisKind alias_analysis) {
  return OperatorGenerator(
      static_cast<const char*>(schema_str),
      std::forward<Func>(op),
      alias_analysis);
}

template <typename Func>
std::optional<Operator> OperatorGenerator(
    torch::detail::SelectiveStr<false> schema_str,
    Func&& op,
    AliasAnalysisKind alias_analysis) {
  return std::nullopt;
}

template <typename Func>
std::optional<Operator> OperatorGenerator(
    const std::string name,
    const std::string overload_name,
    const std::vector<c10::Argument> arguments,
    const std::vector<c10::Argument> returns,
    Func&& op,
    AliasAnalysisKind alias_analysis) {
  return std::optional<Operator>(Operator(
      name,
      overload_name,
      arguments,
      returns,
      std::forward<Func>(op),
      alias_analysis));
}

} // namespace torch::jit
