#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/NativeFunctions.h>
#include <c10/util/irange.h>
#include <torch/library.h>
#include <ATen/native/MathBitFallThroughLists.h>

namespace at {

  // TODO: add a note explaining the design decisions
  // ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors
  static void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
    const auto& arguments = op.schema().arguments();
    const auto num_arguments = arguments.size();
    const auto stack_start = stack->size() - num_arguments;

    std::optional<bool> is_write;
    for (const auto i : c10::irange(num_arguments)) {
      const auto& alias_info = arguments[i].alias_info();
      if (alias_info != nullptr) {
        if (is_write.has_value()) {
          TORCH_CHECK(*is_write == alias_info->isWrite(),
            "Unsupported operator for ", "ZeroTensorFallback: ", op.schema().name(),
            "ZeroTensor fallback doesn't work for operators with a mix "
            "mutable and non-mutable inputs that alias with outputs, "
            "this must be implemented manually.  "
            "If you got this error on a core op, please report a bug to PyTorch.");
        } else {
          is_write = alias_info->isWrite();
        }
      }
    }

    if (is_write.has_value() && !*is_write) {
      // We assume that view operators automatically handle the ZeroTensor bit
      // correctly by propagating the dispatch key in key_set.
      // This is not necessarily always right, so you should test these cases.
      op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
      return;
    }

    for (const auto i : c10::irange(num_arguments)) {
      auto& ivalue = (*stack)[stack_start + i];
      if (!(ivalue.isTensor() || ivalue.isTensorList())) {
        continue;
      }
      const auto& argument = arguments[i];
      bool mut_arg = false;

      if (argument.alias_info()) {
        // Was already tested by is_write loop above
        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
        mut_arg = true;
      }

      if (ivalue.isTensor()) {
        auto tensor = std::move(ivalue).toTensor();
        if (tensor._is_zerotensor()) {
          TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
                    "obtained using .clone() if you want a mutable tensor.");
          tensor = at::zeros({}, tensor.options()).expand(tensor.sizes());
        }
        (*stack)[stack_start + i] = std::move(tensor);
      } else if (ivalue.isTensorList()) {
        auto tensors = std::move(ivalue).toTensorList();
        for(const auto j : c10::irange(tensors.size())) {
          const Tensor& tensor = tensors[j];
          if (tensor._is_zerotensor()) {
            // TODO: assert requires_grad=False
            //_like should not propagate zerotensor dispatch key
            TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
                    "obtained using .clone() if you want a mutable tensor.");
            tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes());
          }
        }
        (*stack)[stack_start + i] = std::move(tensors);
      }
    }

    op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
  }


  TORCH_LIBRARY_IMPL(_, ZeroTensor, m) {
    m.fallback(torch::CppFunction::makeFromBoxedFunction<&zeroTensorFallback>());
  }

  TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) {
    m.impl("zeros_like", torch::CppFunction::makeFallthrough());
    m.impl("mul.Scalar", torch::CppFunction::makeFallthrough());
    m.impl("add.Scalar", torch::CppFunction::makeFallthrough());
    m.impl("copy_", torch::CppFunction::makeFallthrough());
    m.impl("clone", torch::CppFunction::makeFallthrough());
    m.impl("dot", torch::CppFunction::makeFallthrough());
    m.impl("vdot", torch::CppFunction::makeFallthrough());
    // The functions in the list below have a specific registeration in native_functions.yaml and
    // do not use the fallback.
    // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough());
    // m.impl("add.Tensor", torch::CppFunction::makeFallthrough());
    // m.impl("linalg_cross", torch::CppFunction::makeFallthrough());

    TORCH_VIEW_FNS(m)
    TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
  }
} // namespace at
