#include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
#include <torch/csrc/jit/tensorexpr/half_support.h>

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/native/cuda/jit_utils.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
#include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/cuda_random.h>
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/registerizer.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty_strided_native.h>
#endif

#include <unordered_map>
#include <utility>

namespace torch::jit::tensorexpr {

// A RAII wrapper to manage a variable and name pair in the look-up table.
// TODO: move this to a more shared place.
class ScopedVarName {
 public:
  ScopedVarName(VarNameMap* mapping, const VarPtr& var, const std::string& name)
      : mapping_(mapping), var_(var) {
    auto iter = mapping->find(var);
    if (iter != mapping->end()) {
      throw std::runtime_error("Duplicate var entry: " + var->name_hint());
    }
    mapping->insert(std::make_pair(var, name));
  }

  ScopedVarName(
      UniqueNameManager* manager,
      const VarPtr& var,
      const std::string& name)
      : ScopedVarName(&manager->unique_name_mapping_, var, name) {}

  ScopedVarName(const ScopedVarName&) = delete;
  ScopedVarName& operator=(const ScopedVarName&) = delete;

  ~ScopedVarName() noexcept(false) {
    mapping_->erase(var_);
  }

 private:
  VarNameMap* mapping_ = nullptr;
  VarPtr var_ = nullptr;
};

static bool is_zero(const ExprPtr& expr) {
  auto v = intValue(expr);
  return v && *v == 0;
}

static const at::cuda::NVRTC& nvrtc() {
  return at::globalContext().getNVRTC();
}

std::string CudaPrinter::dtypeToCppString(const Dtype& dtype) {
  switch (dtype.scalar_type()) {
    case ScalarType::Bool:
      return "bool";
    case ScalarType::Half:
      return "half";
    case ScalarType::BFloat16:
      return fuser::cuda::bfloat16_type_string;
    case ScalarType::Char:
      return "char";
    case ScalarType::Byte:
      return "unsigned char";
    case ScalarType::Short:
      return "short";
    case ScalarType::Long:
      return "long long";
    default:
      return dtype.ToCppString();
  }
}

void CudaAnalysis::visit(const FreePtr& v) {
  if (thread_local_bufs_.count(v->buffer_var()) == 0 &&
      cross_block_bufs_.count(v->buffer_var()) == 0) {
    throw std::runtime_error("Global free not supported yet");
  }
}

void CudaAnalysis::visit(const AllocatePtr& v) {
  StmtPtr p = v->get_parent();
  while (p) {
    ForPtr for_v = to<For>(p);
    if (for_v) {
      if (for_v->loop_options().is_gpu_block_index()) {
        // TODO: This isn't right if there's a thread index at a higher level
        // than this.
        cross_block_bufs_.insert(v->buffer_var());
        return;
      } else if (for_v->loop_options().is_gpu_thread_index()) {
        thread_local_bufs_.insert(v->buffer_var());
        return;
      }
    }
    p = p->get_parent();
  }
  throw std::runtime_error("Global alloc not supported yet");
}

void CudaAnalysis::visit(const PlacementAllocatePtr& v) {
  throw std::runtime_error("Memory reuse not supported yet");
}

void CudaAnalysis::visit(const ForPtr& v) {
  // Recurse first.
  v->body()->accept(this);

  const LoopOptions& loop_options = v->loop_options();
  if (loop_options.is_gpu_block_index()) {
    int gpu_block_index = loop_options.gpu_block_index();
    if (gpu_block_index >= 3) {
      throw std::runtime_error("support only 3D gpu_block_index");
    }
    ExprPtr prev = nullptr;
    if (gpu_block_extents_.size() <= static_cast<size_t>(gpu_block_index)) {
      gpu_block_extents_.resize(gpu_block_index + 1);
    } else {
      prev = gpu_block_extents_[gpu_block_index];
    }
    if (!is_zero(v->start())) {
      throw std::runtime_error(
          "start must be zero for gpu_block_index: " +
          std::to_string(v->start()));
    }

    // NOLINTNEXTLINE(bugprone-branch-clone)
    if (prev == nullptr) {
      gpu_block_extents_[gpu_block_index] = v->stop();
    } else if (prev->isConstant() && immediateEquals(prev, 1)) {
      // extents must be positive so if the current extent is 1 then even if the
      // stop is symbolic it's the max.
      gpu_block_extents_[gpu_block_index] = v->stop();
    } else {
      gpu_block_extents_[gpu_block_index] =
          IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
    }
  } else if (loop_options.is_gpu_thread_index()) {
    int gpu_thread_index = loop_options.gpu_thread_index();
    if (gpu_thread_index >= 3) {
      throw std::runtime_error("support only 3D gpu_thread_index");
    }
    ExprPtr prev = nullptr;
    if (gpu_thread_extents_.size() <= static_cast<size_t>(gpu_thread_index)) {
      gpu_thread_extents_.resize(gpu_thread_index + 1);
    } else {
      prev = gpu_thread_extents_[gpu_thread_index];
    }
    if (!is_zero(v->start())) {
      throw std::runtime_error(
          "start must be zero for gpu_thread_index: " +
          std::to_string(v->start()));
    }

    // NOLINTNEXTLINE(bugprone-branch-clone)
    if (prev == nullptr) {
      gpu_thread_extents_[gpu_thread_index] = v->stop();
    } else if (prev->isConstant() && immediateEquals(prev, 1)) {
      // extents must be positive so if the current extent is 1 then even if the
      // stop is symbolic it's the max.
      gpu_thread_extents_[gpu_thread_index] = v->stop();
    } else {
      gpu_thread_extents_[gpu_thread_index] =
          IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
    }
  }
}

void CudaPrinter::print_flat_alloc(const AllocatePtr& alloc) {
  std::vector<ExprPtr> dims = alloc->dims();
  // TODO: this should be merged with the storage flattener.
  int64_t flat_size = 1;
  for (const auto& dim : dims) {
    auto dim_i = intValue(dim);
    if (dim_i) {
      flat_size *= *dim_i;
    } else {
      throw std::runtime_error("Only integer dimensions are supported for now");
    }
  }
  os() << dtypeToCppString(alloc->dtype()) << " " << (*alloc->buffer_var())
       << "[" << flat_size << "];" << '\n';
}

void CudaPrinter::visit(const AllocatePtr& v) {
  // TODO: handle dynamic shapes here.
  if (cuda_analysis_->cross_block_bufs().count(v->buffer_var()) != 0) {
    emitIndent();
    os() << "__shared__ ";
    print_flat_alloc(v);
    return;
  }

  if (cuda_analysis_->thread_local_bufs().count(v->buffer_var()) != 0) {
    emitIndent();
    print_flat_alloc(v);
    return;
  }

  throw std::runtime_error("Encountered Alloc not local to block or thread");
}

void CudaPrinter::visit(const FreePtr& v) {
  // do nothing
}

void CudaPrinter::visit(const ForPtr& v) {
  IRPrinter::visit(v);
}

void CudaPrinter::visit(const CastPtr& v) {
  std::string castFn = v->dtype().scalar_type() == ScalarType::Half
      ? "__float2half"
      : v->dtype().scalar_type() == ScalarType::BFloat16 ? "__float2bfloat16"
      : v->src_value()->dtype().scalar_type() == ScalarType::Half
      ? "__half2float"
      : v->src_value()->dtype().scalar_type() == ScalarType::BFloat16
      ? "__bfloat162float"
      : ("(" + dtypeToCppString(v->dtype()) + ")");
  os() << castFn << "(";
  v->src_value()->accept(this);
  os() << ")";
}

void CudaPrinter::visit(const IntrinsicsPtr& v) {
  if (v->op_type() == IntrinsicsOp::kRand) {
    os() << "Uint32ToFloat(" << *rand_func_ << "())";
    return;
  }

  std::string func_name = v->func_name();

  // get type of resulting expression.
  ScalarType returnType = v->param(0)->dtype().scalar_type();
  for (size_t i = 1; i < v->nparams(); ++i) {
    returnType = promoteTypes(returnType, v->param(i)->dtype().scalar_type());
  }

  if (returnType == ScalarType::Half || returnType == ScalarType::Float) {
    func_name = func_name + "f";
  }
  if (v->op_type() == IntrinsicsOp::kAbs &&
      !c10::isIntegralType(returnType, true)) {
    // since kAbs's func_name is `abs`, prefix `f` for floating point
    func_name = "f" + func_name;
  }
  if (v->op_type() == IntrinsicsOp::kIsNan) {
    func_name = "isnan";
  }

  os() << func_name << "(";
  for (const auto i : c10::irange(v->nparams())) {
    if (i > 0) {
      os() << ", ";
    }
    os() << *v->param(i);
  }
  os() << ")";
}

void CudaPrinter::visit(const ExternalCallPtr& v) {
  throw unimplemented_lowering(v);
}

void CudaPrinter::visit(const LoadPtr& v) {
  // TODO: find a better metric in using ldg or not. Support different dtypes.
  // Detects whether the load target is also a store target.
  // TODO: this is currently too wide. It detects whether a store-target
  // exists within the program. In fact, this check is only necessary within a
  // kernel.
  if (v->indices().empty()) {
    os() << *v->base_handle();
    return;
  }
  if (v->dtype().scalar_type() == ScalarType::Bool ||
      v->dtype().scalar_type() == ScalarType::Half ||
      v->dtype().scalar_type() == ScalarType::BFloat16) {
    // There's no __ldg overload for bool or half.
    os() << *v->base_handle() << "[" << *v->flat_index() << "]";
    return;
  }
  if (cuda_analysis_->is_buf_store_target(v->buf())) {
    // Cuda __ldg can only be applied on read-only buffers.
    os() << *v->base_handle() << "[" << *v->flat_index() << "]";
    return;
  }
  os() << "__ldg(" << *v->base_handle() << " + " << *v->flat_index() << ")";
}

// TODO: maybe this should be a more shared location?
// TODO: investigate how "ExprPtr" can be implicitly converted to "ExprHandle"
// as a bool.
static bool CheckEqual(const ExprPtr& lhs, const ExprPtr& rhs) {
  // The fast path. Checks if the pointers are the same.
  if (lhs == rhs) {
    return true;
  }
  ExprHandle diff = Sub::make(ExprHandle(lhs), ExprHandle(rhs));
  ExprHandle diff_s = IRSimplifier::simplify(diff);
  return immediateEquals(diff_s.node(), 0);
}

class AtomicAddFuser : public IRMutator {
 public:
  AtomicAddFuser(
      const std::unordered_set<VarPtr>& thread_local_bufs,
      const GPUMetaVarRewriter& metavars)
      : thread_local_bufs_(thread_local_bufs) {
    const std::vector<ExprPtr>& block_extents = metavars.gpu_block_extents();
    const std::vector<VarPtr>& block_vars = metavars.gpu_block_vars();
    for (size_t i = 0; i < block_extents.size(); ++i) {
      MetaVarExtent extent{block_extents[i], false};
      if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
        extent.trivial = true;
      } else {
        nontrivial_metavars_.insert(block_vars[i]);
      }
      metavars_[block_vars[i]] = extent;
    }

    const std::vector<ExprPtr>& thread_extents = metavars.gpu_thread_extents();
    const std::vector<VarPtr>& thread_vars = metavars.gpu_thread_vars();
    for (size_t i = 0; i < thread_extents.size(); ++i) {
      MetaVarExtent extent{thread_extents[i], false};
      if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
        extent.trivial = true;
      } else {
        nontrivial_metavars_.insert(thread_vars[i]);
      }
      metavars_[thread_vars[i]] = extent;
    }
  }

  StmtPtr mutate(const StorePtr& v) override {
    BufPtr buf = v->buf();

    // Thread locals never need to be atomic.
    if (thread_local_bufs_.count(buf->base_handle()) != 0) {
      return v;
    }

    ScalarType dtype = v->value()->dtype().scalar_type();
    if (dtype != ScalarType::Float && dtype != ScalarType::Double) {
      return v;
    }
    AddPtr add_v = to<Add>(v->value());
    if (!add_v) {
      return v;
    }
    LoadPtr load_v = to<Load>(add_v->lhs());
    if (!load_v) {
      return v;
    }
    if (v->base_handle() != load_v->base_handle()) {
      return v;
    }
    if (v->indices().empty() && load_v->indices().empty()) {
      return v;
    }
    bool index_equal = CheckEqual(v->flat_index(), load_v->flat_index());
    if (!index_equal) {
      return v;
    }

    // TODO: this checks that the metavars occur directly as an index, but this
    // is pessimistic, blockIdx.x + 1 is fine too if there is no overlapping.
    std::unordered_set<VarPtr> vars_to_find = nontrivial_metavars_;
    for (const ExprPtr& e : v->indices()) {
      if (VarPtr v = to<Var>(e)) {
        vars_to_find.erase(v);
      }
    }

    if (vars_to_find.empty()) {
      // All metavars accounted for.
      return v;
    }

    return alloc<AtomicAdd>(buf, v->indices(), add_v->rhs());
  }

 private:
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  const std::unordered_set<VarPtr>& thread_local_bufs_;
  struct MetaVarExtent {
    ExprPtr expr{nullptr};
    bool trivial{false};
  };
  std::unordered_map<VarPtr, MetaVarExtent> metavars_;
  std::unordered_set<VarPtr> nontrivial_metavars_;
};

void CudaPrinter::visit(const StorePtr& v) {
  emitIndent();
  if (v->indices().empty()) {
    os() << *v->base_handle() << " = ";
  } else {
    os() << *v->base_handle() << "[" << *v->flat_index() << "] = ";
  }
  os() << *v->value() << ";";
  os() << '\n';
}

void CudaPrinter::visit(const AtomicAddPtr& v) {
  emitIndent();
  if (cuda_analysis_->thread_local_bufs().count(v->base_handle()) > 0) {
    // atomicAdd only works on global and shared memory
    os() << *v->base_handle() << "[" << *v->flat_index()
         << "] += " << *v->value() << ";";
  } else {
    os() << "atomicAdd(&" << *v->base_handle() << "[" << *v->flat_index() << "]"
         << ", " << *v->value() << ");";
  }
  os() << '\n';
}

void CudaPrinter::visit(const MaxPtr& v) {
  if (v->dtype().is_integral()) {
    os() << "max(";
  } else {
    os() << "maximum(";
  }
  v->lhs()->accept(this);
  os() << ",";
  v->rhs()->accept(this);
  os() << ")";
}

void CudaPrinter::visit(const MinPtr& v) {
  if (v->dtype().is_integral()) {
    os() << "min(";
  } else {
    os() << "minimum(";
  }
  v->lhs()->accept(this);
  os() << ",";
  v->rhs()->accept(this);
  os() << ")";
}

void CudaPrinter::visit(const IfThenElsePtr& v) {
  os() << "((";
  v->condition()->accept(this);
  os() << ") ? ";
  v->true_value()->accept(this);
  os() << " : ";
  v->false_value()->accept(this);
  os() << ")";
}

void CudaPrinter::visit(const BlockPtr& v) {
  os() << "{" << '\n';
  indent_++;

  for (const StmtPtr& s : v->stmts()) {
    s->accept(this);
  }

  indent_--;
  emitIndent();
  os() << "}";
}

void CudaPrinter::visit(const LetPtr& v) {
  emitIndent();
  os() << dtypeToCppString(v->var()->dtype());
  os() << " " << *v->var() << " = ";
  v->value()->accept(this);
  os() << ";" << '\n';
}

class PrioritizeLoad : public IRMutator {
 public:
  ExprPtr mutate(const LoadPtr& v) override {
    // Look at the declaration of this variable for more details.
    if (nested_if_then_else_ > 0) {
      return IRMutator::mutate(v);
    }
    if (nested_let_) {
      return IRMutator::mutate(v);
    }
    if (thread_local_bufs_.count(v->base_handle()) > 0) {
      return IRMutator::mutate(v);
    }
    if (v->indices().empty()) {
      return IRMutator::mutate(v);
    }
    if (nested_store_) {
      if (v->base_handle() == nested_store_->buf()->base_handle() &&
          v->indices().size() == nested_store_->indices().size()) {
        // also check indices
        bool same = true;
        for (const auto i : c10::irange(v->indices().size())) {
          if (!exprEquals(v->indices()[i], nested_store_->indices()[i])) {
            same = false;
            break;
          }
        }
        if (same) {
          return IRMutator::mutate(v);
        }
      } else if (nested_store_->indices().empty()) {
        return IRMutator::mutate(v);
      }
    }

    MemLoadList& load_list = load_stack_.back();
    VarPtr load_new_var = alloc<Var>("v", v->dtype());
    ExprPtr new_value = IRMutator::mutate(v);
    load_list.emplace_back(load_new_var, new_value);

    return load_new_var;
  }

  ExprPtr mutate(const CastPtr& v) override {
    LoadPtr src_load = to<Load>(v->src_value());
    ExprPtr new_src = v->src_value()->accept_mutator(this);
    VarPtr new_var = to<Var>(new_src);
    if (!src_load || !new_var) {
      return alloc<Cast>(v->dtype(), new_src);
    }

    // We just did the prioritize load, let's fold in the Cast.
    MemLoadList& load_list = load_stack_.back();
    assert(!load_list.empty());
    auto pair = load_list.back();
    assert(pair.first == new_var);
    load_list.pop_back();

    new_var = alloc<Var>("v", v->dtype());
    ExprPtr new_value = alloc<Cast>(v->dtype(), pair.second);
    load_list.emplace_back(new_var, new_value);
    return new_var;
  }

  StmtPtr mutate(const StorePtr& v) override {
    StorePtr last = nested_store_;
    nested_store_ = v;
    StmtPtr s = IRMutator::mutate(v);
    nested_store_ = last;
    return s;
  }

  StmtPtr mutate(const LetPtr& v) override {
    nested_let_ = true;
    StmtPtr s = IRMutator::mutate(v);
    nested_let_ = false;
    return s;
  }

  StmtPtr mutate(const BlockPtr& v) override {
    std::list<StmtPtr> stmts = v->stmts();
    for (const StmtPtr& stmt : stmts) {
      PushList();
      StmtPtr stmt_new = stmt->accept_mutator(this);

      AddMemLoadsFromList(v, stmt);
      PopList();

      if (stmt_new == stmt) {
        continue;
      }
      v->replace_stmt(stmt, stmt_new);
    }
    return v;
  }

  ExprPtr mutate(const IfThenElsePtr& v) override {
    nested_if_then_else_++;
    ExprPtr new_v = IRMutator::mutate(v);
    nested_if_then_else_--;
    return new_v;
  }

 private:
  using MemLoadEntry = std::pair<VarPtr, ExprPtr>;
  using MemLoadList = std::vector<MemLoadEntry>;
  using MemoryLoadStack = std::vector<MemLoadList>;

  void PushList() {
    load_stack_.emplace_back();
  }

  void PopList() {
    load_stack_.pop_back();
  }

  void AddMemLoadsFromList(const BlockPtr& block, const StmtPtr& last) {
    MemLoadList& load_list = load_stack_.back();
    if (load_list.empty()) {
      return;
    }

    for (auto& pair : load_list) {
      StmtPtr news = alloc<Let>(pair.first, pair.second);
      block->insert_stmt_before(news, last);
    }
  }

  MemoryLoadStack load_stack_;
  // TODO: For now, we are not moving the loads with the IfThenElse.
  // Eventually, we should switch to a more generic structure like:
  // int v2 = IfThenElse(cond, true_v, false_v) + 2 ->
  //
  // int v;
  // if (cond) {
  //   v = true_v;
  // } else {
  //   v = false_v;
  // }
  // int v2 = v + 2;
  int nested_if_then_else_{0};
  StorePtr nested_store_{nullptr};
  bool nested_let_{false};
  std::unordered_set<VarPtr> thread_local_bufs_;
};

std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
  int64_t counter = 0;
  std::string name = func_prefix;
  while (taken_func_names.count(name)) {
    name = func_prefix + "_" + std::to_string(counter++);
  }

  taken_func_names.insert(name);
  return name;
}

bool GPUMetaVarRewriter::isFullExtent() {
  {
    auto& extents = cuda_analysis_->gpu_block_extents();
    for (int i = 0; i < 3; ++i) {
      if (!exprEquals(current_block_reach_[i], extents[i])) {
        return false;
      }
    }
  }

  {
    auto& extents = cuda_analysis_->gpu_thread_extents();
    for (int i = 0; i < 3; ++i) {
      if (!exprEquals(current_thread_reach_[i], extents[i])) {
        return false;
      }
    }
  }

  return true;
}

StmtPtr GPUMetaVarRewriter::mutate(const ForPtr& v) {
  StmtPtr body = v->body();
  ExprPtr old_reach = nullptr;
  const LoopOptions& loop_options = v->loop_options();
  if (loop_options.is_gpu_block_index()) {
    int gpu_block_index = loop_options.gpu_block_index();
    if (gpu_block_index >= 3) {
      throw std::runtime_error("support only 3D gpu_block_index");
    }
    old_reach = current_block_reach_[gpu_block_index];

    // Extents must be positive, assume >= 1.
    if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
      current_block_reach_[gpu_block_index] = v->stop();
    } else {
      current_block_reach_[gpu_block_index] =
          IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
    }

    VarPtr metaVar = gpu_block_vars_[gpu_block_index];
    body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
  } else if (loop_options.is_gpu_thread_index()) {
    int gpu_thread_index = loop_options.gpu_thread_index();
    if (gpu_thread_index >= 3) {
      throw std::runtime_error("support only 3D gpu_thread_index");
    }
    old_reach = current_thread_reach_[gpu_thread_index];

    // Extents must be positive, assume >= 1.
    if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
      current_thread_reach_[gpu_thread_index] = v->stop();
    } else {
      current_thread_reach_[gpu_thread_index] =
          IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
    }

    VarPtr metaVar = gpu_thread_vars_[gpu_thread_index];
    body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
  }

  // Recurse into body block.
  body = Stmt::clone(body->accept_mutator(this));

  // pop the internal reach off the stack.
  if (loop_options.is_gpu_block_index()) {
    current_block_reach_[loop_options.gpu_block_index()] = old_reach;
    return body;
  } else if (loop_options.is_gpu_thread_index()) {
    current_thread_reach_[loop_options.gpu_thread_index()] = old_reach;
    return body;
  }

  return v->cloneWithNewBody(body);
}

StmtPtr GPUMetaVarRewriter::mutate(const BlockPtr& v) {
  std::vector<Segment> innerSegments;
  Segment current;

  auto pushAndReset = [&](bool mask) {
    if (!current.empty()) {
      innerSegments.push_back(current);
    }
    current.reset(mask);
  };

  // Here's we're slicing the Block's contents into segments that should have
  // the same launch reach. Segments are comprised of all statements that aren't
  // loops - which are their own segments. Some operations, such as threading
  // and memory ops should never be masked and so also get their own segment.
  for (const StmtPtr& stmt : *v) {
    StmtPtr stmt_new = stmt->accept_mutator(this);
    if (stmt == stmt_new) {
      stmt_new = Stmt::clone(stmt_new);
    }

    // Likewise, Allocate and Free should never be masked.
    if (to<Allocate>(stmt) || to<Free>(stmt)) {
      pushAndReset(false);
    }

    // If the current stmt *was* a loop, it's a segment boundary.
    if (ForPtr f = to<For>(stmt)) {
      pushAndReset(false);
    }

    current.stmts().push_back(stmt_new);
    // if the current segment should not be masked, it's a segment boundary on
    // the far side as well.
    if (!current.mask()) {
      pushAndReset(true);
    }
  }

  if (!current.empty()) {
    innerSegments.push_back(current);
  }

  // We are max extent in all dimensions, so need no masks at this level.
  if (isFullExtent()) {
    // flatten inner segments.
    std::vector<StmtPtr> stmts;
    for (auto& v : innerSegments) {
      for (const auto& s : v.stmts()) {
        stmts.push_back(s);
      }
    }

    return alloc<Block>(stmts);
  }

  std::vector<StmtPtr> stmts;
  for (auto& segment : innerSegments) {
    bool need_sync = false;
    // We never mask loops, they'll mask their contents.
    if (!segment.mask()) {
      TORCH_INTERNAL_ASSERT(segment.stmts().size() == 1, buildErrorMessage());
      stmts.push_back(segment.stmts()[0]);
      continue;
    }

    // If we get here, we must mask since we're not full reach and our direct
    // child isn't a For.
    StmtPtr inner = alloc<Block>(segment.stmts());
    // threads inside blocks.
    auto& thread_extents = cuda_analysis_->gpu_thread_extents();
    for (size_t i = 0; i < gpu_thread_vars_.size(); ++i) {
      if (!exprEquals(current_thread_reach_[i], thread_extents[i])) {
        need_sync = true;
        // Mask it against the current dimensions.
        inner = alloc<Cond>(
            alloc<CompareSelect>(
                gpu_thread_vars_[i],
                current_thread_reach_[i],
                CompareSelectOperation::kLT),
            inner,
            nullptr);
      }
    }
    auto& block_extents = cuda_analysis_->gpu_block_extents();
    for (size_t i = 0; i < gpu_block_vars_.size(); ++i) {
      if (!exprEquals(current_block_reach_[i], block_extents[i])) {
        // Mask it against the current dimensions.
        inner = alloc<Cond>(
            alloc<CompareSelect>(
                gpu_block_vars_[i],
                current_block_reach_[i],
                CompareSelectOperation::kLT),
            inner,
            nullptr);
      }
    }

    if (need_sync) {
      stmts.push_back(alloc<SyncThreads>());
    }
    stmts.push_back(inner);
    if (need_sync) {
      stmts.push_back(alloc<SyncThreads>());
    }
  }

  return alloc<Block>(stmts);
}

static std::ostream& operator<<(
    std::ostream& out,
    const std::vector<ExprPtr>& exprs) {
  size_t i = 0;
  for (const auto& expr : exprs) {
    if (i++ > 0) {
      out << ", ";
    }
    out << *expr;
  }
  return out;
}

static const char* device_resource_string = R"(
#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)

)";

static const char* shared_resource_string = R"(
template<typename T>
__device__ T maximum(T a, T b) {
  return isnan(a) ? a : (a > b ? a : b);
}

template<typename T>
__device__ T minimum(T a, T b) {
  return isnan(a) ? a : (a < b ? a : b);
}

)";

void CudaCodeGen::Initialize() {
  // TODO: handle multiple kernels.
  // TODO: handle dynamic dimension.
  // TODO: call nvrtc.
  // TODO: merge HasRand with CudaAnalysis.
  GenericIntrinsicsExpander intrinsics_expander;
  apply_mutator(&intrinsics_expander);

  HasRand has_rand_func(stmt());
  has_random_ = has_rand_func.has_rand();
  cuda_analysis_ = std::make_unique<CudaAnalysis>();
  printer_ =
      std::make_unique<CudaPrinter>(&oss_, cuda_analysis_.get(), has_random_);
  metavar_rewriter_ =
      std::make_unique<GPUMetaVarRewriter>(cuda_analysis_.get());

  // Check whether the statement uses the Half type, if so add the
  // half_support_literal.
  StmtPtr stmt_v = stmt();
  HalfChecker halfChecker(buffer_args());
  stmt_v->accept(&halfChecker);

  os() << device_resource_string << shared_resource_string;

  if (has_random_) {
    os() << philox_random_string << '\n';
  }

  if (halfChecker.hasHalf()) {
    os() << fuser::cuda::half_support_literal << '\n';
  }
  if (halfChecker.hasBFloat16()) {
    os() << fuser::cuda::bfloat16_support_literal << '\n';
  }

  std::string func_name = GetUniqueFuncName(kernel_func_name());
  os() << "extern \"C\" __global__" << '\n';
#if defined(USE_ROCM)
  // CUDA has a default limit of threads per block (=flat work group size)
  // of 1024, but ROCm uses 256 by default. At the time of writing
  // (#45506), I am unaware of a stricter limit that TensorExpr imposes
  // (maybe for perf),so I use 1024 as maximum flat work group size.
  // We put a minimum value of 1, this is also used by hip (ROCm 3.8) in
  // the __launch_bound__ implementation. The arguments for the attribute
  // are (min, max), for details see the documentation at
  // https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
  os() << "__attribute__((amdgpu_flat_work_group_size(1, 1024)))" << std::endl;
#endif
  os() << "void " << func_name << "(";
  const std::vector<BufferArg> buffer_args = this->buffer_args();
  for (size_t i = 0; i < buffer_args.size(); i++) {
    if (i > 0) {
      os() << ", ";
    }
    const BufferArg& buffer_arg = buffer_args[i];
    VarPtr var = buffer_arg.var();
    Dtype dtype = buffer_arg.dtype();

    os() << printer_->dtypeToCppString(dtype)
         << (buffer_arg.isVar() ? " " : "* ")
         << name_manager()->get_unique_name(var);
  }
  VarPtr rand_seed;
  VarPtr rand_offset;
  if (has_random_) {
    // TODO: switch to kUint64 when it is available.
    rand_seed = alloc<Var>("rand_seed", kInt);
    rand_offset = alloc<Var>("rand_offset", kInt);
    std::string uint64_str = "unsigned long long";
    os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " "
         << *rand_offset;
  }
  os() << ") {";
  os() << '\n';

  if (has_random_) {
    VarPtr idx = alloc<Var>("idx", kInt);
    os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;" << '\n';
    VarPtr rand_func = printer_->rand_func();
    os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
         << *rand_offset << ");" << '\n';
    os() << '\n';
  }

  stmt_v->accept(cuda_analysis_.get());

  stmt_v = stmt_v->accept_mutator(metavar_rewriter_.get());

  AtomicAddFuser atomic_add_fuser(
      cuda_analysis_->thread_local_bufs(), *metavar_rewriter_);
  stmt_v = stmt_v->accept_mutator(&atomic_add_fuser);

  stmt_v = registerize(stmt_v);

  PrioritizeLoad prioritize_load;
  stmt_v = stmt_v->accept_mutator(&prioritize_load);

  // The registerizer might insert half-type scalars, we don't want this.
  HalfRewriter hsFix;
  stmt_v = stmt_v->accept_mutator(&hsFix);

  stmt_v = IRSimplifier::simplify(stmt_v);
  set_stmt(stmt_v);

  stmt_v->accept(printer_.get());
  os() << '\n';
  os() << "}";

  // Check that all block extents had been set.
  const std::vector<ExprPtr>& gpu_block_extents =
      metavar_rewriter_->gpu_block_extents();
  for (size_t i = 0; i < gpu_block_extents.size(); i++) {
    if (!gpu_block_extents[i]) {
      throw std::runtime_error("Missing gpu_block_index: " + std::to_string(i));
    }
  }

  // Precompute block and thread extents for call_with_numel().  If
  // precomputation can't be done (block/thread extents aren't
  // constant), then disallow call_with_numel.
  auto block_extents = metavar_rewriter_->gpu_block_extents();
  auto thread_extents = metavar_rewriter_->gpu_thread_extents();
  bool canCallWithNumel =
      !has_random_ && !block_extents.empty() && !thread_extents.empty();
  for (size_t i = 1; i < block_extents.size() && canCallWithNumel; i++) {
    canCallWithNumel = canCallWithNumel && block_extents[i]->isConstant() &&
        immediateAs<int>(block_extents[i]) == 1;
  }
  for (size_t i = 1; i < thread_extents.size() && canCallWithNumel; i++) {
    canCallWithNumel = canCallWithNumel && thread_extents[i]->isConstant() &&
        immediateAs<int>(thread_extents[i]) == 1;
  }
  if (canCallWithNumel && thread_extents[0]->isConstant()) {
    // We assume block_extents[0] is output.numel()/thread_block_size_.
    thread_block_size_ = immediateAs<int>(thread_extents[0]);
  } else {
    // Disable call_with_numel.
    thread_block_size_ = -1;
  }

  // Build an LLVM based eval expression for the extents
  block_extents_eval_.reserve(block_extents.size());
  std::vector<BufferArg> extents_buffer_args;

  // We need to extract the args that are used in the thread and block extents
  // from bufferArgs and only use those for the `ExprEval` below. Without this,
  // bufferArgs might contain arbitrary types that are not handled by LLVM and
  // hence would result in an error.
  std::unordered_set<VarPtr> vars_in_extents;
  for (const auto& be : block_extents) {
    auto v = VarFinder::find(be);
    vars_in_extents.insert(v.begin(), v.end());
  }
  for (const auto& te : thread_extents) {
    auto v = VarFinder::find(te);
    vars_in_extents.insert(v.begin(), v.end());
  }
  for (const size_t i : c10::irange(buffer_args.size())) {
    if (vars_in_extents.count(buffer_args[i].var())) {
      extents_buffer_args.push_back(buffer_args[i]);
      arg_pos_in_extents_.push_back(true);
    } else {
      arg_pos_in_extents_.push_back(false);
    }
  }
  for (const auto& be : block_extents) {
#ifdef TORCH_ENABLE_LLVM
    block_extents_eval_.emplace_back(
        ExprEval<LLVMCodeGen>(ExprHandle(be), extents_buffer_args));
#else
    block_extents_eval_.emplace_back(ExprHandle(be), extents_buffer_args);
#endif
  }
  thread_extents_eval_.reserve(thread_extents.size());
  for (const auto& te : thread_extents) {
#ifdef TORCH_ENABLE_LLVM
    thread_extents_eval_.emplace_back(
        ExprEval<LLVMCodeGen>(ExprHandle(te), extents_buffer_args));
#else
    thread_extents_eval_.emplace_back(ExprHandle(te), extents_buffer_args);
#endif
  }

  GRAPH_DEBUG(
      "Fused TE CUDA kernel:\n",
      oss_.str(),
      "\n",
      "gpu_block_extents: (",
      metavar_rewriter_->gpu_block_extents(),
      ")\n",
      "gpu_thread_extents: (",
      metavar_rewriter_->gpu_thread_extents(),
      ")");

  CompileToNVRTC(oss_.str(), func_name);
}

void CudaCodeGen::call_with_numel(void** args, int64_t numel) {
  if (C10_UNLIKELY(numel == 0)) {
    return;
  }
  if (C10_UNLIKELY(thread_block_size_ <= 0)) {
    TORCH_INTERNAL_ASSERT(
        thread_block_size_ >= 0,
        "call_with_numel() requires a precomputed thread block size");
  }

  auto const& buffer_args = this->buffer_args();
  size_t gpu_block_extents =
      (numel + thread_block_size_ - 1) / thread_block_size_;
  size_t gpu_thread_extents = thread_block_size_;

  // In CUDA we need to pass pointers to pointers for buffers, thus we need to
  // go over args and add an extra indirection for such non-scalar
  // arguments.
  // Why? See some details here:
  // https://stackoverflow.com/questions/34388712/cannot-understand-how-jcuda-culaunchkernel-work
  std::vector<void*> ptr_to_args(buffer_args.size());
  for (size_t i = 0; i < buffer_args.size(); i++) {
    ptr_to_args[i] =
        buffer_args[i].isVar() ? args[i] : const_cast<void**>(&args[i]);
  }

  const auto device = this->device().index();
  const auto prior_device = at::cuda::current_device();
  if (prior_device != device) {
    at::cuda::set_device(device);
  }

  auto stream = at::cuda::getCurrentCUDAStream();
  at::cuda::jit::initializeCudaContext();
  AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
      function_,
      gpu_block_extents,
      1,
      1,
      gpu_thread_extents,
      1,
      1,
      0,
      stream,
      ptr_to_args.data(),
      nullptr));

  if (prior_device != device) {
    at::cuda::set_device(prior_device);
  }
}

void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
  auto const& buffer_args = this->buffer_args();

  // TODO: move as much of this into the constructors.
  const std::vector<ExprPtr>& gpu_block_extents =
      metavar_rewriter_->gpu_block_extents();
  const std::vector<ExprPtr>& gpu_thread_extents =
      metavar_rewriter_->gpu_thread_extents();
  if (gpu_block_extents.size() > 3 || gpu_thread_extents.size() > 3) {
    throw malformed_input(
        "cuda_codegen: block or thread extent greater than 3D");
  }

  std::vector<int64_t> gpu_block_extents_v(3, 1);
  std::vector<int64_t> gpu_thread_extents_v(3, 1);

  // evaluate all the block/thread extents into values
  // TODO: eventually, codegen these calculations and make them part of the
  // module.
  std::vector<void*> extent_args;
  size_t raw_args_size = raw_args.size();
  extent_args.reserve(raw_args_size);
  for (size_t i = 0; i < raw_args_size; ++i) {
    if (arg_pos_in_extents_[i]) {
      extent_args.push_back(raw_args[i]);
    }
  }
  for (size_t i = 0; i < gpu_block_extents.size(); i++) {
    if (gpu_block_extents[i]->isConstant()) {
      gpu_block_extents_v[i] = immediateAs<int64_t>(gpu_block_extents[i]);
      continue;
    }
    {
      // invocation of block_extents_eval_ isn't thread safe and this function
      // may be invoked by multiple threads
      std::lock_guard<std::mutex> guard(eval_lock_);
      gpu_block_extents_v[i] =
          block_extents_eval_[i].value<int64_t>(extent_args);
    }
  }
  for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
    if (gpu_thread_extents[i]->isConstant()) {
      gpu_thread_extents_v[i] = immediateAs<int64_t>(gpu_thread_extents[i]);
      continue;
    }
    {
      std::lock_guard<std::mutex> guard(eval_lock_);
      gpu_thread_extents_v[i] =
          thread_extents_eval_[i].value<int64_t>(extent_args);
    }
  }

  // Skip launching the kernel if there are no elements to process.
  for (auto extent : gpu_block_extents_v) {
    if (extent == 0) {
      return;
    }
  }

  auto ptr_count = buffer_args.size();
  // If the kernel has a rand call in it, add two extra arguments for random
  // seed and offset.
  if (has_random_) {
    ptr_count += 2;
  }
  std::vector<void*> ptr_to_args(ptr_count);

  // In CUDA we need to pass pointers to pointers for buffers, thus we need to
  // go over raw_args and add an extra indirection for such non-scalar
  // arguments.
  // Why? See some details here:
  // https://stackoverflow.com/questions/34388712/cannot-understand-how-jcuda-culaunchkernel-work
  for (size_t i = 0; i < buffer_args.size(); i++) {
    ptr_to_args[i] =
        buffer_args[i].isVar() ? raw_args[i] : const_cast<void**>(&raw_args[i]);
  }

  if (has_random_) {
    uint64_t rand_seed = uint64_t(-1);
    uint64_t rand_offset = uint64_t(-1);
    auto gen = at::cuda::detail::getDefaultCUDAGenerator();
    // TODO: total hack. Switch to numel when it is available.
    int64_t total_elements_per_thread = (1LL << 28);
    {
      std::lock_guard<std::mutex> lock(gen.mutex());
      auto philox_engine_inputs =
          at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
              total_elements_per_thread);
      rand_seed = philox_engine_inputs.first;
      rand_offset = philox_engine_inputs.second;
    }
    ptr_to_args[buffer_args.size()] = &rand_seed;
    ptr_to_args[buffer_args.size() + 1] = &rand_offset;
  }

  auto prior_device = at::cuda::current_device();
  if (prior_device != this->device().index()) {
    at::cuda::set_device(this->device().index());
  }
  // Launch the kernels
  auto stream = at::cuda::getCurrentCUDAStream();
  at::cuda::jit::initializeCudaContext();
  AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
      function_,
      gpu_block_extents_v[0],
      gpu_block_extents_v[1],
      gpu_block_extents_v[2],
      gpu_thread_extents_v[0],
      gpu_thread_extents_v[1],
      gpu_thread_extents_v[2],
      0,
      stream,
      ptr_to_args.data(),
      nullptr));

  if (prior_device != this->device().index()) {
    at::cuda::set_device(prior_device);
  }
}

void CudaCodeGen::call(const std::vector<CallArg>& args) {
  if (args.size() != buffer_args().size()) {
    throw malformed_input("cuda_codegen: wrong number of args in call");
  }

  auto const& buffer_args = this->buffer_args();
  std::vector<void*> raw_args(buffer_args.size());
  for (size_t i = 0; i < buffer_args.size(); i++) {
    auto const& bufferArg = buffer_args[i];
    auto const& callArg = args[i];
    raw_args[i] = argToPtr(bufferArg, callArg);
  }
  call_raw(raw_args);
}

at::Tensor CudaCodeGen::empty_strided(
    c10::IntArrayRef size,
    c10::IntArrayRef stride,
    std::optional<c10::ScalarType> dtype_opt,
    std::optional<c10::Layout> layout_opt,
    std::optional<c10::Device> device_opt,
    std::optional<bool> pin_memory_opt) {
  c10::DeviceGuard device_guard(device_opt.value());
  return at::native::empty_strided_cuda(
      size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
}

void CudaCodeGen::CompileToNVRTC(
    const std::string& code,
    const std::string& func_name) {
  at::cuda::jit::initializeCudaContext();
  // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
  // properly in some scenarios
  auto prior_device = at::cuda::current_device();
  if (prior_device != this->device().index()) {
    at::cuda::set_device(this->device().index());
  }
  // Acquires device and NVRTC properties (for compile arch and occupancy
  // calculations)
  cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
  int major = 0, minor = 0;
  bool compile_to_sass = false;
  fuser::cuda::codegenOutputQuery(prop, major, minor, compile_to_sass);

  // Creates the NVRTC program
  nvrtcProgram program{nullptr};
  AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
      &program, code.c_str(), nullptr, 0, nullptr, nullptr));

#if defined(USE_ROCM)
  std::vector<const char*> args = {"--std=c++17"};
  args.push_back("-hip-pch");
#else
  const std::string compute = std::string("--gpu-architecture=") +
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
      // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
      // which gives better backwards compatibility to work on older driver,
      // (since older driver doesn't necessarily recognize PTX emitted by new
      // toolkit);
      // Meanwhile, for forward compatibility (future device with
      // `compile_to_sass==false`), since SASS are not necessarily compatible,
      // we fallback to PTX instead.
      (compile_to_sass ? "sm_" : "compute_") +
#else
      "compute_" +
#endif
      std::to_string(major) + std::to_string(minor);
  const std::vector<const char*> args = {
      "--std=c++17", compute.c_str(), "-default-device"};
#endif

  auto result = nvrtc().nvrtcCompileProgram(
      program, static_cast<int>(args.size()), args.data());
  if (result != NVRTC_SUCCESS) {
    size_t logsize = 0;
    AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
    std::vector<char> log(logsize);
    AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
    std::stringstream cu;
    cu << log.data() << '\n';
    cu << "nvrtc compilation failed: " << '\n';
    cu << code << '\n';
    throw std::runtime_error(cu.str());
  }
  ResourceGuard holdProgram(
      [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
  AT_CUDA_NVRTC_CHECK(result);
  size_t ptx_size = 0;
  std::vector<char> ptx;
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
  // compile_to_sass determines whether we are generating SASS or PTX, hence
  // the different API.
  auto getSize = compile_to_sass
      ? at::globalContext().getNVRTC().nvrtcGetCUBINSize
      : at::globalContext().getNVRTC().nvrtcGetPTXSize;
  auto getFunc = compile_to_sass ? at::globalContext().getNVRTC().nvrtcGetCUBIN
                                 : at::globalContext().getNVRTC().nvrtcGetPTX;
#else
  auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
  auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
#endif
  AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
  ptx.resize(ptx_size);
  AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));

  CUmodule module{nullptr};
  AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data()));
  AT_CUDA_DRIVER_CHECK(
      nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str()));

  if (prior_device != this->device().index()) {
    at::cuda::set_device(prior_device);
  }
}

CudaCodeGen::~CudaCodeGen() = default;

RegisterCodeGen<CudaCodeGen> cuda_codegen_reg("cuda_codegen");

} // namespace torch::jit::tensorexpr
