// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_CXX11_TENSOR_TENSOR_BLOCK_H
#define EIGEN_CXX11_TENSOR_TENSOR_BLOCK_H

namespace Eigen {
namespace internal {

// -------------------------------------------------------------------------- //
// Forward declarations for templates defined below.
template <typename Scalar, typename IndexType, int NumDims, int Layout>
class TensorBlockIO;

// -------------------------------------------------------------------------- //
// Helper function to compute strides for densely stored buffer of given
// dimensions.

// TODO(ezhulenev): We compute strides 1000 times in different evaluators, use
// this function instead everywhere.
template <int Layout, typename IndexType, int NumDims>
EIGEN_ALWAYS_INLINE DSizes<IndexType, NumDims> strides(
    const DSizes<IndexType, NumDims>& dimensions) {
  DSizes<IndexType, NumDims> strides;
  if (NumDims == 0) return strides;

  // TODO(ezhulenev): Use templates to unroll this loop (similar to
  // h_array_reduce in CXX11meta.h)? Benchmark it.
  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    strides[0] = 1;
    for (int i = 1; i < NumDims; ++i) {
      strides[i] = strides[i - 1] * dimensions[i - 1];
    }
  } else {
    strides[NumDims - 1] = 1;
    for (int i = NumDims - 2; i >= 0; --i) {
      strides[i] = strides[i + 1] * dimensions[i + 1];
    }
  }

  return strides;
}

template <int Layout, typename IndexType, size_t NumDims>
EIGEN_ALWAYS_INLINE DSizes<IndexType, NumDims> strides(
    const Eigen::array<IndexType, NumDims>& dimensions) {
  return strides<Layout>(DSizes<IndexType, NumDims>(dimensions));
}

template <int Layout, std::ptrdiff_t... Indices>
EIGEN_STRONG_INLINE DSizes<std::ptrdiff_t, sizeof...(Indices)> strides(
    const Sizes<Indices...>& sizes) {
  return strides<Layout>(DSizes<std::ptrdiff_t, sizeof...(Indices)>(sizes));
}

// -------------------------------------------------------------------------- //

// Tensor block shape type defines what are the shape preference for the blocks
// extracted from the larger tensor.
//
// Example: blocks of 100 elements from the large 100x100 tensor:
// - tensor: 100x100
// - target_block_size: 100
//
// TensorBlockShapeType:
//  - kUniformAllDims: 100 blocks of size 10x10
//  - kSkewedInnerDims: 100 blocks of size 100x1 (or 1x100 depending on a column
//                      or row major layout)
enum class TensorBlockShapeType { kUniformAllDims, kSkewedInnerDims };

struct TensorBlockResourceRequirements {
  TensorBlockShapeType shape_type;  // target block shape
  size_t size;                      // target block size
  TensorOpCost cost_per_coeff;      // cost of computing a single block element

#ifdef EIGEN_HIPCC
  // For HIPCC, we need to explicitly declare as a "device fun", the constructor
  // which is implicitly invoked in the "merge" / "any" routines. else HIPCC
  // errors out complaining about the lack of a matching constructor
  EIGEN_DEVICE_FUNC
  TensorBlockResourceRequirements(TensorBlockShapeType shape_type_, size_t size_,
				  TensorOpCost cost_)
    : shape_type(shape_type_), size(size_), cost_per_coeff(cost_)
  {}
#endif

  template <typename Scalar>
  EIGEN_DEVICE_FUNC static TensorBlockResourceRequirements withShapeAndSize(
      TensorBlockShapeType shape_type, size_t size_in_bytes,
      TensorOpCost cost) {
    const size_t size = numext::maxi(size_t(1), size_in_bytes / sizeof(Scalar));
    return {shape_type, size, cost};
  }

  template <typename Scalar>
  EIGEN_DEVICE_FUNC static TensorBlockResourceRequirements withShapeAndSize(
      TensorBlockShapeType shape_type, size_t size_in_bytes) {
    // This default cost per coefficient is valid for most materialized tensor
    // block evaluation implementations, because they typically just read
    // coefficients from the underlying tensor storage, and write to the tensor
    // block buffer (scratch or destination memory, reads and writes have linear
    // access pattern). We ignore the fixed cost of block evaluation, because in
    // practice it should negligible.
    //
    // Lazy block evaluation adds the cost of calling a functor for each
    // coefficient.
    //
    // All non-trivial block evaluation implementations must provide their own
    // cost approximation (e.g. shuffling inner dimension has a much higher cost
    // because it reads memory randomly, although the total number of moved
    // bytes is the same).
    return withShapeAndSize<Scalar>(shape_type, size_in_bytes,
                                    {/*bytes_loaded=*/sizeof(Scalar),
                                     /*bytes_stored=*/sizeof(Scalar),
                                     /*compute_cycles=*/0});
  }

  template <typename Scalar>
  EIGEN_DEVICE_FUNC static TensorBlockResourceRequirements skewed(
      size_t size_in_bytes) {
    return withShapeAndSize<Scalar>(TensorBlockShapeType::kSkewedInnerDims,
                                    size_in_bytes);
  }

  template <typename Scalar>
  EIGEN_DEVICE_FUNC static TensorBlockResourceRequirements uniform(
      size_t size_in_bytes) {
    return withShapeAndSize<Scalar>(TensorBlockShapeType::kUniformAllDims,
                                    size_in_bytes);
  }

  EIGEN_DEVICE_FUNC
  static EIGEN_STRONG_INLINE TensorBlockResourceRequirements
  merge(const TensorBlockResourceRequirements& lhs,
        const TensorBlockResourceRequirements& rhs) {
    return {merge(lhs.shape_type, rhs.shape_type),           // shape_type
            merge(lhs.size, rhs.size),                       // size
            merge(lhs.cost_per_coeff, rhs.cost_per_coeff)};  // cost_per_coeff
  }

  EIGEN_DEVICE_FUNC TensorBlockResourceRequirements& addCostPerCoeff(
      TensorOpCost cost) {
    cost_per_coeff += cost;
    return *this;
  }

  // This is a resource requirement that should be returned from expressions
  // that do not have any block evaluation preference (e.g. default tensor
  // expression with raw buffer access).
  EIGEN_DEVICE_FUNC
  static EIGEN_STRONG_INLINE TensorBlockResourceRequirements any() {
    return {TensorBlockShapeType::kUniformAllDims, 1, {0, 0, 0}};
  }

 private:
  using Requirements = TensorBlockResourceRequirements;

  EIGEN_DEVICE_FUNC
  static EIGEN_STRONG_INLINE size_t merge(size_t lhs_size, size_t rhs_size) {
    return numext::maxi(lhs_size, rhs_size);
  }

  EIGEN_DEVICE_FUNC
  static EIGEN_STRONG_INLINE TensorBlockShapeType
  merge(TensorBlockShapeType lhs, TensorBlockShapeType rhs) {
    return (lhs == TensorBlockShapeType::kSkewedInnerDims ||
            rhs == TensorBlockShapeType::kSkewedInnerDims)
               ? TensorBlockShapeType::kSkewedInnerDims
               : TensorBlockShapeType::kUniformAllDims;
  }

  EIGEN_DEVICE_FUNC
  static EIGEN_STRONG_INLINE TensorOpCost merge(TensorOpCost lhs_cost,
                                                TensorOpCost rhs_cost) {
    return lhs_cost + rhs_cost;
  }
};

// -------------------------------------------------------------------------- //
// TensorBlockDescriptor specifies a block offset within a tensor and the block
// sizes along each of the tensor dimensions.

template <int NumDims, typename IndexType = Eigen::Index>
class TensorBlockDescriptor {
 public:
  typedef DSizes<IndexType, NumDims> Dimensions;

  // If we evaluate a Tensor assignment, and expression on the left, already has
  // a memory buffer, then we might do performance optimization, and evaluate
  // the root expression directly into the final output memory. Some time it's
  // possible to reuse it for materializing subexpressions inside an expression
  // tree, to to avoid dynamic memory allocation.
  //
  // The pointer type of the underlying storage is erased, because passing
  // Scalar type through all the expression evaluation layers is way too many
  // templates. In practice destination buffer type should always match the
  // evaluated expression scalar type.
  class DestinationBuffer {
   public:
    enum DestinationBufferKind : int {
      // The above explicit specification of "int" as the enum basetype is
      // needed to get around a HIPCC link error ("the field type is not
      // amp-compatible")
      // which is issued for class members with the enum type.
      // TODO(rocm):
      // remove the "int" basetype once HIPCC has been fixed to not error out
      // in the above scenario.

      // Destination buffer is not defined (`m_data` == nullptr).
      kEmpty,

      // Tensor block defined by an owning tensor block descriptor can fit
      // contiguously into the destination buffer. In this case it's safe to
      // materialize tensor block in the destination buffer, wrap it in a
      // TensorMap, and use to build Eigen expression on top of it.
      kContiguous,

      // Destination buffer strides do not match strides of the contiguously
      // stored block, and it's impossible to define a TensorMap over this
      // buffer. However if we are evaluating a root of an expression tree, we
      // still can materialize an output into this destination, because we can
      // guarantee that no one will ever access it through block API.
      //
      // In theory it is possible to build valid TensorStriding<TensorMap>
      // expression on top of this destination buffer, however it has
      // inefficient coeff/packet access, and defeats the purpose of fast block
      // evaluation API.
      kStrided
    };

    template <typename Scalar>
    Scalar* data() const {
      eigen_assert(m_data_type_size == sizeof(Scalar));
      return static_cast<Scalar*>(m_data);
    }

    const Dimensions& strides() const { return m_strides; }
    const DestinationBufferKind& kind() const { return m_kind; }

   private:
    friend class TensorBlockDescriptor;

    DestinationBuffer() : m_data(NULL), m_data_type_size(0), m_kind(kEmpty) {}

    template <typename Scalar>
    DestinationBuffer(Scalar* data, const Dimensions& strides,
                      DestinationBufferKind kind)
        : m_data(static_cast<void*>(data)),
          m_data_type_size(sizeof(Scalar)),
          m_strides(strides),
          m_kind(kind) {}

    template <int Layout, typename Scalar>
    static DestinationBuffer make(const TensorBlockDescriptor& desc,
                                  Scalar* data, const Dimensions& strides) {
      return DestinationBuffer(data, strides, kind<Layout>(desc, strides));
    }

    template <int Layout>
    static DestinationBufferKind kind(const TensorBlockDescriptor& desc,
                                      const Dimensions& strides) {
      const Dimensions& desc_dims = desc.dimensions();
      const Dimensions& desc_strides = internal::strides<Layout>(desc_dims);
      for (int i = 0; i < NumDims; ++i) {
        if (desc_dims[i] == 1) continue;
        if (desc_strides[i] != strides[i]) return kStrided;
      }
      return kContiguous;
    }

    // Storage pointer is type erased, to reduce template bloat, but we still
    // keep the size of the underlying element type for error checking.
    void* m_data;
    size_t m_data_type_size;

    // Destination buffer dimensions always match the dimensions of a tensor
    // block descriptor it belongs to, however strides might be different.
    Dimensions m_strides;

    DestinationBufferKind m_kind;
  };

  TensorBlockDescriptor(const IndexType offset, const Dimensions& dimensions,
                        const DestinationBuffer& destination)
      : m_offset(offset),
        m_dimensions(dimensions),
        m_destination(destination) {}

  TensorBlockDescriptor(const IndexType offset, const Dimensions& dimensions)
      : m_offset(offset),
        m_dimensions(dimensions),
        m_destination(DestinationBuffer()) {}

  IndexType offset() const { return m_offset; }
  const Dimensions& dimensions() const { return m_dimensions; }
  IndexType dimension(int index) const { return m_dimensions[index]; }
  IndexType size() const { return array_prod<IndexType>(m_dimensions); }

  const DestinationBuffer& destination() const { return m_destination; }

  template <int Layout, typename Scalar>
  void AddDestinationBuffer(Scalar* dst_base, const Dimensions& dst_strides) {
    eigen_assert(dst_base != NULL);
    m_destination =
        DestinationBuffer::template make<Layout>(*this, dst_base, dst_strides);
  }

  template <int Layout, typename Scalar, typename DstStridesIndexType>
  void AddDestinationBuffer(
      Scalar* dst_base,
      const DSizes<DstStridesIndexType, NumDims>& dst_strides) {
    // DSizes constructor will do index type promotion if it's safe.
    AddDestinationBuffer<Layout>(dst_base, Dimensions(dst_strides));
  }

  TensorBlockDescriptor& DropDestinationBuffer() {
    m_destination.m_data = NULL;
    m_destination.m_kind = DestinationBuffer::kEmpty;
    return *this;
  }

  bool HasDestinationBuffer() const {
    return m_destination.kind() != DestinationBuffer::kEmpty;
  }

  // Returns a copy of `*this` with updated offset.
  TensorBlockDescriptor WithOffset(IndexType offset) const {
    return TensorBlockDescriptor(offset, m_dimensions, m_destination);
  }

 private:
  // Offset and dimensions are immutable after construction. Block descriptor
  // can only be mutated by adding or dropping destination.
  const IndexType m_offset;
  const Dimensions m_dimensions;
  DestinationBuffer m_destination;
};

// -------------------------------------------------------------------------- //
// TensorBlockMapper is responsible for iterating over the blocks of a tensor.

template <int NumDims, int Layout, typename IndexType = Eigen::Index>
class TensorBlockMapper {
  typedef TensorBlockDescriptor<NumDims, IndexType> BlockDescriptor;

 public:
  typedef DSizes<IndexType, NumDims> Dimensions;

  TensorBlockMapper() = default;
  TensorBlockMapper(const DSizes<IndexType, NumDims>& dimensions,
                    const TensorBlockResourceRequirements& requirements)
      : m_tensor_dimensions(dimensions), m_requirements(requirements) {
    // Compute block dimensions and the total number of blocks.
    InitializeBlockDimensions();
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockCount() const {
    return m_total_block_count;
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockTotalSize() const {
    return m_block_dimensions.TotalSize();
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DSizes<IndexType, NumDims>&
  blockDimensions() const {
    return m_block_dimensions;
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE BlockDescriptor
  blockDescriptor(IndexType block_index) const {
    static const bool isColMajor = Layout == static_cast<int>(ColMajor);

    IndexType offset = 0;
    DSizes<IndexType, NumDims> dimensions;

    if (NumDims == 0) return BlockDescriptor(offset, dimensions);

    // Iterate outer -> inner dimensions.
    for (int i = NumDims - 1; i >= 0; --i) {
      const int dim = isColMajor ? i : NumDims - i - 1;

      const IndexType idx = block_index / m_block_strides[dim];
      block_index -= idx * m_block_strides[dim];

      const IndexType coord = idx * m_block_dimensions[dim];
      dimensions[dim] = numext::mini(m_tensor_dimensions[dim] - coord,
                                     m_block_dimensions[dim]);
      offset += coord * m_tensor_strides[dim];
    }

    return {offset, dimensions};
  }

 private:
  void InitializeBlockDimensions() {
    // Requested block shape and size.
    const TensorBlockShapeType shape_type = m_requirements.shape_type;
    IndexType target_block_size =
        numext::maxi<IndexType>(1, static_cast<IndexType>(m_requirements.size));

    IndexType tensor_size = m_tensor_dimensions.TotalSize();

    // Corner case: one of the dimensions is zero. Logic below is too complex
    // to handle this case on a general basis, just use unit block size.
    // Note: we must not yield blocks with zero dimensions (recipe for
    // overflows/underflows, divisions by zero and NaNs later).
    if (tensor_size == 0) {
      for (int i = 0; i < NumDims; ++i) {
        m_block_dimensions[i] = 1;
      }
      m_total_block_count = 0;
      return;
    }

    // If tensor fits into a target block size, evaluate it as a single block.
    if (tensor_size <= target_block_size) {
      m_block_dimensions = m_tensor_dimensions;
      m_total_block_count = 1;
      // The only valid block index is `0`, and in this case we do not need
      // to compute real strides for tensor or blocks (see blockDescriptor).
      for (int i = 0; i < NumDims; ++i) {
        m_tensor_strides[i] = 0;
        m_block_strides[i] = 1;
      }
      return;
    }

    static const bool isColMajor = Layout == static_cast<int>(ColMajor);

    // Block shape skewed towards inner dimension.
    if (shape_type == TensorBlockShapeType::kSkewedInnerDims) {
      IndexType coeff_to_allocate = target_block_size;

      for (int i = 0; i < NumDims; ++i) {
        const int dim = isColMajor ? i : NumDims - i - 1;
        m_block_dimensions[dim] =
            numext::mini(coeff_to_allocate, m_tensor_dimensions[dim]);
        coeff_to_allocate = divup(
            coeff_to_allocate,
            numext::maxi(static_cast<IndexType>(1), m_block_dimensions[dim]));
      }
      eigen_assert(coeff_to_allocate == 1);

    } else if (shape_type == TensorBlockShapeType::kUniformAllDims) {
      // Tensor will not fit within 'target_block_size' budget: calculate tensor
      // block dimension sizes based on "square" dimension size target.
      const IndexType dim_size_target = convert_index<IndexType>(
          std::pow(static_cast<float>(target_block_size),
                   1.0f / static_cast<float>(m_block_dimensions.rank())));

      for (int i = 0; i < NumDims; ++i) {
        // TODO(andydavis) Adjust the inner most 'block_dim_size' to make it
        // a multiple of the packet size. Note that reducing
        // 'block_dim_size' in this manner can increase the number of
        // blocks, and so will amplify any per-block overhead.
        m_block_dimensions[i] =
            numext::mini(dim_size_target, m_tensor_dimensions[i]);
      }

      // Add any un-allocated coefficients to inner dimension(s).
      IndexType total_size = m_block_dimensions.TotalSize();
      for (int i = 0; i < NumDims; ++i) {
        const int dim = isColMajor ? i : NumDims - i - 1;

        if (m_block_dimensions[dim] < m_tensor_dimensions[dim]) {
          const IndexType total_size_other_dims =
              total_size / m_block_dimensions[dim];
          const IndexType alloc_avail =
              divup<IndexType>(target_block_size, total_size_other_dims);
          if (alloc_avail == m_block_dimensions[dim]) {
            // Insufficient excess coefficients to allocate.
            break;
          }
          m_block_dimensions[dim] =
              numext::mini(m_tensor_dimensions[dim], alloc_avail);
          total_size = total_size_other_dims * m_block_dimensions[dim];
        }
      }

    } else {
      eigen_assert(false);  // unknown block shape
    }

    eigen_assert(m_block_dimensions.TotalSize() >=
                 numext::mini<IndexType>(target_block_size,
                                         m_tensor_dimensions.TotalSize()));

    // Calculate block counts by dimension and total block count.
    DSizes<IndexType, NumDims> block_count;
    for (int i = 0; i < NumDims; ++i) {
      block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
    }
    m_total_block_count = array_prod(block_count);

    // Calculate block strides (used for enumerating blocks).
    m_tensor_strides = strides<Layout>(m_tensor_dimensions);
    m_block_strides = strides<Layout>(block_count);
  }

  DSizes<IndexType, NumDims> m_tensor_dimensions;
  TensorBlockResourceRequirements m_requirements;

  DSizes<IndexType, NumDims> m_block_dimensions;
  IndexType m_total_block_count;

  DSizes<IndexType, NumDims> m_tensor_strides;
  DSizes<IndexType, NumDims> m_block_strides;
};

// -------------------------------------------------------------------------- //
// TensorBlockScratchAllocator is responsible for allocating temporary buffers
// for block evaluation (output or input block materialization). Given that
// Eigen expression traversal order is deterministic, all temporary allocations
// are happening in the same order, and usually have exactly the same size.
// Scratch allocator keeps a trace of all dynamic allocations, and after the
// first block evaluation is completed, we should be able to reuse all the
// temporary buffers for the next block evaluation.

template <typename Device>
class TensorBlockScratchAllocator {
 public:
  explicit TensorBlockScratchAllocator(const Device& device)
      : m_device(device), m_allocation_index(0) {}

  ~TensorBlockScratchAllocator() {
    for (size_t i = 0; i < m_allocations.size(); ++i) {
      m_device.deallocate(m_allocations[i].ptr);
    }
  }

  void* allocate(size_t size) {
    // TODO(ezhulenev): Remove when replaced with inlined vector.
    if (m_allocations.capacity() == 0) m_allocations.reserve(8);

    // Check if we already have an existing allocation att current index.
    const int num_allocations = static_cast<int>(m_allocations.size());
    const bool has_allocation = m_allocation_index < num_allocations;

    // Allocation index can't be larger than the number of allocations.
    eigen_assert(m_allocation_index <= num_allocations);

    // If we have existing allocation, and its size is larger or equal to
    // requested size, we do nothing.

    // If current allocation can't fit requested size, we deallocate it, and
    // replace with a larger allocation.
    if (has_allocation && m_allocations[m_allocation_index].size < size) {
      m_device.deallocate(m_allocations[m_allocation_index].ptr);
      m_allocations[m_allocation_index].ptr = m_device.allocate(size);
      m_allocations[m_allocation_index].size = size;
    }

    // Make a new allocation if we don't have and existing one.
    if (!has_allocation) {
      Allocation allocation;
      allocation.ptr = m_device.allocate(size);
      allocation.size = size;
      m_allocations.push_back(allocation);
    }

    eigen_assert(m_allocations[m_allocation_index].ptr != NULL);
    eigen_assert(m_allocations[m_allocation_index].size >= size);

    return m_allocations[m_allocation_index++].ptr;
  }

  void reset() { m_allocation_index = 0; }

 private:
  struct Allocation {
    void* ptr;
    size_t size;
  };

  const Device& m_device;
  int m_allocation_index;
  // TODO(ezhulenev): This should be an inlined vector.
  std::vector<Allocation> m_allocations;
};

// -------------------------------------------------------------------------- //
// TensorBlockKind represents all possible block kinds, that can be produced by
// TensorEvaluator::evalBlock function.
enum TensorBlockKind {
  // Tensor block that is a lazy expression that must be assigned to a
  // destination using TensorBlockAssign.
  kExpr,

  // Tensor block that is a view into a memory buffer owned by an underlying
  // Tensor expression (e.g. it can be a view into a Tensor buffer).
  kView,

  // Tensor block that was materialized in a scratch memory buffer, allocated
  // with TensorBlockScratchAllocator. This block must be copied to a
  // destination, similar to a block of `kExpr` type.
  kMaterializedInScratch,

  // Tensor block that was materialized directly into the final output memory
  // buffer. For example if the left side of an assignment is a Tensor, we can
  // directly materialize the block in the destination memory.
  //
  // If strides in the output buffer do not match tensor block strides, the
  // Tensor expression will be invalid, and should not be used by
  // TensorBlockAssign or for constructing another block expression.
  kMaterializedInOutput
};

// -------------------------------------------------------------------------- //
// TensorBlockNotImplemented should be used to defined TensorBlock typedef in
// TensorEvaluators that do not support block evaluation.

class TensorBlockNotImplemented {
 public:
  typedef void XprType;
};

// -------------------------------------------------------------------------- //
// XprScalar extracts Scalar type from the Eigen expressions (if expression type
// is not void). It's required to be able to define lazy block expression for
// argument types, that do not support block evaluation.

template <typename XprType>
struct XprScalar {
  typedef typename XprType::Scalar type;
};
template <>
struct XprScalar<void> {
  typedef void type;
};

// -------------------------------------------------------------------------- //
// TensorMaterializedBlock is a fully evaluated block of the original tensor,
// and XprType is just a TensorMap over the data. This block type is typically
// used to materialize blocks of tensor expressions, that can't be efficiently
// represented as lazy Tensor expressions with fast coeff/packet operations,
// e.g. we materialize all broadcasts into evaluated blocks.
//
// TensorMaterializedBlock does not own its memory buffer, it's either a memory
// buffer that backs the original expression (e.g. block is just a view into a
// Tensor), or a memory buffer allocated with scratch allocator, and in this
// case the scratch allocator will deallocate it at the end of block based
// expression execution.
//
// If the block was evaluated directly into the output buffer, and strides in
// the output buffer do not match block strides, the TensorMap expression will
// be invalid, and should never be used in block assignment or any other tensor
// expression.

template <typename Scalar, int NumDims, int Layout,
          typename IndexType = Eigen::Index>
class TensorMaterializedBlock {
 public:
  typedef DSizes<IndexType, NumDims> Dimensions;
  typedef TensorMap<const Tensor<Scalar, NumDims, Layout> > XprType;

  TensorMaterializedBlock(TensorBlockKind kind, const Scalar* data,
                          const Dimensions& dimensions, bool valid_expr = true)
      : m_kind(kind),
        m_data(data),
        m_dimensions(dimensions),
        m_expr(m_data, m_dimensions),
        m_valid_expr(valid_expr) {
    eigen_assert(m_kind == internal::TensorBlockKind::kView ||
                 m_kind == internal::TensorBlockKind::kMaterializedInScratch ||
                 m_kind == internal::TensorBlockKind::kMaterializedInOutput);
  }

  TensorBlockKind kind() const { return m_kind; }
  // NOTE(ezhulenev): Returning XprType by value like in other block types
  // causes asan failures. The theory is that XprType::Nested doesn't work
  // properly for TensorMap.
  const XprType& expr() const {
    eigen_assert(m_valid_expr);
    return m_expr;
  }
  const Scalar* data() const { return m_data; }
  void cleanup() {}

  typedef internal::TensorBlockDescriptor<NumDims, IndexType> TensorBlockDesc;

  // TensorMaterializedBlock can be backed by different types of storage:
  //
  //   (1) Contiguous block of memory allocated with scratch allocator.
  //   (2) Contiguous block of memory reused from tensor block descriptor
  //       destination buffer.
  //   (3) Strided block of memory reused from tensor block descriptor
  //       destination buffer.
  //
  class Storage {
   public:
    Scalar* data() const { return m_data; }
    const Dimensions& dimensions() const { return m_dimensions; }
    const Dimensions& strides() const { return m_strides; }

    TensorMaterializedBlock AsTensorMaterializedBlock() const {
      return TensorMaterializedBlock(
          m_materialized_in_output
              ? internal::TensorBlockKind::kMaterializedInOutput
              : internal::TensorBlockKind::kMaterializedInScratch,
          m_data, m_dimensions, !m_strided_storage);
    }

   private:
    friend class TensorMaterializedBlock;

    Storage(Scalar* data, const Dimensions& dimensions,
            const Dimensions& strides, bool materialized_in_output,
            bool strided_storage)
        : m_data(data),
          m_dimensions(dimensions),
          m_strides(strides),
          m_materialized_in_output(materialized_in_output),
          m_strided_storage(strided_storage) {}

    Scalar* m_data;
    Dimensions m_dimensions;
    Dimensions m_strides;
    bool m_materialized_in_output;
    bool m_strided_storage;
  };

  // Creates a storage for materialized block either from the block descriptor
  // destination buffer, or allocates a new buffer with scratch allocator.
  template <typename TensorBlockScratch>
  EIGEN_STRONG_INLINE static Storage prepareStorage(
      TensorBlockDesc& desc, TensorBlockScratch& scratch,
      bool allow_strided_storage = false) {
    // Try to reuse destination as an output block buffer.
    typedef typename TensorBlockDesc::DestinationBuffer DestinationBuffer;

    if (desc.destination().kind() == DestinationBuffer::kContiguous) {
      Scalar* buffer = desc.destination().template data<Scalar>();
      desc.DropDestinationBuffer();
      return Storage(buffer, desc.dimensions(),
                     internal::strides<Layout>(desc.dimensions()),
                     /*materialized_in_output=*/true,
                     /*strided_storage=*/false);

    } else if (desc.destination().kind() == DestinationBuffer::kStrided &&
               allow_strided_storage) {
      Scalar* buffer = desc.destination().template data<Scalar>();
      desc.DropDestinationBuffer();
      return Storage(buffer, desc.dimensions(), desc.destination().strides(),
                     /*materialized_in_output=*/true, /*strided_storage=*/true);

    } else {
      void* mem = scratch.allocate(desc.size() * sizeof(Scalar));
      return Storage(static_cast<Scalar*>(mem), desc.dimensions(),
                     internal::strides<Layout>(desc.dimensions()),
                     /*materialized_in_output=*/false,
                     /*strided_storage=*/false);
    }
  }

  // Creates a materialized block for the given descriptor from a memory buffer.
  template <typename DataDimensions, typename TensorBlockScratch>
  EIGEN_STRONG_INLINE static TensorMaterializedBlock materialize(
      const Scalar* data, const DataDimensions& data_dims,
      TensorBlockDesc& desc, TensorBlockScratch& scratch) {
    eigen_assert(array_size<DataDimensions>::value == desc.dimensions().size());

    // If a tensor block dimensions covers a contiguous block of the underlying
    // memory, we can skip block buffer memory allocation, and construct a block
    // from existing `data` memory buffer.
    //
    // Example: (RowMajor layout)
    //   data_dims:          [11, 12, 13, 14]
    //   desc.dimensions():  [1,   1,  3, 14]
    //
    // In this case we can construct a TensorBlock starting at
    // `data + desc.offset()`, with a `desc.dimensions()` block sizes.
    static const bool is_col_major = Layout == ColMajor;

    // Find out how many inner dimensions have a matching size.
    int num_matching_inner_dims = 0;
    for (int i = 0; i < NumDims; ++i) {
      int dim = is_col_major ? i : NumDims - i - 1;
      if (data_dims[dim] != desc.dimensions()[dim]) break;
      ++num_matching_inner_dims;
    }

    // All the outer dimensions must be of size `1`, except a single dimension
    // before the matching inner dimension (`3` in the example above).
    bool can_use_direct_access = true;
    for (int i = num_matching_inner_dims + 1; i < NumDims; ++i) {
      int dim = is_col_major ? i : NumDims - i - 1;
      if (desc.dimension(dim) != 1) {
        can_use_direct_access = false;
        break;
      }
    }

    if (can_use_direct_access) {
      const Scalar* block_start = data + desc.offset();
      return TensorMaterializedBlock(internal::TensorBlockKind::kView,
                                     block_start, desc.dimensions());

    } else {
      // Reuse destination buffer or allocate new buffer with scratch allocator.
      const Storage storage = prepareStorage(desc, scratch);

      typedef internal::TensorBlockIO<Scalar, IndexType, NumDims, Layout>
          TensorBlockIO;
      typedef typename TensorBlockIO::Dst TensorBlockIODst;
      typedef typename TensorBlockIO::Src TensorBlockIOSrc;

      TensorBlockIOSrc src(internal::strides<Layout>(Dimensions(data_dims)),
                           data, desc.offset());
      TensorBlockIODst dst(storage.dimensions(), storage.strides(),
                           storage.data());

      TensorBlockIO::Copy(dst, src);
      return storage.AsTensorMaterializedBlock();
    }
  }

 private:
  TensorBlockKind m_kind;
  const Scalar* m_data;
  Dimensions m_dimensions;
  XprType m_expr;
  bool m_valid_expr;
};

// -------------------------------------------------------------------------- //
// TensorCwiseUnaryBlock is a lazy tensor expression block that applies UnaryOp
// functor to the blocks produced by the underlying Tensor expression.

template <typename UnaryOp, typename ArgTensorBlock>
class TensorCwiseUnaryBlock {
  static const bool NoArgBlockAccess =
      internal::is_void<typename ArgTensorBlock::XprType>::value;

 public:
  typedef typename conditional<
      NoArgBlockAccess, void,
      TensorCwiseUnaryOp<UnaryOp, const typename ArgTensorBlock::XprType> >::
      type XprType;

  typedef typename XprScalar<XprType>::type Scalar;

  TensorCwiseUnaryBlock(const ArgTensorBlock& arg_block, const UnaryOp& functor)
      : m_arg_block(arg_block), m_functor(functor) {}

  TensorBlockKind kind() const { return internal::TensorBlockKind::kExpr; }

  XprType expr() const { return XprType(m_arg_block.expr(), m_functor); }
  const Scalar* data() const { return NULL; }
  void cleanup() { m_arg_block.cleanup(); }

 private:
  ArgTensorBlock m_arg_block;
  UnaryOp m_functor;
};

// -------------------------------------------------------------------------- //
// TensorCwiseUnaryBlock is a lazy tensor expression block that applies BinaryOp
// functor to the blocks produced by the underlying Tensor expression.

template <typename BinaryOp, typename LhsTensorBlock, typename RhsTensorBlock>
class TensorCwiseBinaryBlock {
  static const bool NoArgBlockAccess =
      internal::is_void<typename LhsTensorBlock::XprType>::value ||
      internal::is_void<typename RhsTensorBlock::XprType>::value;

 public:
  typedef typename conditional<
      NoArgBlockAccess, void,
      TensorCwiseBinaryOp<BinaryOp, const typename LhsTensorBlock::XprType,
                          const typename RhsTensorBlock::XprType> >::type
      XprType;

  typedef typename XprScalar<XprType>::type Scalar;

  TensorCwiseBinaryBlock(const LhsTensorBlock& left_block,
                         const RhsTensorBlock& right_block,
                         const BinaryOp& functor)
      : m_left_block(left_block),
        m_right_block(right_block),
        m_functor(functor) {}

  TensorBlockKind kind() const { return internal::TensorBlockKind::kExpr; }

  XprType expr() const {
    return XprType(m_left_block.expr(), m_right_block.expr(), m_functor);
  }

  const Scalar* data() const { return NULL; }

  void cleanup() {
    m_left_block.cleanup();
    m_right_block.cleanup();
  }

 private:
  LhsTensorBlock m_left_block;
  RhsTensorBlock m_right_block;
  BinaryOp m_functor;
};

// -------------------------------------------------------------------------- //
// TensorUnaryExprBlock is a lazy tensor expression block that can construct
// an arbitrary tensor expression from a block of the underlying type (this is a
// generalization of the TensorCwiseUnaryBlock for arbitrary expressions).

template <typename BlockFactory, typename ArgTensorBlock>
class TensorUnaryExprBlock {
  typedef typename ArgTensorBlock::XprType ArgXprType;
  static const bool NoArgBlockAccess = internal::is_void<ArgXprType>::value;

 public:
  typedef typename conditional<
      NoArgBlockAccess, void,
      typename BlockFactory::template XprType<ArgXprType>::type>::type XprType;

  typedef typename XprScalar<XprType>::type Scalar;

  TensorUnaryExprBlock(const ArgTensorBlock& arg_block,
                       const BlockFactory& factory)
      : m_arg_block(arg_block), m_factory(factory) {}

  TensorBlockKind kind() const { return internal::TensorBlockKind::kExpr; }
  XprType expr() const { return m_factory.expr(m_arg_block.expr()); }
  const Scalar* data() const { return NULL; }
  void cleanup() { m_arg_block.cleanup(); }

 private:
  ArgTensorBlock m_arg_block;
  BlockFactory m_factory;
};

// -------------------------------------------------------------------------- //
// TensorTernaryExprBlock is a lazy tensor expression block that can construct
// an arbitrary tensor expression from three blocks of the underlying type.

template <typename BlockFactory, typename Arg1TensorBlock,
          typename Arg2TensorBlock, typename Arg3TensorBlock>
class TensorTernaryExprBlock {
  typedef typename Arg1TensorBlock::XprType Arg1XprType;
  typedef typename Arg2TensorBlock::XprType Arg2XprType;
  typedef typename Arg3TensorBlock::XprType Arg3XprType;

  static const bool NoArgBlockAccess = internal::is_void<Arg1XprType>::value ||
                                       internal::is_void<Arg2XprType>::value ||
                                       internal::is_void<Arg3XprType>::value;

 public:
  typedef typename conditional<
      NoArgBlockAccess, void,
      typename BlockFactory::template XprType<Arg1XprType, Arg2XprType,
                                              Arg3XprType>::type>::type XprType;

  typedef typename XprScalar<XprType>::type Scalar;

  TensorTernaryExprBlock(const Arg1TensorBlock& arg1_block,
                         const Arg2TensorBlock& arg2_block,
                         const Arg3TensorBlock& arg3_block,
                         const BlockFactory& factory)
      : m_arg1_block(arg1_block),
        m_arg2_block(arg2_block),
        m_arg3_block(arg3_block),
        m_factory(factory) {}

  TensorBlockKind kind() const { return internal::TensorBlockKind::kExpr; }
  XprType expr() const {
    return m_factory.expr(m_arg1_block.expr(), m_arg2_block.expr(),
                          m_arg3_block.expr());
  }
  const Scalar* data() const { return NULL; }
  void cleanup() {
    m_arg1_block.cleanup();
    m_arg2_block.cleanup();
    m_arg3_block.cleanup();
  }

 private:
  Arg1TensorBlock m_arg1_block;
  Arg2TensorBlock m_arg2_block;
  Arg3TensorBlock m_arg3_block;
  BlockFactory m_factory;
};

// -------------------------------------------------------------------------- //
// StridedLinearBufferCopy provides a method to copy data between two linear
// buffers with different strides, with optimized paths for scatter/gather.

template <typename Scalar, typename IndexType>
class StridedLinearBufferCopy {
  typedef typename packet_traits<Scalar>::type Packet;
  enum {
    Vectorizable = packet_traits<Scalar>::Vectorizable,
    PacketSize = packet_traits<Scalar>::size
  };

 public:
  // Specifying linear copy kind statically gives ~30% speedup for small sizes.
  enum class Kind {
    Linear = 0,       // src_stride == 1 && dst_stride == 1
    Scatter = 1,      // src_stride == 1 && dst_stride != 1
    FillLinear = 2,   // src_stride == 0 && dst_stride == 1
    FillScatter = 3,  // src_stride == 0 && dst_stride != 1
    Gather = 4,       // dst_stride == 1
    Random = 5        // everything else
  };

  struct Dst {
    Dst(IndexType o, IndexType s, Scalar* d) : offset(o), stride(s), data(d) {}

    IndexType offset;
    IndexType stride;
    Scalar* data;
  };

  struct Src {
    Src(IndexType o, IndexType s, const Scalar* d)
        : offset(o), stride(s), data(d) {}

    IndexType offset;
    IndexType stride;
    const Scalar* data;
  };

  template <typename StridedLinearBufferCopy::Kind kind>
  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void Run(const Dst& dst,
                                                        const Src& src,
                                                        const size_t count) {
    Run<kind>(count, dst.offset, dst.stride, dst.data, src.offset, src.stride,
              src.data);
  }

 private:
  template <typename StridedLinearBufferCopy::Kind kind>
  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void Run(
      const IndexType count, const IndexType dst_offset,
      const IndexType dst_stride, Scalar* EIGEN_RESTRICT dst_data,
      const IndexType src_offset, const IndexType src_stride,
      const Scalar* EIGEN_RESTRICT src_data) {
    const Scalar* src = &src_data[src_offset];
    Scalar* dst = &dst_data[dst_offset];

    if (!Vectorizable) {
      for (Index i = 0; i < count; ++i) {
        dst[i * dst_stride] = src[i * src_stride];
      }
      return;
    }

    const IndexType vectorized_size = count - PacketSize;
    IndexType i = 0;

    if (kind == StridedLinearBufferCopy::Kind::Linear) {
      // ******************************************************************** //
      // Linear copy from `src` to `dst`.
      const IndexType unrolled_size = count - 4 * PacketSize;
      eigen_assert(src_stride == 1 && dst_stride == 1);
      for (; i <= unrolled_size; i += 4 * PacketSize) {
        for (int j = 0; j < 4; ++j) {
          Packet p = ploadu<Packet>(src + i + j * PacketSize);
          pstoreu<Scalar, Packet>(dst + i + j * PacketSize, p);
        }
      }
      for (; i <= vectorized_size; i += PacketSize) {
        Packet p = ploadu<Packet>(src + i);
        pstoreu<Scalar, Packet>(dst + i, p);
      }
      for (; i < count; ++i) {
        dst[i] = src[i];
      }
      // ******************************************************************** //
    } else if (kind == StridedLinearBufferCopy::Kind::Scatter) {
      // Scatter from `src` to `dst`.
      eigen_assert(src_stride == 1 && dst_stride != 1);
      for (; i <= vectorized_size; i += PacketSize) {
        Packet p = ploadu<Packet>(src + i);
        pscatter<Scalar, Packet>(dst + i * dst_stride, p, dst_stride);
      }
      for (; i < count; ++i) {
        dst[i * dst_stride] = src[i];
      }
      // ******************************************************************** //
    } else if (kind == StridedLinearBufferCopy::Kind::FillLinear) {
      // Fill `dst` with value at `*src`.
      eigen_assert(src_stride == 0 && dst_stride == 1);
      const IndexType unrolled_size = count - 4 * PacketSize;
      Packet p = pload1<Packet>(src);
      for (; i <= unrolled_size; i += 4 * PacketSize) {
        for (int j = 0; j < 4; ++j) {
          pstoreu<Scalar, Packet>(dst + i + j * PacketSize, p);
        }
      }
      for (; i <= vectorized_size; i += PacketSize) {
        pstoreu<Scalar, Packet>(dst + i, p);
      }
      for (; i < count; ++i) {
        dst[i] = *src;
      }
      // ******************************************************************** //
    } else if (kind == StridedLinearBufferCopy::Kind::FillScatter) {
      // Scatter `*src` into `dst`.
      eigen_assert(src_stride == 0 && dst_stride != 1);
      Packet p = pload1<Packet>(src);
      for (; i <= vectorized_size; i += PacketSize) {
        pscatter<Scalar, Packet>(dst + i * dst_stride, p, dst_stride);
      }
      for (; i < count; ++i) {
        dst[i * dst_stride] = *src;
      }
      // ******************************************************************** //
    } else if (kind == StridedLinearBufferCopy::Kind::Gather) {
      // Gather from `src` into `dst`.
      eigen_assert(dst_stride == 1);
      for (; i <= vectorized_size; i += PacketSize) {
        Packet p = pgather<Scalar, Packet>(src + i * src_stride, src_stride);
        pstoreu<Scalar, Packet>(dst + i, p);
      }
      for (; i < count; ++i) {
        dst[i] = src[i * src_stride];
      }
      // ******************************************************************** //
    } else if (kind == StridedLinearBufferCopy::Kind::Random) {
      // Random.
      for (; i < count; ++i) {
        dst[i * dst_stride] = src[i * src_stride];
      }
    } else {
      eigen_assert(false);
    }
  }
};

// -------------------------------------------------------------------------- //
// TensorBlockIO copies data from `src` tensor block, to the `dst` tensor block.
// It's possible to specify src->dst dimension mapping for the copy operation.
// Dimensions of `dst` specify how many elements have to be copied, for the
// `src` we need to know only stride to navigate through source memory buffer.

template <typename Scalar, typename IndexType, int NumDims, int Layout>
class TensorBlockIO {
  static const bool IsColMajor = (Layout == ColMajor);

  typedef StridedLinearBufferCopy<Scalar, IndexType> LinCopy;

 public:
  typedef DSizes<IndexType, NumDims> Dimensions;
  typedef DSizes<int, NumDims> DimensionsMap;

  struct Dst {
    Dst(const Dimensions& dst_dims, const Dimensions& dst_strides, Scalar* dst,
        IndexType dst_offset = 0)
        : dims(dst_dims), strides(dst_strides), data(dst), offset(dst_offset) {}

    Dimensions dims;
    Dimensions strides;
    Scalar* data;
    IndexType offset;
  };

  struct Src {
    Src(const Dimensions& src_strides, const Scalar* src,
        IndexType src_offset = 0)
        : strides(src_strides), data(src), offset(src_offset) {}

    Dimensions strides;
    const Scalar* data;
    IndexType offset;
  };

  // Copies data to `dst` from `src`, using provided dimensions mapping:
  //
  //   src_dimension_index = dst_to_src_dim_map[dst_dimension_index]
  //
  // Returns the number of copied elements.
  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType Copy(
      const Dst& dst, const Src& src, const DimensionsMap& dst_to_src_dim_map) {
    // Copy single scalar value from `src` to `dst`.
    if (NumDims == 0) {
      *(dst.data + dst.offset) = *(src.data + src.offset);
      return 1;
    }

    // Both `dst` and `src` must have contiguous innermost dimension. We also
    // accept the special case with stride '0', because it's used as a trick to
    // implement broadcasting.
    {
      int inner_dim = IsColMajor ? 0 : NumDims - 1;
      EIGEN_UNUSED_VARIABLE(inner_dim);
      eigen_assert(dst.strides[inner_dim] == 1 || dst.strides[inner_dim] == 0);
      eigen_assert(src.strides[inner_dim] == 1 || src.strides[inner_dim] == 0);
    }

    // Give a shorter name to `dst_to_src_dim_map`.
    const DimensionsMap& dim_map = dst_to_src_dim_map;

    // Do not squeeze reordered inner dimensions.
    int num_squeezable_dims = NumSqueezableInnerDims(dim_map);

    // NOTE: We find the innermost dimension (contiguous in memory) in the dst
    // block, and we write data linearly into that dimension, reading it from
    // the src. If dimensions are reordered, we might end up reading data from
    // the src with `stride != 1`.
    //
    // NOTE: Random-Read/Linear-Write can be up to ~2X faster than
    // Linear-Read/Random-Write: https://stackoverflow.com/a/54935680

    // Find the innermost dimension in the dst whose size is not 1. This is the
    // effective inner dim.
    int num_size_one_inner_dims = 0;
    for (int i = 0; i < num_squeezable_dims; ++i) {
      const int dst_dim = IsColMajor ? i : NumDims - i - 1;
      if (dst.dims[dst_dim] != 1) break;
      num_size_one_inner_dims++;
    }

    // If all dimensions are of size 1, just copy a scalar from `src` to `dst`.
    if (num_size_one_inner_dims == NumDims) {
      *(dst.data + dst.offset) = *(src.data + src.offset);
      return 1;
    }

    // Outermost dimension in the dst with `stride == 1` (contiguous in memory).
    const int dst_stride1_dim = IsColMajor
                                    ? num_size_one_inner_dims
                                    : NumDims - num_size_one_inner_dims - 1;

    // Dimension in the src that corresponds to the dst innermost dimension.
    const int src_dim_for_dst_stride1_dim =
        NumDims == 0 ? 1 : dim_map[dst_stride1_dim];

    // Size of the innermost dimension (length of contiguous blocks of memory).
    IndexType dst_inner_dim_size = NumDims == 0 ? 1 : dst.dims[dst_stride1_dim];

    // Squeeze multiple inner dims into one if they are contiguous in `dst` and
    // `src` memory, so we can do less linear copy calls.
    for (int i = num_size_one_inner_dims + 1; i < num_squeezable_dims; ++i) {
      const int dst_dim = IsColMajor ? i : NumDims - i - 1;
      const IndexType dst_stride = dst.strides[dst_dim];
      const IndexType src_stride = src.strides[dim_map[dst_dim]];
      if (dst_inner_dim_size == dst_stride && dst_stride == src_stride) {
        dst_inner_dim_size *= dst.dims[dst_dim];
        ++num_size_one_inner_dims;
      } else {
        break;
      }
    }

    // Setup strides to read data from `src` and write to `dst`.
    IndexType input_offset = src.offset;
    IndexType output_offset = dst.offset;
    IndexType input_stride =
        NumDims == 0 ? 1 : src.strides[src_dim_for_dst_stride1_dim];
    IndexType output_stride = NumDims == 0 ? 1 : dst.strides[dst_stride1_dim];

    const int at_least_1_dim = NumDims <= 1 ? 1 : NumDims - 1;
    array<BlockIteratorState, at_least_1_dim> it;

    // Initialize block iterator state. Squeeze away any dimension of size 1.
    int idx = 0;  // currently initialized iterator state index
    for (int i = num_size_one_inner_dims; i < NumDims - 1; ++i) {
      const int dst_dim = IsColMajor ? i + 1 : NumDims - i - 2;
      if (dst.dims[dst_dim] == 1) continue;

      it[idx].size = dst.dims[dst_dim];
      it[idx].input_stride = src.strides[dim_map[dst_dim]];
      it[idx].output_stride = dst.strides[dst_dim];

      it[idx].input_span = it[idx].input_stride * (it[idx].size - 1);
      it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);

      idx++;
    }

    // Iterate copying data from src to dst.
    const IndexType block_total_size = NumDims == 0 ? 1 : dst.dims.TotalSize();

#define COPY_INNER_DIM(KIND)                                           \
  IndexType num_copied = 0;                                            \
  for (num_copied = 0; num_copied < block_total_size;                  \
       num_copied += dst_inner_dim_size) {                             \
    LinCopy::template Run<KIND>(                                       \
        typename LinCopy::Dst(output_offset, output_stride, dst.data), \
        typename LinCopy::Src(input_offset, input_stride, src.data),   \
        dst_inner_dim_size);                                           \
                                                                       \
    for (int j = 0; j < idx; ++j) {                                    \
      if (++it[j].count < it[j].size) {                                \
        input_offset += it[j].input_stride;                            \
        output_offset += it[j].output_stride;                          \
        break;                                                         \
      }                                                                \
      it[j].count = 0;                                                 \
      input_offset -= it[j].input_span;                                \
      output_offset -= it[j].output_span;                              \
    }                                                                  \
  }                                                                    \
  return num_copied;

    if (input_stride == 1 && output_stride == 1) {
      COPY_INNER_DIM(LinCopy::Kind::Linear);
    } else if (input_stride == 1 && output_stride != 1) {
      COPY_INNER_DIM(LinCopy::Kind::Scatter);
    } else if (input_stride == 0 && output_stride == 1) {
      COPY_INNER_DIM(LinCopy::Kind::FillLinear);
    } else if (input_stride == 0 && output_stride != 1) {
      COPY_INNER_DIM(LinCopy::Kind::FillScatter);
    } else if (output_stride == 1) {
      COPY_INNER_DIM(LinCopy::Kind::Gather);
    } else {
      COPY_INNER_DIM(LinCopy::Kind::Random);
    }

#undef COPY_INNER_DIM
  }

  // Copy from `src` to `dst` with an identity src->dst dimension map. Returns
  // the number of copied elements.
  static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexType Copy(const Dst& dst,
                                                              const Src& src) {
    DimensionsMap dst_to_src_map;
    for (int i = 0; i < NumDims; ++i) dst_to_src_map[i] = i;
    return Copy(dst, src, dst_to_src_map);
  }

 private:
  struct BlockIteratorState {
    BlockIteratorState()
        : size(0),
          count(0),
          input_stride(0),
          output_stride(0),
          input_span(0),
          output_span(0) {}

    IndexType size;
    IndexType count;
    IndexType input_stride;
    IndexType output_stride;
    IndexType input_span;
    IndexType output_span;
  };

  // Compute how many inner dimensions it's allowed to squeeze when doing IO
  // between two tensor blocks. It's safe to squeeze inner dimensions, only
  // if they are not reordered.
  static int NumSqueezableInnerDims(const DimensionsMap& dim_map) {
    int num_squeezable_dims = 0;
    for (int i = 0; i < NumDims; ++i) {
      const int dim = IsColMajor ? i : NumDims - i - 1;
      if (dim_map[dim] != dim) break;
      num_squeezable_dims++;
    }
    return num_squeezable_dims;
  }
};

// -------------------------------------------------------------------------- //
// TensorBlockAssignment assigns a block expression of type `TensorBlockExpr` to
// a Tensor block defined by `desc`, backed by a memory buffer at `target`.
//
// Currently there is no way to write from a Tensor expression to a block of
// memory, if dimensions are reordered. If you need to do that, you should
// materialize a Tensor block expression into a memory buffer, and then use
// TensorBlockIO to copy data between two memory buffers with a custom
// `target->src` dimension map (see definition above).
//
// Also currently the innermost dimension of `target` must have a stride '1'
// (contiguous in memory). This restriction could be lifted with a `pscatter`,
// but in practice it's never needed, and there is a similar TensorBlockIO
// workaround for that.
//
// TODO(ezhulenev): TensorBlockAssignment is a special case of TensorBlockIO
// where `src` is a tensor expression. Explore if it is possible to rewrite IO
// to use expressions instead of pointers, and after that TensorBlockAssignment
// will become an alias to IO.
template <typename Scalar, int NumDims, typename TensorBlockExpr,
          typename IndexType = Eigen::Index>
class TensorBlockAssignment {
  // We will use coeff/packet path to evaluate block expressions.
  typedef TensorEvaluator<const TensorBlockExpr, DefaultDevice>
      TensorBlockEvaluator;

  typedef DSizes<IndexType, NumDims> Dimensions;

  enum {
    Vectorizable = packet_traits<Scalar>::Vectorizable,
    PacketSize = packet_traits<Scalar>::size
  };

  template <bool Vectorizable, typename Evaluator>
  struct InnerDimAssign {
    EIGEN_ALWAYS_INLINE static void Run(Scalar* target, IndexType count,
                                        const Evaluator& eval,
                                        IndexType eval_offset) {
      for (IndexType i = 0; i < count; ++i) {
        target[i] = eval.coeff(eval_offset + i);
      }
    }
  };

  template <typename Evaluator>
  struct InnerDimAssign<true, Evaluator> {
    EIGEN_ALWAYS_INLINE static void Run(Scalar* target, IndexType count,
                                        const Evaluator& eval,
                                        IndexType eval_offset) {
      typedef typename packet_traits<Scalar>::type Packet;

      const IndexType unrolled_size = count - 4 * PacketSize;
      const IndexType vectorized_size = count - PacketSize;
      IndexType i = 0;

      for (; i <= unrolled_size; i += 4 * PacketSize) {
        for (int j = 0; j < 4; ++j) {
          const IndexType idx = eval_offset + i + j * PacketSize;
          Packet p = eval.template packet<Unaligned>(idx);
          pstoreu<Scalar>(target + i + j * PacketSize, p);
        }
      }

      for (; i <= vectorized_size; i += PacketSize) {
        Packet p = eval.template packet<Unaligned>(eval_offset + i);
        pstoreu<Scalar>(target + i, p);
      }

      for (; i < count; ++i) {
        target[i] = eval.coeff(eval_offset + i);
      }
    }
  };

 public:
  struct Target {
    Target(const Dimensions& target_dims, const Dimensions& target_strides,
           Scalar* target_data, IndexType target_offset = 0)
        : dims(target_dims),
          strides(target_strides),
          data(target_data),
          offset(target_offset) {}

    Dimensions dims;
    Dimensions strides;
    Scalar* data;
    IndexType offset;
  };

  static Target target(const Dimensions& target_dims,
                       const Dimensions& target_strides, Scalar* target_data,
                       IndexType target_offset = 0) {
    return Target(target_dims, target_strides, target_data, target_offset);
  }

  template <typename TargetDimsIndexType, typename TargetStridesIndexType>
  static Target target(
      const DSizes<TargetDimsIndexType, NumDims>& target_dims,
      const DSizes<TargetStridesIndexType, NumDims>& target_strides,
      Scalar* target_data, IndexType target_offset = 0) {
    // DSizes constructor will do index type promotion if it's safe.
    return Target(Dimensions(target_dims), Dimensions(target_strides),
                  target_data, target_offset);
  }

  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void Run(
      const Target& target, const TensorBlockExpr& expr) {
    // Prepare evaluator for block expression.
    DefaultDevice default_device;
    TensorBlockEvaluator eval(expr, default_device);

    // Tensor block expression dimension should match destination dimensions.
    eigen_assert(dimensions_match(target.dims, eval.dimensions()));

    static const int Layout = TensorBlockEvaluator::Layout;
    static const bool is_col_major = Layout == ColMajor;

    // Initialize output inner dimension size based on a layout.
    const IndexType output_size = NumDims == 0 ? 1 : target.dims.TotalSize();
    const int inner_dim_idx = is_col_major ? 0 : NumDims - 1;
    IndexType output_inner_dim_size = target.dims[inner_dim_idx];

    // Target inner dimension stride must be '1'.
    eigen_assert(target.strides[inner_dim_idx] == 1);

    // Squeeze multiple inner dims into one if they are contiguous in `target`.
    IndexType num_squeezed_dims = 0;
    for (Index i = 1; i < NumDims; ++i) {
      const Index dim = is_col_major ? i : NumDims - i - 1;
      const IndexType target_stride = target.strides[dim];

      if (output_inner_dim_size == target_stride) {
        output_inner_dim_size *= target.dims[dim];
        num_squeezed_dims++;
      } else {
        break;
      }
    }

    // Initialize output block iterator state. Dimension in this array are
    // always in inner_most -> outer_most order (col major layout).
    array<BlockIteratorState, NumDims> it;

    int idx = 0;  // currently initialized iterator state index
    for (Index i = num_squeezed_dims; i < NumDims - 1; ++i) {
      const Index dim = is_col_major ? i + 1 : NumDims - i - 2;

      it[idx].count = 0;
      it[idx].size = target.dims[dim];
      it[idx].output_stride = target.strides[dim];
      it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
      idx++;
    }

    // We read block expression from the beginning, and start writing data to
    // `target` at given offset.
    IndexType input_offset = 0;
    IndexType output_offset = target.offset;

    // Iterate copying data from `eval` to `target`.
    for (IndexType i = 0; i < output_size; i += output_inner_dim_size) {
      // Assign to `target` at current offset.
      InnerDimAssign<Vectorizable && TensorBlockEvaluator::PacketAccess,
                     TensorBlockEvaluator>::Run(target.data + output_offset,
                                                output_inner_dim_size, eval,
                                                input_offset);

      // Move input offset forward by the number of assigned coefficients.
      input_offset += output_inner_dim_size;

      // Update index.
      for (int j = 0; j < idx; ++j) {
        if (++it[j].count < it[j].size) {
          output_offset += it[j].output_stride;
          break;
        }
        it[j].count = 0;
        output_offset -= it[j].output_span;
      }
    }
  }

 private:
  struct BlockIteratorState {
    BlockIteratorState()
        : count(0), size(0), output_stride(0), output_span(0) {}

    IndexType count;
    IndexType size;
    IndexType output_stride;
    IndexType output_span;
  };
};

// -------------------------------------------------------------------------- //

}  // namespace internal
}  // namespace Eigen

#endif  // EIGEN_CXX11_TENSOR_TENSOR_BLOCK_H
