// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
// Copyright (C) 2017 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
#define EIGEN_CXX11_TENSOR_TENSOR_TRACE_H

namespace Eigen {

/** \class TensorTrace
  * \ingroup CXX11_Tensor_Module
  *
  * \brief Tensor Trace class.
  *
  *
  */

namespace internal {
template<typename Dims, typename XprType>
struct traits<TensorTraceOp<Dims, XprType> > : public traits<XprType>
{
  typedef typename XprType::Scalar Scalar;
  typedef traits<XprType> XprTraits;
  typedef typename XprTraits::StorageKind StorageKind;
  typedef typename XprTraits::Index Index;
  typedef typename XprType::Nested Nested;
  typedef typename remove_reference<Nested>::type _Nested;
  static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
  static const int Layout = XprTraits::Layout;
};

template<typename Dims, typename XprType>
struct eval<TensorTraceOp<Dims, XprType>, Eigen::Dense>
{
  typedef const TensorTraceOp<Dims, XprType>& type;
};

template<typename Dims, typename XprType>
struct nested<TensorTraceOp<Dims, XprType>, 1, typename eval<TensorTraceOp<Dims, XprType> >::type>
{
  typedef TensorTraceOp<Dims, XprType> type;
};

} // end namespace internal


template<typename Dims, typename XprType>
class TensorTraceOp : public TensorBase<TensorTraceOp<Dims, XprType> >
{
  public:
    typedef typename Eigen::internal::traits<TensorTraceOp>::Scalar Scalar;
    typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    typedef typename XprType::CoeffReturnType CoeffReturnType;
    typedef typename Eigen::internal::nested<TensorTraceOp>::type Nested;
    typedef typename Eigen::internal::traits<TensorTraceOp>::StorageKind StorageKind;
    typedef typename Eigen::internal::traits<TensorTraceOp>::Index Index;

    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTraceOp(const XprType& expr, const Dims& dims)
      : m_xpr(expr), m_dims(dims) {
    }

    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    const Dims& dims() const { return m_dims; }

    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    const typename internal::remove_all<typename XprType::Nested>::type& expression() const { return m_xpr; }

  protected:
    typename XprType::Nested m_xpr;
    const Dims m_dims;
};


// Eval as rvalue
template<typename Dims, typename ArgType, typename Device>
struct TensorEvaluator<const TensorTraceOp<Dims, ArgType>, Device>
{
  typedef TensorTraceOp<Dims, ArgType> XprType;
  static const int NumInputDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
  static const int NumReducedDims = internal::array_size<Dims>::value;
  static const int NumOutputDims = NumInputDims - NumReducedDims;
  typedef typename XprType::Index Index;
  typedef DSizes<Index, NumOutputDims> Dimensions;
  typedef typename XprType::Scalar Scalar;
  typedef typename XprType::CoeffReturnType CoeffReturnType;
  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
  static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
  typedef StorageMemory<CoeffReturnType, Device> Storage;
  typedef typename Storage::Type EvaluatorPointerType;

  enum {
    IsAligned = false,
    PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
    BlockAccess = false,
    PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
    Layout = TensorEvaluator<ArgType, Device>::Layout,
    CoordAccess = false,
    RawAccess = false
  };

  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
  typedef internal::TensorBlockNotImplemented TensorBlock;
  //===--------------------------------------------------------------------===//

  EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
    : m_impl(op.expression(), device), m_traceDim(1), m_device(device)
  {

    EIGEN_STATIC_ASSERT((NumOutputDims >= 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
    EIGEN_STATIC_ASSERT((NumReducedDims >= 2) || ((NumReducedDims == 0) && (NumInputDims == 0)), YOU_MADE_A_PROGRAMMING_MISTAKE);

    for (int i = 0; i < NumInputDims; ++i) {
      m_reduced[i] = false;
    }

    const Dims& op_dims = op.dims();
    for (int i = 0; i < NumReducedDims; ++i) {
      eigen_assert(op_dims[i] >= 0);
      eigen_assert(op_dims[i] < NumInputDims);
      m_reduced[op_dims[i]] = true;
    }

    // All the dimensions should be distinct to compute the trace
    int num_distinct_reduce_dims = 0;
    for (int i = 0; i < NumInputDims; ++i) {
      if (m_reduced[i]) {
        ++num_distinct_reduce_dims;
      }
    }

    eigen_assert(num_distinct_reduce_dims == NumReducedDims);

    // Compute the dimensions of the result.
    const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();

    int output_index = 0;
    int reduced_index = 0;
    for (int i = 0; i < NumInputDims; ++i) {
      if (m_reduced[i]) {
        m_reducedDims[reduced_index] = input_dims[i];
        if (reduced_index > 0) {
          // All the trace dimensions must have the same size
          eigen_assert(m_reducedDims[0] == m_reducedDims[reduced_index]);
        }
        ++reduced_index;
      }
      else {
        m_dimensions[output_index] = input_dims[i];
        ++output_index;
      }
    }

    if (NumReducedDims != 0) {
      m_traceDim = m_reducedDims[0];
    }

    // Compute the output strides
    if (NumOutputDims > 0) {
      if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
        m_outputStrides[0] = 1;
        for (int i = 1; i < NumOutputDims; ++i) {
          m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
        }
      }
      else {
        m_outputStrides.back() = 1;
        for (int i = NumOutputDims - 2; i >= 0; --i) {
          m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
        }
      }
    }

    // Compute the input strides
    if (NumInputDims > 0) {
      array<Index, NumInputDims> input_strides;
      if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
        input_strides[0] = 1;
        for (int i = 1; i < NumInputDims; ++i) {
          input_strides[i] = input_strides[i - 1] * input_dims[i - 1];
        }
      }
      else {
        input_strides.back() = 1;
        for (int i = NumInputDims - 2; i >= 0; --i) {
          input_strides[i] = input_strides[i + 1] * input_dims[i + 1];
        }
      }

      output_index = 0;
      reduced_index = 0;
      for (int i = 0; i < NumInputDims; ++i) {
        if(m_reduced[i]) {
          m_reducedStrides[reduced_index] = input_strides[i];
          ++reduced_index;
        }
        else {
          m_preservedStrides[output_index] = input_strides[i];
          ++output_index;
        }
      }
    }
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
    return m_dimensions;
  }

  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
    m_impl.evalSubExprsIfNeeded(NULL);
    return true;
  }

  EIGEN_STRONG_INLINE void cleanup() {
    m_impl.cleanup();
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
  {
    // Initialize the result
    CoeffReturnType result = internal::cast<int, CoeffReturnType>(0);
    Index index_stride = 0;
    for (int i = 0; i < NumReducedDims; ++i) {
      index_stride += m_reducedStrides[i];
    }

    // If trace is requested along all dimensions, starting index would be 0
    Index cur_index = 0;
    if (NumOutputDims != 0)
      cur_index = firstInput(index);
    for (Index i = 0; i < m_traceDim; ++i) {
        result += m_impl.coeff(cur_index);
        cur_index += index_stride;
    }

    return result;
  }

  template<int LoadMode>
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {

    EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
    eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());

    EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
    for (int i = 0; i < PacketSize; ++i) {
        values[i] = coeff(index + i);
    }
    PacketReturnType result = internal::ploadt<PacketReturnType, LoadMode>(values);
    return result;
  }

#ifdef EIGEN_USE_SYCL
  // binding placeholder accessors to a command group handler for SYCL
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
    m_impl.bind(cgh);
  }
#endif

 protected:
  // Given the output index, finds the first index in the input tensor used to compute the trace
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index firstInput(Index index) const {
    Index startInput = 0;
    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      for (int i = NumOutputDims - 1; i > 0; --i) {
        const Index idx = index / m_outputStrides[i];
        startInput += idx * m_preservedStrides[i];
        index -= idx * m_outputStrides[i];
      }
      startInput += index * m_preservedStrides[0];
    }
    else {
      for (int i = 0; i < NumOutputDims - 1; ++i) {
        const Index idx = index / m_outputStrides[i];
        startInput += idx * m_preservedStrides[i];
        index -= idx * m_outputStrides[i];
      }
      startInput += index * m_preservedStrides[NumOutputDims - 1];
    }
    return startInput;
  }

  Dimensions m_dimensions;
  TensorEvaluator<ArgType, Device> m_impl;
  // Initialize the size of the trace dimension
  Index m_traceDim;
  const Device EIGEN_DEVICE_REF m_device;
  array<bool, NumInputDims> m_reduced;
  array<Index, NumReducedDims> m_reducedDims;
  array<Index, NumOutputDims> m_outputStrides;
  array<Index, NumReducedDims> m_reducedStrides;
  array<Index, NumOutputDims> m_preservedStrides;
};


} // End namespace Eigen

#endif // EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
