
#include <torch/csrc/jit/passes/autocast.h>

#include <ATen/autocast_mode.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <optional>

#include <stack>
#include <unordered_set>
#include <vector>

namespace torch::jit {

namespace {

bool autocast_enabled = true;

struct AutocastContext {
  bool gpu_enabled = false;
  bool cpu_enabled = false;
  c10::ScalarType gpu_scalar_type = c10::ScalarType::Undefined;
  c10::ScalarType cpu_scalar_type = c10::ScalarType::Undefined;

  operator bool() const {
    return gpu_enabled || cpu_enabled;
  }
};

struct AutocastScope {
  Value* instance = nullptr;
  AutocastContext context;
  void stack(const AutocastContext& parent_context) {}
};

bool isAutocastNode(Value* value) {
  const auto class_name = getModuleName(value);
  return class_name.has_value() &&
      (*class_name == "__torch__.torch.cuda.amp.autocast_mode.autocast" ||
       *class_name == "__torch__.torch.cpu.amp.autocast_mode.autocast" ||
       *class_name == "__torch__.torch.amp.autocast_mode.autocast");
}

// If we have an autocast instance, return it
//
// This is the pattern we're looking for (this is done after
//  autocast.__init__() has been inlined)
//
// %4 : bool = prim::Constant[value=1]()
// %5 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject()
//  = prim::SetAttr[name="_enabled"](%5, %4)
//
// Notes:
//  1. There's no guarantee that the autocast instance is in the same block
//    as the prim::Enter() node
//  2. `prim::SetAttr` must follow `prim::CreateObject()` in the same block,
//    but there might be other nodes in between
//
std::optional<AutocastScope> parseAutocast(
    Value* value,
    const AutocastContext& context) {
  if (!isAutocastNode(value)) {
    // Not an autocast...
    return std::nullopt;
  }
  if (value->node()->kind() == prim::CreateObject) {
    AutocastScope scope;
    scope.instance = value;
    scope.context = context;
    std::optional<bool> enabled;
    std::string device;
    c10::ScalarType dtype = c10::ScalarType::Undefined;
    for (Use use : value->uses()) {
      // TODO: support runtime flag
      if (use.user->kind() == prim::SetAttr &&
          use.user->s(attr::name) == "_enabled") {
        // Search for `prim::SetAttr[name="_enabled"]`
        enabled = constant_as<bool>(use.user->input(1));
        TORCH_CHECK(
            enabled.has_value(),
            "Autocast _enabled argument must be a constant");
      } else if (
          use.user->kind() == prim::SetAttr &&
          use.user->s(attr::name) == "device") {
        // Search for `prim::SetAttr[name="device"]`
        auto ret = constant_as<std::string>(use.user->input(1));
        TORCH_CHECK(
            ret.has_value(), "Autocast device argument must be a constant");
        device = ret.value();
      } else if (
          use.user->kind() == prim::SetAttr &&
          use.user->s(attr::name) == "fast_dtype") {
        // Search for `prim::SetAttr[name="fast_dtype"]`
        auto ret = constant_as<c10::ScalarType>(use.user->input(1));
        if (ret.has_value()) {
          dtype = ret.value();
        }
      }
    }
    TORCH_CHECK(enabled.has_value(), "Autocast missing _enabled attribute");
    TORCH_CHECK(!device.empty(), "Autocast missing device attribute");
    if (dtype == c10::ScalarType::Undefined) {
      dtype = at::autocast::get_autocast_dtype(c10::Device(device).type());
    }
    TORCH_CHECK(
        dtype != c10::ScalarType::Undefined,
        "Autocast has invalid fast_dtype attribute");
    if (device == "cuda" || device == "mps") {
      scope.context.gpu_enabled = enabled.value();
      scope.context.gpu_scalar_type = dtype;
    } else if (device == "cpu") {
      scope.context.cpu_enabled = enabled.value();
      scope.context.cpu_scalar_type = dtype;
    } else {
      TORCH_INTERNAL_ASSERT(
          false, "unrecognized device for autocast pass: ", device);
    }
    return scope;
  } else {
    // We only support simple and static autocast expressions. For example,
    // the following should report an error (since the autocast would not
    // work as expected)
    //
    //    autocast_on = autocast(enabled=True)
    //    autocast_off = autocast(enabled=False)
    //    with autocast_on if condition else autocast_off:
    //        ...
    //
    // TODO: better error message
    //
    AT_ERROR("Unsupported autocast syntax");
  }

  return std::nullopt;
}

void castTensorInputs(
    Node* node,
    Symbol cast_op,
    const AutocastContext& context) {
  if (!context) {
    return;
  }

  const auto graph = node->owningGraph();

  std::unordered_set<Value*> casted_inputs;
  // need to also keep the inputs in order, otherwise tracing fails
  // sanity checks because casting ops are inserted in random order
  std::vector<Value*> casted_inputs_ordered;
  for (auto input : node->inputs()) {
    // TODO: update cast_op signature to take dynamic context flags
    auto input_tensor_type = input->type()->cast<TensorType>();
    if (input_tensor_type && input->node()->kind() != cast_op) {
      auto has_inserted = casted_inputs.insert(input);
      if (has_inserted.second) {
        casted_inputs_ordered.push_back(input);
      }
    }
  }

  WithInsertPoint insert_point(node);

  for (auto input : casted_inputs_ordered) {
    if (cast_op == aten::_autocast_to_full_precision) {
      const auto new_input = graph->insert(
          cast_op,
          {input,
           graph->insertConstant(IValue(context.gpu_enabled)),
           graph->insertConstant(IValue(context.cpu_enabled))});
      node->replaceInputWith(input, new_input);
    } else if (cast_op == aten::_autocast_to_reduced_precision) {
      const auto new_input = graph->insert(
          cast_op,
          {input,
           graph->insertConstant(IValue(context.gpu_enabled)),
           graph->insertConstant(IValue(context.cpu_enabled)),
           graph->insertConstant(IValue(context.gpu_scalar_type)),
           graph->insertConstant(IValue(context.cpu_scalar_type))});
      node->replaceInputWith(input, new_input);
    } else {
      TORCH_INTERNAL_ASSERT(
          false, "unrecognized cast_op symbol: ", cast_op.toQualString());
    }
  }
}

bool hasExplicitDtypeArgument(Node* node) {
  if (node->hasNamedInput("dtype")) {
    Value* dtype_arg = node->namedInput("dtype");
    return dtype_arg->type()->kind() != TypeKind::NoneType;
  }
  return false;
}

void castInputsToWidestType(Node* node, const AutocastContext& context) {
  if (!context) {
    return;
  }
  // Figure out the widest type
  // (really, just looking for any float32 inputs)
  //
  // TODO: revisit this (do we need to consider float64 types?)
  //
  for (auto input : node->inputs()) {
    if (auto tensor_type = input->type()->cast<TensorType>()) {
      const auto dtype = tensor_type->scalarType();
      if (!dtype.has_value() || *dtype == at::ScalarType::Float) {
        castTensorInputs(node, aten::_autocast_to_full_precision, context);
        return;
      }
    }
  }
}

// Users can call torch.is_autocast_enabled() or is_autocast_cpu_enabled() to
// determine whether autocasting is enabled. With JIT-scripted functions, we
// actually need to return true if eager autocast OR jit autocast are enabled.
//
// In the case where JIT autocast is enabled, we replace
//    %x : bool = aten::is_autocast_enabled()
// with a constant "True".
//
// More context on eager vs JIT autocasting:
//
// Autocasting actually has two settings: eager autocasting, and JIT
// autocasting. Eager autocasting is the thread-local setting that turns on
// the relevant bit in the dispatcher settings. JIT autocasting is the pass
// implemented in this file, which makes changes to the graph to insert casting
// ops in order to achieve the same behavior as eager autocasting.
//
// If eager autocasting is enabled at the time when a JIT-scripted function is
// invoked, then autocasting will occur regardless of what the JIT-autocasting
// settings are.
void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) {
  if (!is_jit_enabled) {
    return;
  }

  auto graph = node->owningGraph();

  WithInsertPoint insert_point(node);

  Value* true_constant = graph->insertConstant(IValue(true));
  node->output()->replaceAllUsesWith(true_constant);
  node->destroy();
}

// [Note: implicit type promotion in Autocast]
//
// Casting policy below mostly follows pytorch/aten/src/ATen/autocast.cpp, with
// a few exceptions, e.g. `aten::add`, which is needed to be put to promotion
// list for JIT autocast.
// The reason is that in eager amp, some binary ops promote inputs implicitly
// inside the operation, e.g. `aten::add` with fp16 & fp32 inputs would both be
// casted to fp32. In backward, autograd would cast dgrad to match their
// scalar_type in forward graph. So inputs with mismatched scalar_type would
// get the different dgrad.
// While in JIT, autodiff doesn't do this, so implicit cast is not visible to
// autodiff and backward dgrad for mismatched inputs would ended up with dgrads
// in the same scalar_type. This has caused downstream operations, which
// expects dgrad to be the same scalar type to throw mismatch error.
//
// TODO: Use the list from AMP eager directly
void handleBlock(Block* block, AutocastContext initial_state) {
  std::stack<AutocastScope> autocast_stack;

  std::optional<bool> incompatible_amp = std::nullopt;

  // The current autocast enabled/disabled state
  auto current_state = [&] {
    return autocast_stack.empty() ? initial_state
                                  : autocast_stack.top().context;
  };

  for (Node* node : block->nodes()) {
    switch (node->kind()) {
      case prim::CallFunction:
        // TODO: limit it only to amp related node;
        if (current_state() == initial_state) {
          // if the current autocasting state is the same as the global state,
          // then autocasting will be done correctly on subsequent method and
          // function calls
          if (current_state()) {
            castTensorInputs(
                node, aten::_autocast_to_full_precision, current_state());
          }
          break;
        }
        TORCH_INTERNAL_ASSERT(
            !incompatible_amp.has_value() || incompatible_amp.value(),
            "Calls are not expected with AMP & JIT");
        incompatible_amp = true;
        break;

      case prim::CallMethod:
        // TODO: limit it only to amp related node;
        if (current_state() == initial_state) {
          // if the current autocasting state is the same as the global state,
          // then autocasting will be done correctly on subsequent method and
          // function calls
          if (current_state()) {
            castTensorInputs(
                node, aten::_autocast_to_full_precision, current_state());
          }
          break;
        }
        if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
          const auto& name = node->s(attr::name);
          const auto& function = class_type->getMethod(name);
          if (!function.isGraphFunction()) {
            TORCH_INTERNAL_ASSERT(
                !incompatible_amp.has_value() || incompatible_amp.value(),
                "Calls are not expected with AMP & JIT");
            incompatible_amp = true;
          }
        } else {
          TORCH_INTERNAL_ASSERT(
              !incompatible_amp.has_value() || incompatible_amp.value(),
              "Unexpected prim::CallMethod form with AMP & JIT");
          incompatible_amp = true;
        }
        break;

      case prim::Enter:
        if (auto autocast_scope =
                parseAutocast(node->input(), current_state())) {
          if (node->hasUses()) {
            // TODO: better error message
            AT_ERROR("`with autocast() as ...` is not supported");
          }
          TORCH_INTERNAL_ASSERT(
              !incompatible_amp.has_value() || !incompatible_amp.value(),
              "Unsupported case by AMP & JIT");
          incompatible_amp = false;
          autocast_stack.push(*autocast_scope);
        }
        break;

      case prim::Exit:
        if (isAutocastNode(node->input(0))) {
          TORCH_INTERNAL_ASSERT(!autocast_stack.empty());
          TORCH_INTERNAL_ASSERT(autocast_stack.top().instance == node->input());
          TORCH_INTERNAL_ASSERT(
              !incompatible_amp.has_value() || !incompatible_amp.value(),
              "Unsupported case by AMP & JIT");
          incompatible_amp = false;
          autocast_stack.pop();
        }
        break;

      case aten::is_autocast_enabled:
        updateAutocastEnabledCheck(node, current_state().gpu_enabled);
        break;

      case aten::is_autocast_cpu_enabled:
        updateAutocastEnabledCheck(node, current_state().cpu_enabled);
        break;

      // CastPolicy::fp16 (cast all inputs to float16)
      case aten::_convolution:
      case aten::conv1d:
      case aten::conv2d:
      case aten::conv3d:
      case aten::conv_tbc:
      case aten::conv_transpose1d:
      case aten::convolution:
      case aten::cudnn_convolution:
      case aten::cudnn_convolution_transpose:
      case aten::prelu:
      case aten::addmm:
      case aten::addmv:
      case aten::addr:
      case aten::matmul:
      case aten::mm:
      case aten::mv:
      case aten::linear:
      case aten::addbmm:
      case aten::baddbmm:
      case aten::bmm:
      case aten::chain_matmul:
      case aten::_thnn_fused_lstm_cell:
      case aten::_thnn_fused_gru_cell:
      case aten::lstm_cell:
      case aten::gru_cell:
      case aten::rnn_tanh_cell:
      case aten::rnn_relu_cell:
        if (!node->schema().is_mutable()) {
          castTensorInputs(
              node, aten::_autocast_to_reduced_precision, current_state());
        }
        break;

      // CastPolicy::fp32 (cast all inputs to float32)
      case aten::native_layer_norm:
      case aten::acos:
      case aten::asin:
      case aten::cosh:
      case aten::erfinv:
      case aten::exp:
      case aten::expm1:
      case aten::log:
      case aten::log10:
      case aten::log2:
      case aten::log1p:
      case aten::reciprocal:
      case aten::rsqrt:
      case aten::sinh:
      case aten::tan:
      case aten::pow:
      case aten::softplus:
      case aten::gelu:
      case aten::layer_norm:
      case aten::group_norm:
      case aten::frobenius_norm:
      case aten::nuclear_norm:
      case aten::cosine_similarity:
      case aten::cosine_embedding_loss:
      case aten::nll_loss:
      case aten::nll_loss2d:
      case aten::hinge_embedding_loss:
      case aten::kl_div:
      case aten::l1_loss:
      case aten::smooth_l1_loss:
      case aten::mse_loss:
      case aten::margin_ranking_loss:
      case aten::multilabel_margin_loss:
      case aten::soft_margin_loss:
      case aten::triplet_margin_loss:
      case aten::multi_margin_loss:
      case aten::binary_cross_entropy_with_logits:
      case aten::dist:
      case aten::pdist:
      case aten::cdist:
      case aten::renorm:
      case aten::logsumexp:
        if (!node->schema().is_mutable()) {
          castTensorInputs(
              node, aten::_autocast_to_full_precision, current_state());
        }
        break;

      // CastPolicy::fp32_set_opt_dtype
      case aten::prod:
      case aten::log_softmax:
      case aten::cumprod:
      case aten::cumsum:
      case aten::sum:
        if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
          castTensorInputs(
              node, aten::_autocast_to_full_precision, current_state());
        }
        break;

      // cast softmax to fp32 only on GPU
      case aten::softmax:
        if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
          auto context = current_state();
          context.cpu_enabled = false;
          castTensorInputs(node, aten::_autocast_to_full_precision, context);
        }
        break;

      // CastPolicy::promote (promote inputs to the widest type)
      case aten::addcdiv:
      case aten::addcmul:
      case aten::atan2:
      case aten::bilinear:
      case aten::cat:
      case aten::cross:
      case aten::dot:
      case aten::equal:
      case aten::index_put:
      case aten::stack:
      case aten::tensordot:
      // add, sub, mul, div were added to autocast jit, because aten implicit
      // type promotion is not visible to JIT and could cause dtype mismatch on
      // backward
      // see [Note: implicit type promotion in Autocast]
      case aten::add:
      case aten::sub:
      case aten::mul:
      case aten::div:
        if (!node->schema().is_mutable()) {
          castInputsToWidestType(node, current_state());
        }
        break;

      // Banned in autocast, see binary_cross_entropy_banned()
      case aten::binary_cross_entropy:
        if (current_state()) {
          AT_ERROR("Unsafe to autocast");
        }
    }

    // process sub-blocks, if any
    for (Block* sub_block : node->blocks()) {
      handleBlock(sub_block, current_state());
    }
  }

  // Sanity check: make sure there's no unbalanced transition
  TORCH_INTERNAL_ASSERT(autocast_stack.empty());
}

} // namespace

bool setAutocastMode(bool value) {
  auto old_value = autocast_enabled;
  autocast_enabled = value;
  return old_value;
}

bool autocastEnabled() {
  return autocast_enabled;
}

void Autocast(const std::shared_ptr<Graph>& graph) {
  GRAPH_DUMP("\nBefore Autocast: ", graph);
  if (autocastEnabled()) {
    AutocastContext init = {
        at::autocast::is_autocast_enabled(at::kCUDA),
        at::autocast::is_autocast_enabled(at::kCPU),
        at::autocast::get_autocast_dtype(at::kCUDA),
        at::autocast::get_autocast_dtype(at::kCPU)};
    handleBlock(graph->block(), init);
  }
  GRAPH_DUMP("\nAfter Autocast: ", graph);
}

} // namespace torch::jit
