#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_symnode.h>

namespace pybind11::detail {

bool type_caster<c10::SymInt>::load(py::handle src, bool) {
  if (torch::is_symint(src)) {
    auto node = src.attr("node");
    if (py::isinstance<c10::SymNodeImpl>(node)) {
      value = c10::SymInt(py::cast<c10::SymNode>(node));
      return true;
    }

    value = c10::SymInt(static_cast<c10::SymNode>(
        c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
    return true;
  }

  auto raw_obj = src.ptr();

  if (THPVariable_Check(raw_obj)) {
    auto& var = THPVariable_Unpack(raw_obj);
    if (var.numel() == 1 &&
        at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
      auto scalar = var.item();
      TORCH_INTERNAL_ASSERT(scalar.isIntegral(/*include bool*/ false));
      value = scalar.toSymInt();
      return true;
    }
  }

  if (THPUtils_checkIndex(raw_obj)) {
    value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
    return true;
  }
  return false;
}

py::handle type_caster<c10::SymInt>::cast(
    const c10::SymInt& si,
    return_value_policy /* policy */,
    handle /* parent */) {
  if (si.is_symbolic()) {
    auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
        si.toSymNodeImplUnowned());
    if (py_node) {
      // Return the Python directly (unwrap)
      return torch::get_symint_class()(py_node->getPyObj()).release();
    } else {
      // Wrap the C++ into Python
      auto inner = py::cast(si.toSymNode());
      if (!inner) {
        throw python_error();
      }
      return torch::get_symint_class()(inner).release();
    }
  } else {
    auto m = si.maybe_as_int();
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    return py::cast(*m).release();
  }
}

bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
  if (torch::is_symfloat(src)) {
    value = c10::SymFloat(static_cast<c10::SymNode>(
        c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
    return true;
  }

  auto raw_obj = src.ptr();
  if (THPUtils_checkDouble(raw_obj)) {
    value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
    return true;
  }
  return false;
}

py::handle type_caster<c10::SymFloat>::cast(
    const c10::SymFloat& si,
    return_value_policy /* policy */,
    handle /* parent */) {
  if (si.is_symbolic()) {
    // TODO: generalize this to work with C++ backed class
    auto* py_node =
        dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
    TORCH_INTERNAL_ASSERT(py_node);
    return torch::get_symfloat_class()(py_node->getPyObj()).release();
  } else {
    return py::cast(si.as_float_unchecked()).release();
  }
}

bool type_caster<c10::SymBool>::load(py::handle src, bool) {
  if (torch::is_symbool(src)) {
    value = c10::SymBool(static_cast<c10::SymNode>(
        c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
    return true;
  }

  auto raw_obj = src.ptr();
  if (THPUtils_checkBool(raw_obj)) {
    value = c10::SymBool{THPUtils_unpackBool(raw_obj)};
    return true;
  }
  return false;
}

py::handle type_caster<c10::SymBool>::cast(
    const c10::SymBool& si,
    return_value_policy /* policy */,
    handle /* parent */) {
  if (auto m = si.maybe_as_bool()) {
    return py::cast(*m).release();
  } else {
    // TODO: generalize this to work with C++ backed class
    auto* py_node =
        dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
    TORCH_INTERNAL_ASSERT(py_node);
    return torch::get_symbool_class()(py_node->getPyObj()).release();
  }
}

bool type_caster<c10::Scalar>::load(py::handle src, bool) {
  TORCH_INTERNAL_ASSERT(
      0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
}

py::handle type_caster<c10::Scalar>::cast(
    const c10::Scalar& scalar,
    return_value_policy /* policy */,
    handle /* parent */) {
  if (scalar.isIntegral(/*includeBool*/ false)) {
    // We have to be careful here; we cannot unconditionally route through
    // SymInt because integer data from Tensors can easily be MIN_INT or
    // very negative, which conflicts with the allocated range.
    if (scalar.isSymbolic()) {
      return py::cast(scalar.toSymInt()).release();
    } else {
      if (scalar.type() == at::ScalarType::UInt64) {
        return py::cast(scalar.toUInt64()).release();
      } else {
        return py::cast(scalar.toLong()).release();
      }
    }
  } else if (scalar.isFloatingPoint()) {
    // This isn't strictly necessary but we add it for symmetry
    if (scalar.isSymbolic()) {
      return py::cast(scalar.toSymFloat()).release();
    } else {
      return py::cast(scalar.toDouble()).release();
    }
  } else if (scalar.isBoolean()) {
    if (scalar.isSymbolic()) {
      return py::cast(scalar.toSymBool()).release();
    }
    return py::cast(scalar.toBool()).release();
  } else if (scalar.isComplex()) {
    return py::cast(scalar.toComplexDouble()).release();
  } else {
    TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
  }
}

} // namespace pybind11::detail
