// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert CHANNEL_TILE >= 1
$assert KERNEL_TILE >= 2
$assert ACCUMULATORS >= 1
$assert ACTIVATION in ["LINEAR", "MINMAX"]
$assert ACTIVATION != "LINEAR" or not WASM
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#include <assert.h>

#include <xnnpack/dwconv.h>
#include <xnnpack/math.h>


$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
$SUFFIX = {"LINEAR": "", "MINMAX": "_minmax"}[ACTIVATION]
$PARAMS = {"LINEAR": "xnn_f32_default_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
void xnn_f32_dwconv${SUFFIX}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${"wasm" if WASM else "scalar"}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
    size_t channels,
    size_t output_width,
    const float** input,
    const float* weights,
    float* output,
    size_t input_stride,
    size_t output_increment,
    size_t input_offset,
    const float* zero,
    const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(channels != 0);
  assert(output_width != 0);

  $if ACTIVATION == "MINMAX":
    const float vmin = params->scalar.min;
    const float vmax = params->scalar.max;
  do {
    $for K in range(KERNEL_TILE):
      const float* i${K} = input[${K}];
      assert(i${K} != NULL);
      if XNN_UNPREDICTABLE(i${K} != zero) {
        i${K} = (const float*) ((uintptr_t) i${K} + input_offset);
      }
    input = (const float**) ((uintptr_t) input + input_stride);

    size_t c = channels;
    const float* w = weights;
    $if CHANNEL_TILE > 1:
      for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
        $for C in range(CHANNEL_TILE):
          float vacc${C}p0 = w[${C}];

        $for K in range(KERNEL_TILE):

          $for C in range(CHANNEL_TILE):
            const float vi${K}x${C} = i${K}[${C}];
          i${K} += ${CHANNEL_TILE};

          $for C in range(CHANNEL_TILE):
            const float vk${K}x${C} = w[${(K + 1) * CHANNEL_TILE + C}];
            $if 1 <= K < ACCUMULATORS:
              float vacc${C}p${K} = vi${K}x${C} * vk${K}x${C};
            $else:
              vacc${C}p${K % ACCUMULATORS} = math_muladd_f32(vi${K}x${C}, vk${K}x${C}, vacc${C}p${K % ACCUMULATORS});

        w += ${(KERNEL_TILE + 1) * CHANNEL_TILE};

        $if ACCUMULATORS > 1:
          // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
          $ACC_SLICE = 1
          $while ACC_SLICE < ACCUMULATORS:
            $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
              $if A + ACC_SLICE < ACCUMULATORS:
                $for C in range(CHANNEL_TILE):
                  vacc${C}p${A} = vacc${C}p${A} + vacc${C}p${A + ACC_SLICE};
            $ACC_SLICE *= 2

        $if ACTIVATION == "MINMAX":
          $for C in range(CHANNEL_TILE):
            float vacc${C} = ${MAX_F32}(vacc${C}p0, vmin);

          $for C in range(CHANNEL_TILE):
            vacc${C} = ${MIN_F32}(vacc${C}, vmax);

          $for C in range(CHANNEL_TILE):
            output[${C}] = vacc${C};
        $else:
          $for C in range(CHANNEL_TILE):
            output[${C}] = vacc${C}p0;
        output += ${CHANNEL_TILE};
      }
      for (; c >= 1; c -= 1) {
        float vacc0p0 = *w++;

        $for K in range(KERNEL_TILE):
          const float vi${K} = *i${K}++;
          const float vk${K} = w[${(K + 1) * CHANNEL_TILE - 1}];
          $if 1 <= K < ACCUMULATORS:
            float vacc0p${K} = vi${K} * vk${K};
          $else:
            vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}, vk${K}, vacc0p${K % ACCUMULATORS});

        $if ACCUMULATORS > 1:
          // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
          $ACC_SLICE = 1
          $while ACC_SLICE < ACCUMULATORS:
            $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
              $if A + ACC_SLICE < ACCUMULATORS:
                vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
            $ACC_SLICE *= 2

        $if ACTIVATION == "MINMAX":
          float vacc0 = ${MAX_F32}(vacc0p0, vmin);
          vacc0 = ${MIN_F32}(vacc0, vmax);
          *output++ = vacc0;
        $else:
          *output++ = vacc0p0;
      }
    $else:
      do {
        float vacc0p0 = w[0];
        $for K in range(KERNEL_TILE):

          const float vi${K} = *i${K}++;
          const float vk${K} = w[${K+1}];
          $if 1 <= K < ACCUMULATORS:
            float vacc0p${K} = vi${K} * vk${K};
          $else:
            vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}, vk${K}, vacc0p${K % ACCUMULATORS});

        w += ${KERNEL_TILE + 1};

        $ACC_STEP = 1
        $while ACC_STEP < ACCUMULATORS:
          $for A in range(0, ACCUMULATORS, ACC_STEP * 2):
            $if A + ACC_STEP < ACCUMULATORS:
              vacc0p${A} += vacc0p${A + ACC_STEP};
          $ACC_STEP *= 2

        $if ACTIVATION == "MINMAX":
          float vacc0 = ${MAX_F32}(vacc0p0, vmin);
          vacc0 = ${MIN_F32}(vacc0, vmax);
          *output++ = vacc0;
        $else:
          *output++ = vacc0p0;
      } while (--c != 0);

    output = (float*) ((uintptr_t) output + output_increment);
  } while (--output_width != 0);
}
