// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_CXX11_TENSOR_TENSOR_FORCED_EVAL_H
#define EIGEN_CXX11_TENSOR_TENSOR_FORCED_EVAL_H

namespace Eigen {

/** \class TensorForcedEval
  * \ingroup CXX11_Tensor_Module
  *
  * \brief Tensor reshaping class.
  *
  *
  */
namespace internal {
template<typename XprType>
struct traits<TensorForcedEvalOp<XprType> >
{
  // Type promotion to handle the case where the types of the lhs and the rhs are different.
  typedef typename XprType::Scalar Scalar;
  typedef traits<XprType> XprTraits;
  typedef typename traits<XprType>::StorageKind StorageKind;
  typedef typename traits<XprType>::Index Index;
  typedef typename XprType::Nested Nested;
  typedef typename remove_reference<Nested>::type _Nested;
  static const int NumDimensions = XprTraits::NumDimensions;
  static const int Layout = XprTraits::Layout;
  typedef typename XprTraits::PointerType PointerType;

  enum {
    Flags = 0
  };
};

template<typename XprType>
struct eval<TensorForcedEvalOp<XprType>, Eigen::Dense>
{
  typedef const TensorForcedEvalOp<XprType>& type;
};

template<typename XprType>
struct nested<TensorForcedEvalOp<XprType>, 1, typename eval<TensorForcedEvalOp<XprType> >::type>
{
  typedef TensorForcedEvalOp<XprType> type;
};

}  // end namespace internal



template<typename XprType>
class TensorForcedEvalOp : public TensorBase<TensorForcedEvalOp<XprType>, ReadOnlyAccessors>
{
  public:
  typedef typename Eigen::internal::traits<TensorForcedEvalOp>::Scalar Scalar;
  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
  typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
  typedef typename Eigen::internal::nested<TensorForcedEvalOp>::type Nested;
  typedef typename Eigen::internal::traits<TensorForcedEvalOp>::StorageKind StorageKind;
  typedef typename Eigen::internal::traits<TensorForcedEvalOp>::Index Index;

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorForcedEvalOp(const XprType& expr)
      : m_xpr(expr) {}

    EIGEN_DEVICE_FUNC
    const typename internal::remove_all<typename XprType::Nested>::type&
    expression() const { return m_xpr; }

  protected:
    typename XprType::Nested m_xpr;
};

namespace internal {
template <typename Device, typename CoeffReturnType>
struct non_integral_type_placement_new{
  template <typename StorageType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(Index numValues, StorageType m_buffer) {
   // Initialize non-trivially constructible types.
    if (!internal::is_arithmetic<CoeffReturnType>::value) {
      for (Index i = 0; i < numValues; ++i) new (m_buffer + i) CoeffReturnType();
    }
}
};

// SYCL does not support non-integral types 
// having new (m_buffer + i) CoeffReturnType() causes the following compiler error for SYCL Devices 
// no matching function for call to 'operator new'
template <typename CoeffReturnType>
struct non_integral_type_placement_new<Eigen::SyclDevice, CoeffReturnType> {
  template <typename StorageType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(Index, StorageType) {
}
};
} // end namespace internal

template<typename ArgType_, typename Device>
struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device>
{
  typedef const typename internal::remove_all<ArgType_>::type ArgType;
  typedef TensorForcedEvalOp<ArgType> XprType;
  typedef typename ArgType::Scalar Scalar;
  typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
  typedef typename XprType::Index Index;
  typedef typename XprType::CoeffReturnType CoeffReturnType;
  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
  static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
  typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
  typedef StorageMemory<CoeffReturnType, Device> Storage;
  typedef typename Storage::Type EvaluatorPointerType;

  enum {
    IsAligned         = true,
    PacketAccess      = (PacketType<CoeffReturnType, Device>::size > 1),
    BlockAccess       = internal::is_arithmetic<CoeffReturnType>::value,
    PreferBlockAccess = false,
    Layout            = TensorEvaluator<ArgType, Device>::Layout,
    RawAccess         = true
  };

  static const int NumDims = internal::traits<ArgType>::NumDimensions;

  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
  typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
  typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;

  typedef typename internal::TensorMaterializedBlock<CoeffReturnType, NumDims,
                                                     Layout, Index>
      TensorBlock;
  //===--------------------------------------------------------------------===//

  TensorEvaluator(const XprType& op, const Device& device)
      : m_impl(op.expression(), device), m_op(op.expression()),
      m_device(device), m_buffer(NULL)
  { }

  EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_impl.dimensions(); }

  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
    const Index numValues =  internal::array_prod(m_impl.dimensions());
    m_buffer = m_device.get((CoeffReturnType*)m_device.allocate_temp(numValues * sizeof(CoeffReturnType)));

   internal::non_integral_type_placement_new<Device, CoeffReturnType>()(numValues, m_buffer);

    typedef TensorEvalToOp< const typename internal::remove_const<ArgType>::type > EvalTo;
    EvalTo evalToTmp(m_device.get(m_buffer), m_op);

    internal::TensorExecutor<
        const EvalTo, typename internal::remove_const<Device>::type,
        /*Vectorizable=*/internal::IsVectorizable<Device, const ArgType>::value,
        /*Tiling=*/internal::IsTileable<Device, const ArgType>::value>::
        run(evalToTmp, m_device);

    return true;
  }

#ifdef EIGEN_USE_THREADS
  template <typename EvalSubExprsCallback>
  EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(
      EvaluatorPointerType, EvalSubExprsCallback done) {
    const Index numValues = internal::array_prod(m_impl.dimensions());
    m_buffer = m_device.get((CoeffReturnType*)m_device.allocate_temp(
        numValues * sizeof(CoeffReturnType)));
    typedef TensorEvalToOp<const typename internal::remove_const<ArgType>::type>
        EvalTo;
    EvalTo evalToTmp(m_device.get(m_buffer), m_op);

    auto on_done = std::bind([](EvalSubExprsCallback done_) { done_(true); },
                             std::move(done));
    internal::TensorAsyncExecutor<
        const EvalTo, typename internal::remove_const<Device>::type,
        decltype(on_done),
        /*Vectorizable=*/internal::IsVectorizable<Device, const ArgType>::value,
        /*Tiling=*/internal::IsTileable<Device, const ArgType>::value>::
        runAsync(evalToTmp, m_device, std::move(on_done));
  }
#endif

  EIGEN_STRONG_INLINE void cleanup() {
    m_device.deallocate_temp(m_buffer);
    m_buffer = NULL;
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
  {
    return m_buffer[index];
  }

  template<int LoadMode>
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
  {
    return internal::ploadt<PacketReturnType, LoadMode>(m_buffer + index);
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
  internal::TensorBlockResourceRequirements getResourceRequirements() const {
    return internal::TensorBlockResourceRequirements::any();
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock
  block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
          bool /*root_of_expr_ast*/ = false) const {
    assert(m_buffer != NULL);
    return TensorBlock::materialize(m_buffer, m_impl.dimensions(), desc, scratch);
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
    return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
  EvaluatorPointerType data() const { return m_buffer; }

#ifdef EIGEN_USE_SYCL
  // binding placeholder accessors to a command group handler for SYCL
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
    m_buffer.bind(cgh);
    m_impl.bind(cgh);
  }
#endif
 private:
  TensorEvaluator<ArgType, Device> m_impl;
  const ArgType m_op;
  const Device EIGEN_DEVICE_REF m_device;
  EvaluatorPointerType m_buffer;
};


} // end namespace Eigen

#endif // EIGEN_CXX11_TENSOR_TENSOR_FORCED_EVAL_H
