// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert BATCH_TILE >= 1
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$assert OP in ["ADD", "DIV", "RDIV", "MAX", "MIN", "MUL", "SUB", "RSUB", "SQRDIFF"]
$assert ACTIVATION in ["LINEAR", "MINMAX", "RELU"]
#include <assert.h>

#include <xnnpack/common.h>
#include <xnnpack/math.h>
#include <xnnpack/vbinary.h>


$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
$OP_FUNC = {
$  "ADD": lambda x: "%s + vb" % x,
$  "DIV": lambda x: "%s / vb" % x,
$  "RDIV": lambda x: "vb / %s" % x,
$  "MAX": lambda x: "%s(%s, vb)" % (MAX_F32, x),
$  "MIN": lambda x: "%s(%s, vb)" % (MIN_F32, x),
$  "MUL": lambda x: "%s * vb" % x,
$  "SUB": lambda x: "%s - vb" % x,
$  "RSUB": lambda x: "vb - %s" % x,
$  "SQRDIFF": lambda x: "%s - vb" % x,
$}[OP]
$SUFFIX = {"LINEAR": "", "RELU": "_relu", "MINMAX": "_minmax"}[ACTIVATION]
$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
void xnn_f32_v${OP.lower()}c${SUFFIX}_ukernel__${"wasm" if WASM else "scalar"}_x${BATCH_TILE}(
    size_t n,
    const float* a,
    const float* b,
    float* y,
    const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(n != 0);
  assert(n % sizeof(float) == 0);
  assert(a != NULL);
  assert(b != NULL);
  assert(y != NULL);

  $if ACTIVATION == "MINMAX":
    const float vy_min = params->scalar.min;
    const float vy_max = params->scalar.max;

  const float vb = *b;
  $if BATCH_TILE > 1:
    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
      $for N in range(BATCH_TILE):
        const float va${ABC[N]} = a[${N}];
      a += ${BATCH_TILE};

      $for N in range(BATCH_TILE):
        float vy${ABC[N]} = ${OP_FUNC("va" + ABC[N])};

      $if OP == "SQRDIFF":
        $for N in range(BATCH_TILE):
          vy${ABC[N]} = vy${ABC[N]} * vy${ABC[N]};

      $if ACTIVATION == "MINMAX":
        $for N in range(BATCH_TILE):
          vy${ABC[N]} = ${MAX_F32}(vy${ABC[N]}, vy_min);

        $for N in range(BATCH_TILE):
          vy${ABC[N]} = ${MIN_F32}(vy${ABC[N]}, vy_max);
      $elif ACTIVATION == "RELU":
        $for N in range(BATCH_TILE):
          vy${ABC[N]} = ${MAX_F32}(vy${ABC[N]}, 0.0f);

      $for N in range(BATCH_TILE):
        y[${N}] = vy${ABC[N]};
      y += ${BATCH_TILE};
    }
    if XNN_UNLIKELY(n != 0) {
      $if BATCH_TILE > 2:
        do {
          const float va = *a++;
          float vy = ${OP_FUNC("va")};
          $if OP == "SQRDIFF":
            vy = vy * vy;
          $if ACTIVATION == "MINMAX":
            vy = ${MAX_F32}(vy, vy_min);
            vy = ${MIN_F32}(vy, vy_max);
          $elif ACTIVATION == "RELU":
            vy = ${MAX_F32}(vy, 0.0f);
          *y++ = vy;
          n -= sizeof(float);
        } while (n != 0);
      $else:
        const float va = *a;
        float vy = ${OP_FUNC("va")};
        $if OP == "SQRDIFF":
          vy = vy * vy;
        $if ACTIVATION == "MINMAX":
          vy = ${MAX_F32}(vy, vy_min);
          vy = ${MIN_F32}(vy, vy_max);
        $elif ACTIVATION == "RELU":
          vy = ${MAX_F32}(vy, 0.0f);
        *y = vy;
    }
  $else:
    for (; n >= sizeof(float); n -= sizeof(float)) {
      const float va = *a++;
      float vy = ${OP_FUNC("va")};
      $if OP == "SQRDIFF":
        vy = vy * vy;
      $if ACTIVATION == "MINMAX":
        vy = ${MAX_F32}(vy, vy_min);
        vy = ${MIN_F32}(vy, vy_max);
      $elif ACTIVATION == "RELU":
        vy = ${MAX_F32}(vy, 0.0f);
      *y++ = vy;
    }
}
