// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QS8", "QU8"]
$assert CHANNEL_TILE >= 1
$assert CHANNEL_TILE <= 16
$assert ROW_TILE >= 3
$assert ROW_SUBTILE >= 3
$assert ROW_SUBTILE <= ROW_TILE
$assert REQUANTIZATION == "FP32"
#include <assert.h>
$if VARIANT == "LRINTF":
  #include <math.h>

#include <xnnpack/gavgpool.h>
#include <xnnpack/math.h>


$PARAMS_STRUCT = "fp32_scalar_" + VARIANT.lower()
$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
void xnn_${DATATYPE.lower()}_gavgpool_minmax_fp32_ukernel_${ROW_TILE}p${ROW_SUBTILE}x__scalar_${VARIANT.lower()}_c${CHANNEL_TILE}(
    size_t rows,
    size_t channels,
    const ${XINT8_T}* input,
    size_t input_stride,
    const ${XINT8_T}* zero,
    int32_t* buffer,
    ${XINT8_T}* output,
    const union xnn_${DATATYPE.lower()}_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(rows > ${ROW_TILE});
  assert(channels != 0);

  const ${XINT8_T}* i0 = input;
  $for M in range(1, ROW_TILE):
    const ${XINT8_T}* i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M-1} + input_stride);
  const size_t input_increment = ${ROW_TILE} * input_stride - round_up_po2(channels, ${CHANNEL_TILE}) * sizeof(${XINT8_T});

  const int32_t vinit_bias = params->${PARAMS_STRUCT}.init_bias;
  int32_t* b = buffer;
  $if CHANNEL_TILE == 1:
    size_t c = channels;
    do {
      int32_t vacc = vinit_bias;
      $for M in range(2):
        const int32_t vi${M} = (int32_t) *i${M}++;

      $for M in range(2, ROW_TILE):
        vacc += vi${M-2};
        const int32_t vi${M} = (int32_t) *i${M}++;

      $for M in range(ROW_TILE - 2, ROW_TILE):
        vacc += vi${M};

      *b++ = vacc;
    } while (--c != 0);
  $else:
    for (ptrdiff_t c = (ptrdiff_t) channels; c > 0; c -= ${CHANNEL_TILE}) {
      $for C in range(CHANNEL_TILE):
        const int32_t vi0x${C} = (int32_t) i0[${C}];
      i0 += ${CHANNEL_TILE};

      $for C in range(CHANNEL_TILE):
        int32_t vacc${C} = vi0x${C} + vinit_bias;
        const int32_t vi1x${C} = (int32_t) i1[${C}];
      i1 += ${CHANNEL_TILE};

      $for M in range(2, ROW_TILE):
        $for C in range(CHANNEL_TILE):
          vacc${C} += vi${M-1}x${C};
          const int32_t vi${M}x${C} = (int32_t) i${M}[${C}];
        i${M} += ${CHANNEL_TILE};

      $for C in range(CHANNEL_TILE):
        vacc${C} += vi${ROW_TILE-1}x${C};

      $for C in range(CHANNEL_TILE):
        b[${C}] = vacc${C};
      b += ${CHANNEL_TILE};
    }

  for (rows -= ${ROW_TILE}; rows > ${ROW_SUBTILE}; rows -= ${ROW_SUBTILE}) {
    $for M in range(ROW_SUBTILE):
      i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment);

    int32_t* b = buffer;
    $if CHANNEL_TILE == 1:
      size_t c = channels;
      do {
        int32_t vacc = *b;
        $for M in range(2):
          const int32_t vi${M} = (int32_t) *i${M}++;

        $for M in range(2, ROW_SUBTILE):
          vacc += vi${M-2};
          const int32_t vi${M} = (int32_t) *i${M}++;

        $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
          vacc += vi${M};

        *b++ = vacc;
      } while (--c != 0);
    $else:
      for (ptrdiff_t c = (ptrdiff_t) channels; c > 0; c -= ${CHANNEL_TILE}) {
        $for C in range(CHANNEL_TILE):
          int32_t vacc${C} = b[${C}];
          const int32_t vi0x${C} = (int32_t) i0[${C}];
        i0 += ${CHANNEL_TILE};

        $for M in range(1, ROW_SUBTILE):
          $for C in range(CHANNEL_TILE):
            vacc${C} += vi${M-1}x${C};
            const int32_t vi${M}x${C} = (int32_t) i${M}[${C}];
          i${M} += ${CHANNEL_TILE};

        $for C in range(CHANNEL_TILE):
          vacc${C} += vi${ROW_SUBTILE-1}x${C};

        $for C in range(CHANNEL_TILE):
          b[${C}] = vacc${C};
        b += ${CHANNEL_TILE};
      }
  }

  i0 = (const ${XINT8_T}*) ((uintptr_t) i${ROW_TILE - ROW_SUBTILE} + input_increment);
  $for M in range(1, ROW_SUBTILE):
    i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment);
    $if M % 2 == 1:
      if XNN_UNPREDICTABLE(rows < ${M+1}) {
        i${M} = zero;
      }
    $else:
      if XNN_UNPREDICTABLE(rows <= ${M}) {
        i${M} = zero;
      }

  const float vscale = params->${PARAMS_STRUCT}.scale;
  $if VARIANT == "FMAGIC":
    const float voutput_min_less_zero_point = params->fp32_scalar_fmagic.output_min_less_zero_point;
    const float voutput_max_less_zero_point = params->fp32_scalar_fmagic.output_max_less_zero_point;
    const float vmagic_bias = params->fp32_scalar_fmagic.magic_bias;
    const int32_t vmagic_bias_less_output_zero_point = params->fp32_scalar_fmagic.magic_bias_less_output_zero_point;
  $elif VARIANT == "IMAGIC":
    const float vmagic_bias = params->fp32_scalar_imagic.magic_bias;
    const int32_t vmagic_min = params->fp32_scalar_imagic.magic_min;
    const int32_t vmagic_max = params->fp32_scalar_imagic.magic_max;
    const int32_t vmagic_bias_less_zero_point = params->fp32_scalar_imagic.magic_bias_less_zero_point;
  $elif VARIANT == "LRINTF":
    const float voutput_min_less_zero_point = params->fp32_scalar_lrintf.output_min_less_zero_point;
    const float voutput_max_less_zero_point = params->fp32_scalar_lrintf.output_max_less_zero_point;
    const int32_t voutput_zero_point = params->fp32_scalar_lrintf.output_zero_point;
  $if CHANNEL_TILE == 1:
    do {
      int32_t vacc = *buffer++;
      $for M in range(2):
        const int32_t vi${M} = (int32_t) *i${M}++;

      $for M in range(2, ROW_SUBTILE):
        vacc += vi${M-2};
        const int32_t vi${M} = (int32_t) *i${M}++;

      $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
        vacc += vi${M};

      float vfpacc = (float) vacc * vscale;
      $if VARIANT == "FMAGIC":
        vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
        vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
        vfpacc += vmagic_bias;
        int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
      $elif VARIANT == "IMAGIC":
        vfpacc += vmagic_bias;
        int32_t vout = (int32_t) float_as_uint32(vfpacc);
        vout = math_max_s32(vout, vmagic_min);
        vout = math_min_s32(vout, vmagic_max);
        vout -= vmagic_bias_less_zero_point;
      $elif VARIANT == "LRINTF":
        vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
        vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
        const int32_t vrndacc = (int32_t) lrintf(vfpacc);
        int32_t vout = vrndacc + voutput_zero_point;

      *output++ = (${XINT8_T}) vout;
    } while (--channels != 0);
  $else:
    for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) {
      $for C in range(CHANNEL_TILE):
        int32_t vacc${C} = buffer[${C}];
        const int32_t vi0x${C} = (int32_t) i0[${C}];
      buffer += ${CHANNEL_TILE};
      i0 += ${CHANNEL_TILE};

      $for M in range(1, ROW_SUBTILE):
        $for C in range(CHANNEL_TILE):
          vacc${C} += vi${M-1}x${C};
          const int32_t vi${M}x${C} = (int32_t) i${M}[${C}];
        i${M} += ${CHANNEL_TILE};

      $for C in range(CHANNEL_TILE):
        vacc${C} += vi${ROW_SUBTILE-1}x${C};

      $for C in range(CHANNEL_TILE):
        float vfpacc${C} = (float) vacc${C} * vscale;

      $if VARIANT == "FMAGIC":
        $for C in range(CHANNEL_TILE):
          vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);

        $for C in range(CHANNEL_TILE):
          vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);

        $for C in range(CHANNEL_TILE):
          vfpacc${C} += vmagic_bias;

        $for C in range(CHANNEL_TILE):
          int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}) - vmagic_bias_less_output_zero_point;
      $elif VARIANT == "IMAGIC":
        $for C in range(CHANNEL_TILE):
          vfpacc${C} += vmagic_bias;

        $for C in range(CHANNEL_TILE):
          int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C});

        $for C in range(CHANNEL_TILE):
          vout${C} = math_max_s32(vout${C}, vmagic_min);

        $for C in range(CHANNEL_TILE):
          vout${C} = math_min_s32(vout${C}, vmagic_max);

        $for C in range(CHANNEL_TILE):
          vout${C} -= vmagic_bias_less_zero_point;
      $elif VARIANT == "LRINTF":
        $for C in range(CHANNEL_TILE):
          vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);

        $for C in range(CHANNEL_TILE):
          vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);

        $for C in range(CHANNEL_TILE):
          const int32_t vrndacc${C} = (int32_t) lrintf(vfpacc${C});

        $for C in range(CHANNEL_TILE):
          int32_t vout${C} = vrndacc${C} + voutput_zero_point;

      $for C in range(CHANNEL_TILE):
        output[${C}] = (${XINT8_T}) vout${C};
      output += ${CHANNEL_TILE};
    }
    if XNN_UNLIKELY(channels != 0) {
      $if CHANNEL_TILE == 2:
        int32_t vacc = *buffer;
        $for M in range(2):
          const int32_t vi${M} = (int32_t) *i${M};

        $for M in range(2, ROW_SUBTILE):
          vacc += vi${M-2};
          const int32_t vi${M} = (int32_t) *i${M};

        $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
          vacc += vi${M};

        float vfpacc = (float) vacc * vscale;
        $if VARIANT == "FMAGIC":
          vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
          vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
          vfpacc += vmagic_bias;
          int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
        $elif VARIANT == "IMAGIC":
          vfpacc += vmagic_bias;
          int32_t vout = (int32_t) float_as_uint32(vfpacc);
          vout = math_max_s32(vout, vmagic_min);
          vout = math_min_s32(vout, vmagic_max);
          vout -= vmagic_bias_less_zero_point;
        $elif VARIANT == "LRINTF":
          vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
          vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
          const int32_t vrndacc = (int32_t) lrintf(vfpacc);
          int32_t vout = vrndacc + voutput_zero_point;

        *output = (${XINT8_T}) vout;
      $else:
        do {
          int32_t vacc = *buffer++;
          $for M in range(2):
            const int32_t vi${M} = (int32_t) *i${M}++;

          $for M in range(2, ROW_SUBTILE):
            vacc += vi${M-2};
            const int32_t vi${M} = (int32_t) *i${M}++;

          $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
            vacc += vi${M};

          float vfpacc = (float) vacc * vscale;
          $if VARIANT == "FMAGIC":
            vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
            vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
            vfpacc += vmagic_bias;
            int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
          $elif VARIANT == "IMAGIC":
            vfpacc += vmagic_bias;
            int32_t vout = (int32_t) float_as_uint32(vfpacc);
            vout = math_max_s32(vout, vmagic_min);
            vout = math_min_s32(vout, vmagic_max);
            vout -= vmagic_bias_less_zero_point;
          $elif VARIANT == "LRINTF":
            vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
            vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
            const int32_t vrndacc = (int32_t) lrintf(vfpacc);
            int32_t vout = vrndacc + voutput_zero_point;

          *output++ = (${XINT8_T}) vout;
        } while (--channels != 0);
    }
}
