#pragma once

#include <c10/util/Exception.h>

namespace torch::jit {

// Intrusive doubly linked lists with sane reverse iterators.
// The header file is named generic_graph_node_list.h because it is ONLY
// used for Graph's Node lists, and if you want to use it for other
// things, you will have to do some refactoring.
//
// At the moment, the templated type T must support a few operations:
//
//  - It must have a field: T* next_in_graph[2] = { nullptr, nullptr };
//    which are used for the intrusive linked list pointers.
//
//  - It must have a method 'destroy()', which removes T from the
//    list and frees a T.
//
// In practice, we are only using it with Node and const Node.  'destroy()'
// needs to be renegotiated if you want to use this somewhere else.
//
// Regardless of the iteration direction, iterators always physically point
// to the element they logically point to, rather than
// the off-by-one behavior for all standard library reverse iterators like
// std::list.

// The list is includes two sentinel nodes, one at the beginning and one at the
// end with a circular link between them. It is an error to insert nodes after
// the end sentinel node but before the beginning node:

// Visualization showing only the next() links:
//  HEAD -> first -> second  -> ... -> last -> TAIL
//   ^------------------------------------------

// Visualization showing only the prev() links:
//  HEAD <- first <- second  <- ... <- last <- TAIL
//   ------------------------------------------^

static constexpr int kNextDirection = 0;
static constexpr int kPrevDirection = 1;

template <typename T>
struct generic_graph_node_list;

template <typename T>
struct generic_graph_node_list_iterator;

struct Node;
using graph_node_list = generic_graph_node_list<Node>;
using const_graph_node_list = generic_graph_node_list<const Node>;
using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
using const_graph_node_list_iterator =
    generic_graph_node_list_iterator<const Node>;

template <typename T>
struct generic_graph_node_list_iterator {
  generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
  generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {}
  generic_graph_node_list_iterator(
      const generic_graph_node_list_iterator& rhs) = default;
  generic_graph_node_list_iterator(
      generic_graph_node_list_iterator&& rhs) noexcept = default;
  generic_graph_node_list_iterator& operator=(
      const generic_graph_node_list_iterator& rhs) = default;
  generic_graph_node_list_iterator& operator=(
      generic_graph_node_list_iterator&& rhs) noexcept = default;
  T* operator*() const {
    return cur;
  }
  T* operator->() const {
    return cur;
  }
  generic_graph_node_list_iterator& operator++() {
    AT_ASSERT(cur);
    cur = cur->next_in_graph[d];
    return *this;
  }
  generic_graph_node_list_iterator operator++(int) {
    generic_graph_node_list_iterator old = *this;
    ++(*this);
    return old;
  }
  generic_graph_node_list_iterator& operator--() {
    AT_ASSERT(cur);
    cur = cur->next_in_graph[reverseDir()];
    return *this;
  }
  generic_graph_node_list_iterator operator--(int) {
    generic_graph_node_list_iterator old = *this;
    --(*this);
    return old;
  }

  // erase cur without invalidating this iterator
  // named differently from destroy so that ->/. bugs do not
  // silently cause the wrong one to be called.
  // iterator will point to the previous entry after call
  void destroyCurrent() {
    T* n = cur;
    cur = cur->next_in_graph[reverseDir()];
    n->destroy();
  }
  generic_graph_node_list_iterator reverse() {
    return generic_graph_node_list_iterator(cur, reverseDir());
  }

 private:
  int reverseDir() {
    return d == kNextDirection ? kPrevDirection : kNextDirection;
  }
  T* cur;
  int d; // direction 0 is forward 1 is reverse, see next_in_graph
};

template <typename T>
struct generic_graph_node_list {
  using iterator = generic_graph_node_list_iterator<T>;
  using const_iterator = generic_graph_node_list_iterator<const T>;
  generic_graph_node_list_iterator<T> begin() {
    return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d);
  }
  generic_graph_node_list_iterator<const T> begin() const {
    return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
  }
  generic_graph_node_list_iterator<T> end() {
    return generic_graph_node_list_iterator<T>(head->next_in_graph[!d], d);
  }
  generic_graph_node_list_iterator<const T> end() const {
    return generic_graph_node_list_iterator<const T>(
        head->next_in_graph[!d], d);
  }
  generic_graph_node_list_iterator<T> rbegin() {
    return reverse().begin();
  }
  generic_graph_node_list_iterator<const T> rbegin() const {
    return reverse().begin();
  }
  generic_graph_node_list_iterator<T> rend() {
    return reverse().end();
  }
  generic_graph_node_list_iterator<const T> rend() const {
    return reverse().end();
  }
  generic_graph_node_list reverse() {
    return generic_graph_node_list(head->next_in_graph[!d], !d);
  }
  const generic_graph_node_list reverse() const {
    return generic_graph_node_list(head->next_in_graph[!d], !d);
  }
  T* front() {
    return head->next_in_graph[d];
  }
  const T* front() const {
    return head->next_in_graph[d];
  }
  T* back() {
    return head->next_in_graph[!d];
  }
  const T* back() const {
    return head->next_in_graph[!d];
  }
  generic_graph_node_list(T* head, int d) : head(head), d(d) {}

 private:
  T* head; // both head and tail are sentinel nodes
           // the first real node is head->next_in_graph[d]
           // the tail sentinel is head->next_in_graph[!d]
  int d;
};

template <typename T>
static inline bool operator==(
    generic_graph_node_list_iterator<T> a,
    generic_graph_node_list_iterator<T> b) {
  return *a == *b;
}

template <typename T>
static inline bool operator!=(
    generic_graph_node_list_iterator<T> a,
    generic_graph_node_list_iterator<T> b) {
  return *a != *b;
}

} // namespace torch::jit

namespace std {

template <typename T>
struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> {
  using difference_type = int64_t;
  using value_type = T*;
  using pointer = T**;
  using reference = T*&;
  using iterator_category = bidirectional_iterator_tag;
};

} // namespace std
