#include <ATen/ATen.h>
#include <ATen/core/Dict.h>
#ifdef USE_RPC
#include <torch/csrc/distributed/rpc/rref_context.h>
#endif
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/csrc/utils/byte_order.h>
#include <string>
#include <utility>

namespace torch::jit {

using ::c10::IValue;

static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
  if (root.isObject()) {
    restoreAccurateTypeTags(root, root.type());
  }
}

// Pickled objects are stored in a form compatible with Python pickling.
// In torchscript List[T]/Dict[K, V] are statically typed and contain
// dynamic type tags that allow T, K, and V to be recovered. But this
// info is not stored in the Python pickling information. However, we
// can recover this information from the static type of the top-level
// object being unpickled, because we have a record of the type of the
// objects it contains as attributes.
// `IfPossible` - we can only do this recovery when we have an object as
// the top-level unpickled thing (which is guaranteed for Modules, but
// not for torch.load/torch.save). Otherwise we do not know the types
// of the contained objects and cannot restore the tags.
void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
  struct Work {
    TypePtr type;
    IValue value;
  };
  std::vector<Work> to_process = {{type_tag, root}};
  std::unordered_set<const void*> scanned;
  while (!to_process.empty()) {
    Work w = std::move(to_process.back());
    to_process.pop_back();
    // ensure we only scan each pointer value once, otherwise this
    // can become exponential (and if we allow recursive data in the future,
    // it would not terminiate).
    if (w.value.isPtrType()) {
      const void* key = w.value.internalToPointer();
      auto it = scanned.find(key);
      if (it != scanned.end()) {
        continue;
      }
      scanned.emplace_hint(it, key);
    }
    auto kind = w.type->kind();
    if (auto dyn = w.type->castRaw<c10::DynamicType>()) {
      kind = dyn->dynamicKind();
    }
    switch (kind) {
      case TensorType::Kind:
      case StorageType::Kind:
      case NumberType::Kind:
      case FloatType::Kind:
      case ComplexType::Kind:
      case IntType::Kind:
      case NoneType::Kind:
      case GeneratorType::Kind:
      case QuantizerType::Kind:
      case BoolType::Kind:
      case VarType::Kind:
      case CapsuleType::Kind:
      case PyObjectType::Kind:
      case StringType::Kind:
      case FunctionType::Kind:
      case DeviceObjType::Kind:
      case StreamObjType::Kind:
      case QSchemeType::Kind:
      case LayoutType::Kind:
      case MemoryFormatType::Kind:
      case ScalarTypeType::Kind:
      case RRefType::Kind:
      case AnyType::Kind:
      case AnyListType::Kind:
      case AnyTupleType::Kind:
      case AnyClassType::Kind:
      case AnyEnumType::Kind:
        // no op, there is nothing to tag
        break;
      case c10::SymIntType::Kind:
        // TODO: Can this really show up though? :think:
        TORCH_CHECK(!w.value.toSymInt().is_heap_allocated());
        // no op, there is nothing to tag
        break;
      case c10::SymFloatType::Kind:
        TORCH_CHECK(!w.value.toSymFloat().is_symbolic());
        // no op, there is nothing to tag
        break;
      case c10::SymBoolType::Kind:
        TORCH_CHECK(!w.value.toSymBool().is_heap_allocated());
        // no op, there is nothing to tag
        break;
      case DynamicType::Kind:
      case UnionType::Kind:
      case EnumType::Kind:
        // TODO(gmagogsfm): Implement serialization/deserialization of Enum.
        TORCH_INTERNAL_ASSERT(false);
      case TupleType::Kind: {
        auto t = w.value.toTuple();
        for (size_t i = 0; i < w.type->containedTypeSize(); ++i) {
          Work elem = {w.type->containedType(i), t->elements().at(i)};
          to_process.emplace_back(std::move(elem));
        }
      } break;
      case FutureType::Kind: {
        auto f = w.value.toFuture();
        if (f->completed()) {
          Work elem = {w.type->containedType(0), f->value()};
          to_process.emplace_back(std::move(elem));
        }
      } break;
      case AwaitType::Kind: {
        auto aw = w.value.toAwait();
        if (aw->completed()) {
          Work elem = {w.type->containedType(0), aw->wait()};
          to_process.emplace_back(std::move(elem));
        }
      } break;
      case OptionalType::Kind: {
        if (!w.value.isNone()) {
          Work elem = {w.type->containedType(0), w.value};
          to_process.emplace_back(std::move(elem));
        }
      } break;
      case ListType::Kind: {
        // specialized lists do not need their type refined, so we can exit
        // early here
        if (!w.value.isList()) {
          break;
        }
        auto elem_type = w.type->containedType(0);
        auto lst = w.value.toList();
        lst.unsafeSetElementType(elem_type);
        for (const IValue& item : lst) {
          Work elem = {elem_type, item};
          to_process.emplace_back(std::move(elem));
        }
      } break;
      case DictType::Kind: {
        auto d = w.value.toGenericDict();
        auto keyType = w.type->containedType(0);
        auto valType = w.type->containedType(1);
        d.unsafeSetKeyType(keyType);
        d.unsafeSetValueType(valType);
        for (const auto& item : d) {
          Work kelem = {keyType, item.key()};
          Work velem = {valType, item.value()};
          to_process.emplace_back(std::move(kelem));
          to_process.emplace_back(std::move(velem));
        }
      } break;
      // in both cases the dynamic type is a class, and we are going to tag with
      // the dynamic type
      case InterfaceType::Kind:
      case ClassType::Kind: {
        auto obj = w.value.toObject();
        auto typ = obj->type(); // note: intentionally using the dynamic type,
                                // the static type is potentially less accurate
        for (size_t i = 0; i < typ->numAttributes(); ++i) {
          Work elem = {typ->getAttribute(i), obj->getSlot(i)};
          to_process.emplace_back(std::move(elem));
        }
      };
    }
  }
}

namespace {
template <typename T>
bool is(const Type& type) {
  if (type.kind() == T::Kind) {
    return true;
  }
  if (auto dyn = type.castRaw<c10::DynamicType>()) {
    return dyn->tag() == c10::DynamicTypeTrait<T>::tagValue();
  }
  return false;
}
} // namespace

static void restoreContainerTypeTags(
    const IValue& ivalue,
    const TypePtr& type) {
  if (is<DictType>(*type)) {
    auto dict = ivalue.toGenericDict();
    dict.unsafeSetKeyType(type->containedType(0));
    dict.unsafeSetValueType(type->containedType(1));
  } else if (is<ListType>(*type)) {
    ivalue.toList().unsafeSetElementType(type->containedType(0));
  } else {
    AT_ERROR("Unknown type for tag restoration: " + type->annotation_str());
  }
}

IValue Unpickler::parse_ivalue() {
  run();
  TORCH_CHECK(
      stack_.size() == 1,
      "Unpickler expected 1 element on the stack, but found ",
      stack_.size());
  if (version_ <= 2) {
    // See [type tag serialization]
    restoreAccurateTypeTagsIfPossible(stack_[0]);
  }
  return stack_[0];
}

double Unpickler::readFloat() {
  AT_ASSERT(sizeof(double) == 8);
  double big_endian = read<double>();
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
  double little_endian = 0;

  // Pickle floats are big endian, so reverse the bytes
  auto big_endian_ptr = reinterpret_cast<const char*>(&big_endian);
  std::reverse_copy(
      big_endian_ptr,
      big_endian_ptr + sizeof(big_endian),
      reinterpret_cast<char*>(&little_endian));

  return little_endian;
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
  return big_endian;
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
}

void Unpickler::run() {
  // Expect a PROTO opcode and protocol number at the start of blob
  auto opcode = readOpCode();
  TORCH_CHECK(
      opcode == PickleOpCode::PROTO,
      "Expected PROTO opcode at the start"
      " of pickle archive, found ",
      int(static_cast<uint8_t>(opcode)));
  uint8_t protocol = read<uint8_t>();
  TORCH_CHECK(
      protocol == 2,
      "Only Pickle protocol 2 is supported, found protocol = ",
      protocol);

  while (true) {
    PickleOpCode opcode = readInstruction();
    if (opcode == PickleOpCode::STOP) {
      return;
    }
  }
}
void Unpickler::setInput(size_t memo_id) {
  AT_ASSERT(!stack_.empty());
  if (memo_id >= memo_table_.size()) {
    memo_table_.insert(
        memo_table_.end(), memo_id - memo_table_.size(), IValue());
    memo_table_.push_back(stack_.back());
  } else {
    memo_table_[memo_id] = stack_.back();
  }
}

// emplace_back on bool vectors does not exist on some systems
// avoid it by calling push_back for bool
template <typename T>
inline void append(std::vector<T>& a, T&& e) {
  a.emplace_back(std::forward<T>(e));
}
template <>
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
inline void append<bool>(std::vector<bool>& a, bool&& e) {
  a.push_back(e);
}

static std::vector<int64_t> tupleToIntList(const IValue& v) {
  return fmap(v.toTupleRef().elements(), [](const IValue& v) -> int64_t {
    return v.toInt();
  });
}

// note we cannot use toIntList, toDoubleList because during unpickling the
// lists are not yet tagged
template <typename T>
static std::vector<T> convertList(const IValue& v) {
  return fmap(v.toListRef(), [](const IValue& elem) { return elem.to<T>(); });
}

PickleOpCode Unpickler::readInstruction() {
  auto opcode = readOpCode();
  switch (opcode) {
    case PickleOpCode::EMPTY_LIST: {
      stack_.emplace_back(c10::impl::GenericList(AnyType::get()));
    } break;
    case PickleOpCode::EMPTY_TUPLE: {
      if (empty_tuple_.isNone()) {
        // we only need one object, since tuples are not mutable.
        empty_tuple_ = c10::ivalue::Tuple::create(std::vector<IValue>());
      }
      stack_.emplace_back(empty_tuple_);
    } break;
    case PickleOpCode::BINPUT: {
      size_t memo_id = read<uint8_t>();
      setInput(memo_id);
    } break;
    case PickleOpCode::LONG_BINPUT: {
      TORCH_CHECK(
          std::numeric_limits<size_t>::max() >=
              std::numeric_limits<uint32_t>::max(),
          "Found a LONG_BINPUT opcode, but size_t on this system is "
          "not big enough to decode it");
      size_t memo_id = read<uint32_t>();
      setInput(memo_id);
    } break;
    case PickleOpCode::MARK: {
      // Mark location of the container ivalue in the stack
      marks_.push_back(stack_.size());
    } break;
    case PickleOpCode::NEWTRUE: {
      stack_.emplace_back(true);
    } break;
    case PickleOpCode::NEWFALSE: {
      stack_.emplace_back(false);
    } break;
    case PickleOpCode::NONE: {
      stack_.emplace_back();
    } break;
    case PickleOpCode::BININT1: {
      uint8_t value = read<uint8_t>();
      stack_.emplace_back(int64_t(value));
    } break;
    case PickleOpCode::BININT2: {
      uint16_t value = from_le16(read<uint16_t>());
      stack_.emplace_back(int64_t(value));
    } break;
    case PickleOpCode::BININT: {
      int32_t value = from_le32(read<int32_t>());
      stack_.emplace_back(int64_t(value));
    } break;
    case PickleOpCode::LONG1: {
      // Only read LONG1s with 8 as the length
      uint8_t length = read<uint8_t>();
      TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length));
      stack_.emplace_back(int64_t(from_le64(read<int64_t>())));
    } break;
    case PickleOpCode::BINUNICODE: {
      uint32_t length = from_le32(read<uint32_t>());
      stack_.emplace_back(readBytes(length));
    } break;
    case PickleOpCode::BINUNICODE8: {
      int64_t length = from_le64(read<int64_t>());
      stack_.emplace_back(readBytes(length));
    } break;
    case PickleOpCode::BINFLOAT:
      stack_.emplace_back(readFloat());
      break;
    case PickleOpCode::TUPLE: {
      TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
      size_t start = marks_.back();
      marks_.pop_back();
      std::vector<IValue> elements;
      TORCH_CHECK(
          stack_.size() >= start,
          "Parsing error: wrong start index ",
          start,
          " for stack_ of size ",
          stack_.size());
      const auto tupleSize = stack_.size() - start;
      switch (tupleSize) {
        case 3: {
          auto e3 = pop(stack_);
          auto e2 = pop(stack_);
          auto e1 = pop(stack_);
          stack_.emplace_back(c10::ivalue::Tuple::create(
              std::move(e1), std::move(e2), std::move(e3)));
          break;
        }
        case 2: {
          auto e2 = pop(stack_);
          auto e1 = pop(stack_);
          stack_.emplace_back(
              c10::ivalue::Tuple::create(std::move(e1), std::move(e2)));
          break;
        }
        case 1:
          stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_)));
          break;
        default: {
          elements.reserve(stack_.size() - start);
          auto start_it = stack_.begin() + static_cast<std::ptrdiff_t>(start);
          for (auto it = start_it; it != stack_.end(); ++it) {
            elements.emplace_back(std::move(*it));
          }
          stack_.erase(start_it, stack_.end());
          stack_.emplace_back(c10::ivalue::Tuple::create(std::move(elements)));
          break;
        }
      }
    } break;
    case PickleOpCode::TUPLE1: {
      TORCH_CHECK(
          !stack_.empty(),
          "Parsing error: stack_ contains ",
          stack_.size(),
          " elements, at least 1 expected");
      stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_)));
    } break;
    case PickleOpCode::TUPLE2: {
      TORCH_CHECK(
          stack_.size() > 1,
          "Parsing error: stack_ contains ",
          stack_.size(),
          " elements, at least 2 expected");
      auto e2 = pop(stack_);
      auto e1 = pop(stack_);
      stack_.emplace_back(
          c10::ivalue::Tuple::create(std::move(e1), std::move(e2)));
    } break;
    case PickleOpCode::TUPLE3: {
      TORCH_CHECK(
          stack_.size() > 2,
          "Parsing error: stack_ contains ",
          stack_.size(),
          " elements, at least 3 expected");
      auto e3 = pop(stack_);
      auto e2 = pop(stack_);
      auto e1 = pop(stack_);
      stack_.emplace_back(c10::ivalue::Tuple::create(
          std::move(e1), std::move(e2), std::move(e3)));
    } break;
    case PickleOpCode::EMPTY_DICT:
      stack_.emplace_back(
          c10::impl::GenericDict(AnyType::get(), AnyType::get()));
      break;
    case PickleOpCode::APPENDS: {
      TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
      size_t start = marks_.back();
      TORCH_CHECK(
          start > 0 && start <= stack_.size(),
          "Parsing error: wrong start index ",
          start,
          " for stack_ of size ",
          stack_.size());
      auto list_ivalue = stack_.at(start - 1);
      readList(list_ivalue);
    } break;
    case PickleOpCode::APPEND: {
      TORCH_CHECK(
          stack_.size() >= 2, "Parsing error: missing elements in stack_.");
      auto list_ivalue = stack_.at(stack_.size() - 2);
      readListElements(list_ivalue, stack_.size() - 1);
    } break;
    case PickleOpCode::LIST: {
      IValue list_ivalue = c10::impl::GenericList(AnyType::get());
      readList(list_ivalue);
      stack_.push_back(std::move(list_ivalue));
    } break;
    case PickleOpCode::DICT: {
      TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
      size_t start = marks_.back();
      marks_.pop_back();
      TORCH_CHECK(
          stack_.size() > start,
          "Parsing error: wrong start index ",
          start,
          " for stack_ which of size ",
          stack_.size());
      auto dict = c10::impl::GenericDict(AnyType::get(), AnyType::get());
      TORCH_CHECK(
          (stack_.size() - start) % 2 == 0,
          "Parsing error: stack_ is of size ",
          stack_.size(),
          " and start index is ",
          start,
          ", but stack_ is iterated by two elements at a time");
      for (size_t i = start; i < stack_.size(); i += 2) {
        dict.insert_or_assign(stack_[i], stack_[i + 1]);
      }
      stack_.erase(
          stack_.begin() + static_cast<std::ptrdiff_t>(start), stack_.end());
      stack_.emplace_back(std::move(dict));
    } break;
    case PickleOpCode::SETITEMS: {
      TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
      size_t start = marks_.back();
      marks_.pop_back();
      TORCH_CHECK(
          start > 0 && start <= stack_.size(),
          "Parsing error: wrong start index for stack_");
      auto dict = stack_.at(start - 1).toGenericDict();
      TORCH_CHECK(
          (stack_.size() - start) % 2 == 0,
          "Parsing error: stack_ is of size ",
          stack_.size(),
          " and start index is ",
          start,
          ", but stack_ is iterated by two elemenst at a time");
      for (size_t i = start; i < stack_.size(); i += 2) {
        dict.insert_or_assign(stack_[i], stack_[i + 1]);
      }
      stack_.erase(
          stack_.begin() + static_cast<std::ptrdiff_t>(start), stack_.end());
    } break;
    case PickleOpCode::BINGET: {
      auto pos = read<uint8_t>();
      TORCH_CHECK(
          memo_table_.size() > pos,
          "Parsing error: out of bounds access at ",
          (size_t)pos,
          " to memo_table_ which is of size ",
          memo_table_.size());
      stack_.push_back(memo_table_.at(pos));
    } break;
    case PickleOpCode::LONG_BINGET: {
      auto pos = read<uint32_t>();
      TORCH_CHECK(
          memo_table_.size() > pos,
          "Parsing error: out of bounds access at ",
          (size_t)pos,
          " to memo_table_ which is of size ",
          memo_table_.size());
      stack_.push_back(memo_table_.at(pos));
    } break;
    case PickleOpCode::STOP:
      break;
    case PickleOpCode::GLOBAL: {
      // Module name, it's not needed for anything
      auto module_name = readString();
      auto class_name = readString();
      readGlobal(module_name, class_name);
    } break;
    case PickleOpCode::NEWOBJ: {
      TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty");
      // pop empty tuple, the actual action is stored in the globals_stack_
      stack_.pop_back();
    } break;
    // because we have NEWOBJ do nothing, BUILD and REDUCE end up doing
    // the same thing
    case PickleOpCode::BUILD:
    case PickleOpCode::REDUCE: {
      // stack is: <functor_idx> <functor_arg>
      // extract <functor_idx> and remove from the stack:
      TORCH_CHECK(
          stack_.size() > 1,
          "Parsing error: stack_ contains ",
          stack_.size(),
          " elements, at least 2 expected");
      std::swap(*(stack_.end() - 2), *(stack_.end() - 1));
      size_t idx = stack_.back().toInt();
      stack_.pop_back();
      // stack is: <functor_arg>
      TORCH_CHECK(
          idx < globals_.size(),
          "Parsing error: out of bounds access to globals_");
      globals_.at(idx)();
    } break;
    case PickleOpCode::BINPERSID: {
      TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty");
      auto tuple = pop(stack_).toTuple();
      const auto& args = tuple->elements();
      AT_ASSERT(
          args.at(0).toStringRef() == "storage",
          "unknown PERSID key ",
          args.at(0).toStringRef());
      at::ScalarType type = args.at(1).toScalarType();
      const std::string& key = args.at(2).toStringRef();

      at::Device device(args.at(3).toStringRef());
      // remap device location if it's not meta
      if (device_ && !device.is_meta()) {
        device = *device_;
      }

      at::Storage storage;
      if (storage_context_ != nullptr && storage_context_->hasStorage(key)) {
        // for torch.package logic where storage may be loaded already
        storage = storage_context_->getStorage(key);
      } else {
        int64_t numel = args.at(4).toInt();
        caffe2::TypeMeta dtype = at::CPU(type).typeMeta();

        at::DataPtr storage_ptr;
        if (numel > 0) {
          // If there are no elements in the tensor, there's no point in
          // reading a zero (0) byte file from the input stream and paying
          // that cost.
          storage_ptr = read_record_(key);
        }

        storage = at::Storage(
            c10::Storage::use_byte_size_t(),
            numel * dtype.itemsize(),
            std::move(storage_ptr),
            /*allocator=*/nullptr,
            /*resizable=*/false); // NB: we didn't set any allocator for the
                                  // tensor
        if (storage_context_ != nullptr) {
          storage_context_->addStorage(key, storage);
        }
      }

      auto options = at::CPU(type).options();
      if (use_storage_device_) {
        options = options.device(storage.device());
        device = storage.device();
      }

      at::Tensor tensor;
      if (options.backend() == c10::Backend::QuantizedCPU) {
        tensor = at::_empty_affine_quantized({}, options, 0, 0)
                     .set_(storage, 0, {}, {});
      } else {
        tensor = at::empty({0}, options).set_(storage);
      }

      if (device.is_cuda() || device.is_xpu() || device.is_meta() ||
          device.is_hpu() || device.is_mps() || device.is_privateuseone()) {
        tensor = tensor.to(device, tensor.scalar_type());
      } else if (device.type() != DeviceType::CPU) {
        AT_ERROR(
            "supported devices include CPU, CUDA, HPU and ",
            c10::get_privateuse1_backend(),
            " however got ",
            DeviceTypeName(device.type(), false));
      }
      stack_.emplace_back(std::move(tensor));
    } break;
    case PickleOpCode::SETITEM: {
      // At this OpCode, stack looks like
      // | Stack Bottom |
      // | ......       |
      // | Dict         | -> (stack_size - 3)
      // | Key          | -> (stack_size - 2)
      // | Value        | -> (stack_size - 1)
      TORCH_CHECK(
          stack_.size() >= 3,
          "Parsing error: stack doesn't have enough elements");

      auto stack_size = stack_.size();
      auto dict_pos = stack_size - 3;
      auto key_pos = stack_size - 2;
      auto val_pos = stack_size - 1;

      TORCH_CHECK(
          (dict_pos < stack_size) && (key_pos < stack_size) &&
              (val_pos < stack_size),
          "Parsing error: attempted out-of-bounds access while processing SETITEM opcode");

      auto dict = stack_.at(dict_pos).toGenericDict();
      dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos));
      stack_.erase(
          stack_.begin() + static_cast<std::ptrdiff_t>(key_pos), stack_.end());
    } break;
    default: {
      AT_ERROR(
          "Unknown opcode for unpickling at ",
          // NOLINTNEXTLINE(performance-no-int-to-ptr)
          reinterpret_cast<void*>(opcode),
          ": ",
          int(static_cast<uint8_t>(opcode)));
    } break;
  }
  return opcode;
}

void Unpickler::readGlobal(
    const std::string& module_name,
    const std::string& class_name) {
  if (this->skip_next_read_global) {
    // See [NOTE] skip_next_read_global
    this->skip_next_read_global--;
    if (this->skip_next_read_global == 1) {
      // Pass through to the correct handler
    } else if (this->skip_next_read_global == 0) {
      // Corresponds to the type of `Tensor` being unpickled
      if (module_name != "torch" || class_name != "Tensor") {
        TORCH_WARN(
            "Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++");
      }
      stack_.emplace_back(int64_t(globals_.size() - 1));
      return;
    } else {
      TORCH_CHECK(false, "INVALID VALUES")
    }
  }
  // TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this
  // is only here for bc-compatibility reasons
  if (module_name == "__main__") {
    if (class_name == "TensorID") {
      globals_.emplace_back([this] {
        auto setitem_data = stack_.back();
        stack_.pop_back();
        TORCH_INTERNAL_ASSERT(
            !tensor_table_.empty(),
            "Pickler tried to write a tensor but had no tensor table to write to");
        stack_.emplace_back(tensor_table_.at(setitem_data.toInt()));
      });
    } else if (class_name == "IntList") {
      globals_.emplace_back([this] {
        stack_.back().toList().unsafeSetElementType(IntType::get());
      });
    } else {
      AT_ERROR("Unknown pickler class id", class_name);
    }
  } else if (module_name == "torch.jit._pickle") {
    if (class_name == "build_tensor_from_id") {
      globals_.emplace_back([this] {
        // Pop reduce arg off the stack
        auto data = stack_.back().toTupleRef().elements().at(0);
        stack_.pop_back();
        TORCH_CHECK(
            !tensor_table_.empty(),
            "Found a tensor table reference but Unpickler"
            " has no tensor table\n");
        stack_.emplace_back(tensor_table_.at(data.toInt()));
      });
    } else if (class_name == "restore_type_tag") {
      globals_.emplace_back([this] {
        auto tuple = stack_.back().toTuple();
        const auto& data = tuple->elements();
        auto type_str = data.at(1).toStringRef();
        stack_.pop_back();
        TypePtr type = nullptr;
        auto entry = type_cache_.find(type_str);
        if (entry != type_cache_.end()) {
          type = entry->second;
        } else {
          if (type_resolver_ == nullptr) {
            // If we haven't injected a custom way of retrieving types from
            // names, use a barebones type parser.
            type = type_parser_(type_str);
          } else {
            type = type_resolver_(type_str).type_;
          }
          type_cache_[type_str] = type;
        }
        // TODO: Use lookahead to avoid creating the tuple and immediately
        // destroying it here
        restoreContainerTypeTags(data.at(0), type);
        stack_.emplace_back(data.at(0));
      });
    } else {
      TypePtr elem_type = nullptr;
      if (class_name == "build_intlist") {
        elem_type = IntType::get();
      } else if (class_name == "build_tensorlist") {
        elem_type = TensorType::get();
      } else if (class_name == "build_doublelist") {
        elem_type = FloatType::get();
      } else if (class_name == "build_boollist") {
        elem_type = BoolType::get();
      } else {
        AT_ERROR("Unknown pickler class id ", class_name);
      }
      // Unpickle a list specialization (e.g. List[Tensor], List[int], ...)
      globals_.emplace_back([this, elem_type] {
        // Pop reduce arg off the stack
        auto data = stack_.back().toTupleRef().elements().at(0).toList();
        stack_.pop_back();
        data.unsafeSetElementType(elem_type);
        stack_.emplace_back(std::move(data));
      });
    }
  } else if (
      module_name == "torch._utils" &&
      (class_name == "_rebuild_tensor_v2" ||
       class_name == "_rebuild_qtensor")) {
    // Unpickle a tensor
    bool quantized = class_name == "_rebuild_qtensor";
    rebuildTensor(quantized);
  } else if (
      module_name == "torch._tensor" &&
      (class_name == "_rebuild_from_type_v2")) {
    // Unpickle a Tensor with Python attributes or
    // a Subclassed Tensor.
    rebuildTensorFromTypeV2();
  } else if (
      module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") {
    rebuildSparseTensor();
  } else if (module_name == "builtins" && class_name == "complex") {
    globals_.emplace_back([this] {
      auto tuple = pop(stack_).toTuple();
      const auto& elems = tuple->elements();
      AT_ASSERT(elems.size() == 2);
      auto complex =
          c10::complex<double>(elems.at(0).toDouble(), elems.at(1).toDouble());
      stack_.emplace_back(complex);
    });

  } else if (module_name == "collections" && class_name == "OrderedDict") {
    // collections.OrderedDict is used in tensor serialization for a tensor's
    // backward hooks (but they are not actually saved with this Pickler)
    globals_.emplace_back([this] {
      // drop the Tuple that was argument to OrderedDict, and replace it
      // with None OrderedDicts only appear in tensor deserialization and
      // their value is never used
      stack_.back() = IValue();
    });
  } else if (module_name == "torch" && class_name == "device") {
    globals_.emplace_back([this] {
      auto device_string = stack_.back().toTupleRef().elements().at(0);
      stack_.pop_back();
      stack_.emplace_back(c10::Device(device_string.toStringRef()));
    });
    stack_.emplace_back(int64_t(globals_.size() - 1));
    return;
  } else if (module_name == "torch.distributed.rpc" && class_name == "rref") {
#ifdef USE_RPC
    return rebuildRRef();
#else
    TORCH_INTERNAL_ASSERT(
        false,
        "RRef unpickling is only supported with the distributed package");
#endif
  } else if (module_name == "torch") {
    // Try to manually resolve several global enums
    // NOTE: this does not put a global into the global table,
    // like the other branches here because no REDUCE or BUILD will
    // be called on this value. Instead, we just put it on the stack
    // and return early
    std::optional<c10::ScalarType> scalar_type;
#define CHECK_SCALAR(_, name)          \
  if (class_name == #name "Storage") { \
    scalar_type = c10::k##name;        \
  }
    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CHECK_SCALAR)
#undef CHECK_SCALAR
    if (scalar_type.has_value()) {
      stack_.emplace_back(int64_t(*scalar_type));
      return;
    }

    std::optional<at::QScheme> qscheme;
    for (int i = 0; i < at::COMPILE_TIME_NUM_QSCHEMES; ++i) {
      if (class_name == toString(static_cast<at::QScheme>(i))) {
        qscheme = static_cast<at::QScheme>(i);
      }
    }
    if (qscheme.has_value()) {
      stack_.emplace_back(int64_t(*qscheme));
      return;
    }
    TORCH_CHECK(
        false,
        "Unpickler found unknown torch global, 'torch.",
        class_name,
        "'");
  } else {
    TORCH_CHECK(
        type_resolver_,
        "Unpickler found unknown type ",
        module_name,
        ".",
        class_name);
    at::StrongTypePtr type =
        type_resolver_(c10::QualifiedName(module_name, class_name));
    if (auto enum_type = type.type_->cast<c10::EnumType>()) {
      globals_.emplace_back([this, enum_type] {
        auto val = stack_.back();
        stack_.pop_back();
        for (const auto& p : enum_type->enumNamesValues()) {
          if (p.second == val) {
            auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
                enum_type, p.first, p.second);
            stack_.emplace_back(std::move(enum_holder));
            return;
          }
        }
      });
    } else {
      // Otherwise, global is a class/object type.
      globals_.emplace_back([this, type] {
        auto val = stack_.back();
        stack_.pop_back();
        auto obj = obj_loader_(type, val);
        stack_.emplace_back(std::move(obj));
      });
    }
  }
  stack_.emplace_back(int64_t(globals_.size() - 1));
}

void Unpickler::rebuildSparseTensor() {
  globals_.emplace_back([this] {
    auto tup = pop(stack_).toTuple();
    const auto& elements = tup->elements();
    size_t idx = 0;
    auto layout = elements.at(idx++).toInt();
    at::Tensor result;
    switch (layout) {
      case static_cast<int>(c10::Layout::Sparse): {
        std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
        bool requires_grad = elements.at(idx++).toBool();
        auto& indices_tensor = elements.at(idx++).toTensor();
        auto& values_tensor = elements.at(idx++).toTensor();
        auto options = values_tensor.options()
                           .layout(c10::Layout::Sparse)
                           .requires_grad(requires_grad);
        result = at::_sparse_coo_tensor_unsafe(
            indices_tensor, values_tensor, size, options);
        result = autograd::make_variable(result, options.requires_grad());
        break;
      }
      case static_cast<int>(c10::Layout::SparseCsr): {
        std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
        bool requires_grad = elements.at(idx++).toBool();
        auto& crow_indices = elements.at(idx++).toTensor();
        auto& col_indices = elements.at(idx++).toTensor();
        auto& values_tensor = elements.at(idx++).toTensor();
        auto options = values_tensor.options()
                           .layout(c10::Layout::SparseCsr)
                           .requires_grad(requires_grad);
        result = at::_sparse_csr_tensor_unsafe(
            crow_indices, col_indices, values_tensor, size, options);
        result =
            autograd::make_variable(std::move(result), options.requires_grad());
        break;
      }
      default:
        TORCH_CHECK(
            false,
            "Unsupported sparse tensor layout type in serialization ",
            static_cast<c10::Layout>(layout));
        break;
    }
    stack_.emplace_back(std::move(result));
  });
}

void Unpickler::rebuildTensor(bool quantized) {
  globals_.emplace_back([this, quantized] {
    auto tup = pop(stack_).toTuple();
    const auto& elements = tup->elements();
    size_t idx = 0;
    auto& storage_tensor = elements.at(idx++).toTensor();
    int64_t storage_offset = elements.at(idx++).toInt();
    std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
    std::vector<int64_t> stride = tupleToIntList(elements.at(idx++));
    at::Tensor result;
    if (quantized) {
      auto qparams_tuple = elements.at(idx++).toTuple();
      const auto& qparams = qparams_tuple->elements();
      auto qscheme = static_cast<at::QScheme>(qparams.at(0).toInt());
      switch (qscheme) {
        case at::kPerTensorAffine: {
          double q_scale = qparams.at(1).toDouble();
          int64_t q_zero_point = qparams.at(2).toInt();
          result = at::_empty_affine_quantized(
              {0}, storage_tensor.options(), q_scale, q_zero_point);
        } break;
        case at::kPerChannelAffineFloatQParams:
        case at::kPerChannelAffine: {
          const auto& scales = qparams.at(1).toTensor();
          const auto& zero_points = qparams.at(2).toTensor();
          int64_t axis = qparams.at(3).toInt();
          result = at::_empty_per_channel_affine_quantized(
              {0}, scales, zero_points, axis, storage_tensor.options());
        } break;
        default:
          TORCH_CHECK(
              false,
              "Unsupported tensor quantization type in serialization ",
              toString(qscheme));
          break;
      }
    } else {
      result = at::empty({0}, storage_tensor.options());
    }
    bool requires_grad = elements.at(idx++).toBool();
    idx++; // backwards hooks is empty
    at::TensorImpl* impl = result.unsafeGetTensorImpl();
    impl->set_storage_keep_dtype(storage_tensor.storage());
    impl->set_storage_offset(storage_offset);
    impl->set_sizes_and_strides(size, stride);
    result = autograd::make_variable(result, requires_grad);

    // Handle if math_bits were pickled.
    // See `args` of _reduce_ex_internal
    // for a regular tensor (final else case).
    // Tensors pickled before this patch didn't
    // have this argument for storing MathBits,
    // in that case, we do nothing.
    // NOTE: `math_bits` is the 7th arg.
    // NOTE: This is only meant for regular tensor and not quantized
    //       which also has 7 args serialized.
    if (!quantized && elements.size() == 7) {
      auto math_bits = elements.at(idx++).toGenericDict();
      torch::jit::setTensorMetadata(result, math_bits);
    }

    stack_.emplace_back(std::move(result));
  });
}

void Unpickler::rebuildTensorFromTypeV2() {
  // [NOTE] skip_next_read_global
  // When rebuilding Tensor with Python Attr or Subclassed Tensor,
  // we receive `(func, type(self), args, state)` on stack for
  // `rebuildTensorFromTypeV2`.
  // Thus next call to readGlobal corresponds to `func` which is
  // the function to rebuild the base tensor.
  // The call after `func` to readGlobal corresponds to `type` of the
  // Tensor where we raise warning if the type is not `torch.Tensor`.
  this->skip_next_read_global = 2;
  auto curr_globals_idx = globals_.size();
  globals_.emplace_back([this, curr_globals_idx] {
    // args is a tuple with following data
    //  (function to rebuild base tensor, type of tensor,
    //   arguments to construct base tensor, Python State (as dict))
    auto args = pop(stack_).toTuple();
    size_t tup_idx = 0;
    const auto args_elems = args->elements();
    auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple();
    auto py_state = args_elems.at(tup_idx + 3).toGenericDict();
    if (!py_state.empty()) {
      TORCH_WARN(
          "Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded");
    }
    // This calls the function to rebuild the
    // base tensor.
    // Eg. `rebuildTensor`, `rebuildSpareTensor`.
    stack_.emplace_back(base_tensor_args);
    globals_[curr_globals_idx + 1]();
    stack_.emplace_back(pop(stack_));
  });
}

#ifdef USE_RPC
void Unpickler::rebuildRRef() {
  globals_.emplace_back([this] {
    // It is the same as how rref is unpickled in python,
    // see PyRRef::unpickle
    auto tuple = std::move(stack_.back()).toTuple();
    const auto& args = tuple->elements();
    stack_.pop_back();
    TORCH_INTERNAL_ASSERT(
        args.size() == distributed::rpc::RFD_TUPLE_SIZE,
        "Pickled RRefForkData must contain 7 numbers.");
    auto ownerId =
        static_cast<int16_t>(args.at(distributed::rpc::OWNER_IDX).toInt());
    // const reference will extend the lifetime of the temporary variable
    const auto& rrefId = distributed::rpc::RRefId(
        static_cast<int16_t>(args.at(distributed::rpc::RREFID_ON_IDX).toInt()),
        static_cast<int64_t>(args.at(distributed::rpc::RREFID_ID_IDX).toInt()));
    const auto& forkId = distributed::rpc::RRefId(
        static_cast<int16_t>(args.at(distributed::rpc::FORKID_ON_IDX).toInt()),
        static_cast<int64_t>(args.at(distributed::rpc::FORKID_ID_IDX).toInt()));
    auto parent =
        static_cast<int16_t>(args.at(distributed::rpc::PARENT_IDX).toInt());
    const auto& typeStr = static_cast<std::string>(
        args.at(distributed::rpc::TYPE_IDX).toStringRef());
    auto rrefForkData = distributed::rpc::RRefForkData(
        ownerId, rrefId, forkId, parent, typeStr);
    auto& ctx = distributed::rpc::RRefContext::getInstance();
    c10::intrusive_ptr<distributed::rpc::RRef> rref;
    TORCH_INTERNAL_ASSERT(
        type_resolver_ != nullptr, "type_resolver_ is nullptr.");
    at::StrongTypePtr type = type_resolver_(c10::QualifiedName(typeStr));
    rref = ctx.getOrCreateRRef(rrefForkData, type.type_);
    ctx.notifyOwnerAndParentOfFork(
        rrefForkData.forkId_, rrefForkData.parent_, rref);
    stack_.emplace_back(
        c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref));
  });
  stack_.emplace_back(int64_t(globals_.size() - 1));
  return;
}
#endif

void Unpickler::readSlowWithBuffer(char* dest, size_t sz) {
  // First, read any partial from buffer (may be 0).
  // We explicitly assume that sz > buffer_remaining_,
  // and that sz is never bigger than buffer_.size().
  AT_ASSERT(sz > buffer_remaining_);
  const size_t from_old_buf = buffer_remaining_;
  if (from_old_buf != 0) {
    memcpy(dest, buffer_.data() + buffer_pos_, from_old_buf);
  }
  const size_t needed = sz - from_old_buf;
  // Full read into the buffer. The calls here all explicitly
  // assume that one buffer will be enough for any sz.
  AT_ASSERT(sz <= buffer_.size());
  buffer_remaining_ = reader_(buffer_.data(), buffer_.size());
  if (buffer_remaining_ < needed) {
    AT_ERROR("Unexpected end of pickler archive.");
  }
  memcpy(dest + from_old_buf, buffer_.data(), needed);
  buffer_pos_ = needed; // assignment (0'ed from read)
  buffer_remaining_ -= needed;
}

// Read a number of bytes from the input stream
std::string Unpickler::readBytes(size_t length) {
  std::string data;
  static const size_t kSmallString = 64;
  TORCH_CHECK(
      length <= data.max_size(),
      "Parsing error: can't read ",
      length,
      " bytes to a string");
  if (length <= buffer_remaining_) {
    // Fast-path: entirely in buffer.
    data.assign(buffer_.data() + buffer_pos_, length);
    buffer_pos_ += length;
    buffer_remaining_ -= length;
  } else if (length <= kSmallString) {
    // If the string is smallish, do a full buffer read,
    // and read out of that buffer.
    data.resize(length);
    readSlowWithBuffer(&data[0], length);
  } else {
    // Otherwise, for larger strings, read what we can from
    // the buffer, and then read directly to the destination.
    const size_t from_old_buf = buffer_remaining_;
    if (from_old_buf != 0) {
      data.reserve(length);
      data.append(buffer_.data() + buffer_pos_, from_old_buf);
    }
    data.resize(length);
    const size_t needed = length - from_old_buf;
    size_t nread = reader_(&data[from_old_buf], needed);
    if (nread != needed) {
      AT_ERROR("Unexpected end of pickler archive.");
    }
    buffer_remaining_ = 0;
    // buffer_pos_ has no meaning with buffer_remaining_ == 0.
  }
  return data;
}

void Unpickler::readListElements(IValue list_ivalue, size_t start) {
  auto num_elements = stack_.size() - start;
  auto elements = c10::ArrayRef<IValue>(stack_).slice(start);
  if (list_ivalue.isIntList()) {
    auto list = std::move(list_ivalue).toIntList();
    list.reserve(num_elements);
    for (const auto& elem : elements) {
      list.emplace_back(elem.toInt());
    }
  } else if (list_ivalue.isTensorList()) {
    auto list = std::move(list_ivalue).toTensorList();
    list.reserve(num_elements);
    for (const auto& elem : elements) {
      list.emplace_back(elem.toTensor());
    }
  } else if (list_ivalue.isDoubleList()) {
    auto list = std::move(list_ivalue).toDoubleList();
    list.reserve(num_elements);
    for (const auto& elem : elements) {
      list.emplace_back(elem.toDouble());
    }
  } else if (list_ivalue.isBoolList()) {
    auto list = std::move(list_ivalue).toBoolList();
    list.reserve(num_elements);
    for (const auto& elem : elements) {
      list.push_back(elem.toBool());
    }
  } else if (list_ivalue.isList()) {
    auto list = std::move(list_ivalue).toList();
    list.reserve(num_elements);
    for (const auto& elem : elements) {
      list.emplace_back(elem);
    }
  } else {
    AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind());
  }
  stack_.erase(
      stack_.begin() + static_cast<std::ptrdiff_t>(start), stack_.end());
}

// Pop all the list items off of the stack and append them to the list at
// the corresponding MARK
void Unpickler::readList(IValue list_ivalue) {
  TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
  size_t start = marks_.back();
  marks_.pop_back();
  readListElements(std::move(list_ivalue), start);
}

inline bool is_valid_python_id_char(char c) {
  return c == '_' || c == '.' || (c >= '0' && c <= '9') ||
      (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
}

// Read a newline terminated string
std::string Unpickler::readString() {
  std::string ss;
  while (true) {
    auto* const bufferStart = buffer_.data() + buffer_pos_;
    const auto bufferLeft = buffer_.size() - buffer_pos_;
    char* const newlinePtr =
        static_cast<char*>(memchr(bufferStart, '\n', bufferLeft));
    if (newlinePtr) {
      // read up to newline and we are done.
      auto const charsRead = newlinePtr - bufferStart;
      ss.append(bufferStart, charsRead);
      buffer_remaining_ -= charsRead + 1;
      buffer_pos_ += charsRead + 1;
      break;
    } else {
      // read whole buffer, refill
      for (const char* p = bufferStart; p < bufferStart + bufferLeft; ++p) {
        // Simple check just in case there is no terminating '\n'
        TORCH_CHECK(
            is_valid_python_id_char(*p),
            "Found character '",
            int(uint8_t(*p)),
            "' in string, ",
            "strings must be qualified Python identifiers");
      }
      ss.append(bufferStart, bufferLeft);
      buffer_remaining_ = reader_(buffer_.data(), buffer_.size());
      buffer_pos_ = 0;
    }
  }
  return ss;
}

} // namespace torch::jit
