#include <torch/csrc/jit/frontend/function_schema_parser.h>

#include <ATen/core/Reduction.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/type_factory.h>
#include <torch/csrc/jit/frontend/lexer.h>
#include <torch/csrc/jit/frontend/parse_string_literal.h>
#include <torch/csrc/jit/frontend/schema_type_parser.h>
#include <optional>

#include <memory>
#include <vector>

using at::TypeKind;
using c10::Argument;
using c10::FunctionSchema;
using c10::IValue;
using c10::ListType;
using c10::OperatorName;

namespace torch::jit {

namespace {
struct SchemaParser {
  explicit SchemaParser(const std::string& str, bool allow_typevars)
      : L(std::make_shared<Source>(
            c10::string_view(str),
            std::nullopt,
            0,
            nullptr,
            Source::DONT_COPY)),
        type_parser(L, /*parse_complete_tensor_types*/ false, allow_typevars) {}

  std::variant<OperatorName, FunctionSchema> parseDeclaration() {
    OperatorName name = parseName();

    // If there is no parentheses coming, then this is just the operator name
    // without an argument list
    if (L.cur().kind != '(') {
      return OperatorName(std::move(name));
    }

    std::vector<Argument> arguments;
    std::vector<Argument> returns;
    bool kwarg_only = false;
    bool is_vararg = false;
    bool is_varret = false;
    size_t idx = 0;
    parseList('(', ',', ')', [&] {
      if (is_vararg)
        throw(
            ErrorReport(L.cur())
            << "... must be the last element of the argument list");
      if (L.nextIf('*')) {
        kwarg_only = true;
      } else if (L.nextIf(TK_DOTS)) {
        is_vararg = true;
      } else {
        arguments.push_back(parseArgument(
            idx++, /*is_return=*/false, /*kwarg_only=*/kwarg_only));
      }
    });

    // check if all arguments are not-default for vararg schemas
    if (is_vararg) {
      for (const auto& arg : arguments) {
        if (arg.default_value().has_value()) {
          throw(
              ErrorReport(L.cur())
              << "schemas with vararg (...) can't have default value args");
        }
      }
    }

    idx = 0;
    L.expect(TK_ARROW);
    if (L.nextIf(TK_DOTS)) {
      is_varret = true;
    } else if (L.cur().kind == '(') {
      parseList('(', ',', ')', [&] {
        if (is_varret) {
          throw(
              ErrorReport(L.cur())
              << "... must be the last element of the return list");
        }
        if (L.nextIf(TK_DOTS)) {
          is_varret = true;
        } else {
          returns.push_back(
              parseArgument(idx++, /*is_return=*/true, /*kwarg_only=*/false));
        }
      });
    } else {
      returns.push_back(
          parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false));
    }

    return FunctionSchema(
        std::move(name.name),
        std::move(name.overload_name),
        std::move(arguments),
        std::move(returns),
        is_vararg,
        is_varret);
  }

  c10::OperatorName parseName() {
    std::string name = L.expect(TK_IDENT).text();
    if (L.nextIf(':')) {
      L.expect(':');
      name = name + "::" + L.expect(TK_IDENT).text();
    }
    std::string overload_name = "";
    if (L.nextIf('.')) {
      overload_name = L.expect(TK_IDENT).text();
    }
    // default is used as an attribute on the `OpOverloadPacket`
    // (obtained using `torch.ops.aten.foo`) to get the operator
    // overload with overload name as an empty string
    // and so shouldn't be used as an overload name
    // also disallow dunder attribute names to be overload names
    bool is_a_valid_overload_name =
        !((overload_name == "default") || (overload_name.rfind("__", 0) == 0));
    TORCH_CHECK(
        is_a_valid_overload_name,
        overload_name,
        " is not a legal overload name for aten operators");
    return {name, overload_name};
  }

  std::vector<std::variant<OperatorName, FunctionSchema>> parseDeclarations() {
    std::vector<std::variant<OperatorName, FunctionSchema>> results;
    do {
      results.emplace_back(parseDeclaration());
    } while (L.nextIf(TK_NEWLINE));
    L.expect(TK_EOF);
    return results;
  }

  std::variant<OperatorName, FunctionSchema> parseExactlyOneDeclaration() {
    auto result = parseDeclaration();
    L.nextIf(TK_NEWLINE);
    L.expect(TK_EOF);
    return result;
  }

  Argument parseArgument(size_t /*idx*/, bool is_return, bool kwarg_only) {
    // fake and real type coincide except for Layout/MemoryFormat/ScalarType
    // the fake type for these is Int instead
    auto p = type_parser.parseFakeAndRealType();
    auto fake_type = std::move(std::get<0>(p));
    auto real_type = std::move(std::get<1>(p));
    auto alias_info = std::move(std::get<2>(p));
    std::optional<int32_t> N;
    std::optional<IValue> default_value;
    std::optional<std::string> alias_set;
    std::string name;
    if (L.nextIf('[')) {
      // note: an array with a size hint can only occur at the Argument level
      fake_type = ListType::create(std::move(fake_type));
      real_type = ListType::create(std::move(real_type));
      N = std::stoll(L.expect(TK_NUMBER).text());
      L.expect(']');
      auto container = type_parser.parseAliasAnnotation();
      if (alias_info) {
        if (!container) {
          container = std::optional<at::AliasInfo>(at::AliasInfo());
          container->setIsWrite(alias_info->isWrite());
        }
        container->addContainedType(std::move(*alias_info));
      }
      alias_info = std::move(container);
      if (L.nextIf('?')) {
        fake_type =
            c10::TypeFactory::create<c10::OptionalType>(std::move(fake_type));
        real_type =
            c10::TypeFactory::create<c10::OptionalType>(std::move(real_type));
      }
    }
    if (is_return) {
      // optionally field names in return values
      if (L.cur().kind == TK_IDENT) {
        name = L.next().text();
      } else {
        name = "";
      }
    } else {
      name = L.expect(TK_IDENT).text();
      if (L.nextIf('=')) {
        // NB: this means we have to unswizzle default too
        default_value =
            parseDefaultValue(*fake_type, fake_type->kind(), *real_type, N);
      }
    }
    return Argument(
        std::move(name),
        std::move(fake_type),
        std::move(real_type),
        N,
        std::move(default_value),
        !is_return && kwarg_only,
        std::move(alias_info));
  }

  bool isPossiblyOptionalScalarType(const c10::Type& type) {
    if (type.kind() == at::ScalarTypeType::Kind) {
      return true;
    }
    if (type.kind() == at::OptionalType::Kind) {
      for (const auto& inner : type.containedTypes()) {
        if (isPossiblyOptionalScalarType(*inner))
          return true;
      }
    }
    return false;
  }

  IValue parseSingleConstant(
      const c10::Type& type,
      TypeKind kind,
      const c10::Type& real_type) {
    if (kind == c10::TypeKind::DynamicType) {
      return parseSingleConstant(
          type, type.expectRef<c10::DynamicType>().dynamicKind(), real_type);
    }
    const auto& str2dtype = c10::getStringToDtypeMap();
    switch (L.cur().kind) {
      case TK_TRUE:
        L.next();
        return true;
      case TK_FALSE:
        L.next();
        return false;
      case TK_NONE:
        L.next();
        return IValue();
      case TK_STRINGLITERAL: {
        auto token = L.next();
        return parseStringLiteral(token.range, token.text());
      }
      case TK_IDENT: {
        auto tok = L.next();
        auto text = tok.text();
        // NB: float/complex/long are here for BC purposes. Other dtypes
        // are handled via str2dtype.
        // Please don't add more cases to this if-else block.
        if ("float" == text) {
          return static_cast<int64_t>(at::kFloat);
        } else if ("complex" == text) {
          return static_cast<int64_t>(at::kComplexFloat);
        } else if ("long" == text) {
          return static_cast<int64_t>(at::kLong);
        } else if ("strided" == text) {
          return static_cast<int64_t>(at::kStrided);
        } else if ("Mean" == text) {
          return static_cast<int64_t>(at::Reduction::Mean);
        } else if ("contiguous_format" == text) {
          return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
        } else if (
            isPossiblyOptionalScalarType(real_type) &&
            str2dtype.count(text) > 0) {
          return static_cast<int64_t>(str2dtype.at(text));
        } else {
          throw(ErrorReport(L.cur().range) << "invalid numeric default value");
        }
      }
      default:
        std::string n;
        if (L.nextIf('-'))
          n = "-" + L.expect(TK_NUMBER).text();
        else
          n = L.expect(TK_NUMBER).text();

        if (kind == TypeKind::ComplexType || n.find('j') != std::string::npos) {
          auto imag = std::stod(n.substr(0, n.size() - 1));
          return c10::complex<double>(0, imag);
        } else if (
            kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
            n.find('e') != std::string::npos) {
          return std::stod(n);
        } else {
          int64_t v = std::stoll(n);
          return v;
        }
    }
  }
  IValue convertToList(
      const c10::Type& type,
      TypeKind kind,
      const SourceRange& range,
      const std::vector<IValue>& vs) {
    switch (kind) {
      case TypeKind::ComplexType:
        return fmap(vs, [](const IValue& v) { return v.toComplexDouble(); });
      case TypeKind::FloatType:
        return fmap(vs, [](const IValue& v) { return v.toDouble(); });
      case TypeKind::IntType:
        return fmap(vs, [](const IValue& v) { return v.toInt(); });
      case TypeKind::BoolType:
        return fmap(vs, [](const IValue& v) { return v.toBool(); });
      case TypeKind::DynamicType:
        return convertToList(
            type, type.expectRef<c10::DynamicType>().dynamicKind(), range, vs);
      default:
        throw(
            ErrorReport(range)
            << "lists are only supported for float, int and complex types");
    }
  }
  IValue parseConstantList(
      const c10::Type& type,
      TypeKind kind,
      const c10::Type& real_type) {
    auto tok = L.expect('[');
    std::vector<IValue> vs;
    if (L.cur().kind != ']') {
      do {
        vs.push_back(parseSingleConstant(type, kind, real_type));
      } while (L.nextIf(','));
    }
    L.expect(']');
    return convertToList(type, kind, tok.range, vs);
  }

  IValue parseTensorDefault(const SourceRange& /*range*/) {
    L.expect(TK_NONE);
    return IValue();
  }
  IValue parseDefaultValue(
      const c10::Type& arg_type,
      TypeKind kind,
      const c10::Type& real_type,
      std::optional<int32_t> arg_N) {
    auto range = L.cur().range;
    switch (kind) {
      case TypeKind::TensorType:
      case TypeKind::GeneratorType:
      case TypeKind::QuantizerType: {
        return parseTensorDefault(range);
      } break;
      case TypeKind::StringType:
      case TypeKind::OptionalType:
      case TypeKind::NumberType:
      case TypeKind::IntType:
      case TypeKind::BoolType:
      case TypeKind::FloatType:
      case TypeKind::ComplexType:
        return parseSingleConstant(arg_type, kind, real_type);
        break;
      case TypeKind::DeviceObjType: {
        auto device_text =
            parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
        return c10::Device(device_text);
        break;
      }
      case TypeKind::ListType: {
        auto elem_type = arg_type.containedType(0);
        auto real_elem_type = real_type.containedType(0);
        if (L.cur().kind == TK_IDENT) {
          return parseTensorDefault(range);
        } else if (arg_N && L.cur().kind != '[') {
          IValue v = parseSingleConstant(
              *elem_type, elem_type->kind(), *real_elem_type);
          std::vector<IValue> repeated(*arg_N, v);
          return convertToList(*elem_type, elem_type->kind(), range, repeated);
        } else {
          return parseConstantList(
              *elem_type, elem_type->kind(), *real_elem_type);
        }
      } break;
      case TypeKind::DynamicType:
        return parseDefaultValue(
            arg_type,
            arg_type.expectRef<c10::DynamicType>().dynamicKind(),
            real_type,
            arg_N);
      default:
        throw(ErrorReport(range) << "unexpected type, file a bug report");
    }
    return IValue(); // silence warnings
  }

  void parseList(
      int begin,
      int sep,
      int end,
      c10::function_ref<void()> callback) {
    auto r = L.cur().range;
    if (begin != TK_NOTHING)
      L.expect(begin);
    if (L.cur().kind != end) {
      do {
        callback();
      } while (L.nextIf(sep));
    }
    if (end != TK_NOTHING)
      L.expect(end);
  }
  Lexer L;
  SchemaTypeParser type_parser;
};
} // namespace

std::variant<OperatorName, FunctionSchema> parseSchemaOrName(
    const std::string& schemaOrName,
    bool allow_typevars) {
  // We're ignoring aten and prim for BC reasons
  if (schemaOrName.rfind("aten::", 0) == 0 ||
      schemaOrName.rfind("prim::", 0) == 0) {
    allow_typevars = true;
  }
  return SchemaParser(schemaOrName, allow_typevars)
      .parseExactlyOneDeclaration();
}

FunctionSchema parseSchema(const std::string& schema, bool allow_typevars) {
  auto parsed = parseSchemaOrName(schema, allow_typevars);
  TORCH_CHECK(
      std::holds_alternative<FunctionSchema>(parsed),
      "Tried to parse a function schema but only the operator name was given");
  return std::get<FunctionSchema>(std::move(parsed));
}

OperatorName parseName(const std::string& name) {
  auto parsed = parseSchemaOrName(name);
  TORCH_CHECK(
      std::holds_alternative<OperatorName>(parsed),
      "Tried to parse an operator name but function schema was given");
  return std::get<OperatorName>(std::move(parsed));
}

} // namespace torch::jit
