#include <torch/csrc/jit/passes/liveness.h>

#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <iostream>
#include <memory>

namespace torch::jit {

// LivenessAnalyzer computes "bailout" liveness which is equivalent to
// "{LIVE_IN} or {GEN}" or "{LIVE_OUT} - {KILL}"
struct LivenessAnalyzer {
  explicit LivenessAnalyzer(std::shared_ptr<Graph> graph)
      : graph_(std::move(graph)) {}

  std::unordered_map<Node*, std::vector<Value*>> run() {
    std::vector<Node*> counters;
    insertExplicitUsesOfLoopCounters(graph_->block(), counters);

    // we implement the canonical fixed-point liveness
    // the analysis is run until there are no more changes
    // to liveness sets for each node
    do {
      changed_ = false;
      processBlock(graph_->block(), SparseBitVector{});
    } while (changed_);

    removeCounterNodes(counters);
    std::unordered_map<Node*, std::vector<Value*>> result;

    for (const auto& e : liveness_sets_) {
      result.insert({e.first, toValueVector(e.second)});
    }
    return result;
  }

  // temporary make loop counts live for the duration of the loop
  // as they are needed by BailOuts in the loop
  void insertExplicitUsesOfLoopCounters(
      Block* b,
      std::vector<Node*>& counters) {
    for (auto it : b->nodes()) {
      if (it->kind() == prim::Loop) {
        LoopView lv(it);
        WithInsertPoint guard(lv.bodyBlock());
        auto ctc = graph_->create(prim::Store, {lv.currentTripCount()}, 0);
        graph_->insertNode(ctc);
        counters.push_back(ctc);
        auto mtc = graph_->create(prim::Store, {lv.maxTripCount()}, 0);
        graph_->insertNode(mtc);
        counters.push_back(mtc);
      }

      for (auto ib : it->blocks()) {
        insertExplicitUsesOfLoopCounters(ib, counters);
      }
    }
  }

  void removeCounterNodes(std::vector<Node*>& counters) {
    for (auto n : counters) {
      n->destroy();
    }
  }

  void dump(
      const std::unordered_map<Node*, std::vector<Value*>>& liveness_sets) {
    std::cout << "Liveness info:\n";
    for (auto e : liveness_sets) {
      if (!e.first->outputs().empty()) {
        std::cout << e.first->outputs()[0]->debugName();
      }

      std::cout << " " << e.first->kind().toQualString();
      std::cout << " = ";
      dump(e.second);
      std::cout << '\n';
    }
    std::cout << "graph :\n";
    graph_->dump();
  }

  void dump(const std::vector<Value*>& set) {
    bool first = true;
    std::cout << "[";
    for (auto el : set) {
      if (first) {
        first = false;
      } else {
        std::cout << ", ";
      }
      std::cout << el->debugName() << "(" << el->unique() << ")";
    }
    std::cout << "]";
  }

 private:
  SparseBitVector toSparseBitVector(at::ArrayRef<Value*> values) {
    SparseBitVector sbv;
    for (auto v : values) {
      ids_to_values_[v->unique()] = v;
      sbv.set(v->unique());
    }
    return sbv;
  }

  std::vector<Value*> toValueVector(const SparseBitVector& sbv) {
    std::vector<Value*> vec;
    for (auto id : sbv) {
      vec.push_back(ids_to_values_[id]);
    }
    return vec;
  }

  SparseBitVector processBlock(Block* b, SparseBitVector liveness) {
    // block outputs are the uses
    auto block_outputs = toSparseBitVector(b->outputs());
    liveness |= block_outputs;

    SparseBitVector defs;
    for (Node* it : b->nodes().reverse()) {
      // kill outputs
      liveness -= toSparseBitVector(it->outputs());
      if (it->kind() == prim::Loop) {
        LoopView lv(it);
        // N.B. merge in changes from the loop header
        auto loop_header = *lv.bodyBlock()->nodes().begin();
        auto loop_block = liveness | liveness_sets_[loop_header];
        loop_block = processBlock(lv.bodyBlock(), loop_block);
        // loop block's inputs die outside loop's block
        loop_block -= toSparseBitVector(lv.bodyBlock()->inputs());
        liveness |= loop_block;
      } else if (it->kind() == prim::If) {
        IfView iv(it);
        auto true_liveness = processBlock(iv.thenBlock(), liveness);
        auto false_liveness = processBlock(iv.elseBlock(), liveness);
        liveness |= true_liveness;
        liveness |= false_liveness;
      }
      liveness |= toSparseBitVector(it->inputs());
      // `|=` returns true if new bits were set in LHS
      // after or/union with `liveness`
      auto changed = liveness_sets_[it] |= liveness;
      changed_ = changed_ | changed;
    }
    return liveness;
  }

  std::shared_ptr<Graph> graph_;
  bool changed_{false};
  std::map<Node*, SparseBitVector> liveness_sets_;
  std::map<size_t, Value*> ids_to_values_;
};

std::unordered_map<Node*, std::vector<Value*>> BuildLivenessSets(
    std::shared_ptr<Graph> graph) {
  LivenessAnalyzer la(std::move(graph));
  return la.run();
}

} // namespace torch::jit
