// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert ((TILE_HEIGHT & (TILE_HEIGHT-1) == 0) and TILE_HEIGHT != 0)
$assert ((TILE_WIDTH & (TILE_WIDTH-1) == 0) and TILE_WIDTH != 0)
$assert SIZE in [8, 16, 32, 64]
$assert TYPE in ["int8_t", "int16_t", "int", "float", "int64_t", "double"]
$assert (TILE_WIDTH * SIZE <= 128)
$SUFFIX = "float" if TYPE in ["float", "double"] else "int"

#include <assert.h>

#include <xnnpack/common.h>
#include <xnnpack/math.h>
#include <xnnpack/transpose.h>

void xnn_x${SIZE}_transposec_ukernel__${TILE_HEIGHT}x${TILE_WIDTH}_scalar_${SUFFIX}(
    const uint${SIZE}_t *input,
    uint${SIZE}_t * output,
    size_t input_stride,
    size_t output_stride,
    size_t block_width,
    size_t block_height) XNN_OOB_READS
{
  assert(output_stride >= block_height * sizeof(${TYPE}));
  assert(input_stride >= block_width * sizeof(${TYPE}));

  const size_t tile_height = ${TILE_HEIGHT};
  const size_t tile_width = ${TILE_WIDTH};
  const size_t tile_wbytes = tile_width * sizeof(${TYPE});
  $if TILE_HEIGHT == 1:
    const size_t input_reset = tile_wbytes - block_height * input_stride;
    const size_t output_reset = tile_width * output_stride - block_height * sizeof(${TYPE});
  $else:
    const size_t input_reset = tile_wbytes - round_down_po2(block_height, tile_height) * input_stride;
    const size_t output_reset = tile_width * output_stride - round_down_po2(block_height, 2) * sizeof(${TYPE});
  const size_t input_offset = tile_height * input_stride;

  const ${TYPE}* i0 = (const ${TYPE}*) input;
  $for N in range(1, TILE_HEIGHT):
    const ${TYPE}* i${N} = (const ${TYPE}*) ((uintptr_t) i${N-1} + input_stride);

  ${TYPE}* o0 = (${TYPE}*) output;
  $for N in range(1, TILE_WIDTH):
    ${TYPE}* o${N} = (${TYPE}*) ((uintptr_t) o${N-1} + output_stride);

  do {
    $if TILE_WIDTH > 1:
      if XNN_UNPREDICTABLE(block_width < 2) {
        o1 = o0;
      }
    $for N in range(2, TILE_WIDTH, 2):
      if XNN_UNPREDICTABLE(block_width <= ${N}) {
        o${N} = o0;
      }
      if XNN_UNPREDICTABLE(block_width < ${N+2}) {
        o${N+1} = o0;
      }
    size_t bh = block_height;
    for (; bh >= ${TILE_HEIGHT}; bh -= ${TILE_HEIGHT}) {
      $for M in reversed(range(TILE_WIDTH)):
        $for N in range(TILE_HEIGHT):
          *o${M}++ = i${N}[${M}];
      $for N in range(TILE_HEIGHT):
        i${N} = (const ${TYPE}*) ((uintptr_t) i${N} + input_offset);
    }
    $if TILE_HEIGHT > 2:
      const ${TYPE}* i = i0;
      if (bh & 2) {
        $for M in reversed(range(TILE_WIDTH)):
          o${M}[0] = i0[${M}];
          o${M}[1] = i1[${M}];
          o${M} += 2;
        i = i2;
      }
      if (bh & 1) {
        $for M in reversed(range(TILE_WIDTH)):
          o${M}[0] = i[${M}];
      }
    $elif TILE_HEIGHT > 1:
      if (bh & 1) {
        $for M in reversed(range(TILE_WIDTH)):
          o${M}[0] = i0[${M}];
      }

    i0 = (const ${TYPE}*) ((uintptr_t) i0 + input_reset);
    $for N in range(1, TILE_HEIGHT):
      i${N} = (const ${TYPE}*) ((uintptr_t) i${N-1} + input_stride);
    $for N in range(TILE_WIDTH):
      o${N} = (${TYPE}*) ((uintptr_t) o${N} + output_reset);
    block_width = doz(block_width, tile_width);
  } while (block_width != 0);
}
