// Copyright 2020 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert PIXEL_TILE >= 1
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#include <assert.h>

#include <xnnpack/ibilinear.h>


void xnn_f32_ibilinear_chw_ukernel__scalar_p${PIXEL_TILE}(
    size_t output_pixels,
    size_t channels,
    const float**restrict input,
    size_t input_offset,
    const float*restrict weights,
    float*restrict output,
    size_t input_increment)
{
  assert(output_pixels != 0);
  assert(channels != 0);
  assert(input_increment % sizeof(float) == 0);

  size_t c = channels;
  do {
    const float** i = input;
    const float* w = weights;

    size_t p = output_pixels;
    $if PIXEL_TILE > 1:
      for (; p >= ${PIXEL_TILE}; p -= ${PIXEL_TILE}) {
        $for P in range(PIXEL_TILE):
          const float* itl${P} = (const float*) ((uintptr_t) i[${P * 2}] + input_offset);
          const float* ibl${P} = (const float*) ((uintptr_t) i[${P * 2 + 1}] + input_offset);
        i += ${PIXEL_TILE} * 2;

        $for P in range(PIXEL_TILE):
           const float valphah${ABC[P]} = w[${P * 2}];
           const float valphav${ABC[P]} = w[${P * 2 + 1}];
        w += ${PIXEL_TILE} * 2;

        $for P in range(PIXEL_TILE):
          const float vtl${ABC[P]} = itl${P}[0];
          const float vtr${ABC[P]} = itl${P}[1];
          const float vbl${ABC[P]} = ibl${P}[0];
          const float vbr${ABC[P]} = ibl${P}[1];

        $for P in range(PIXEL_TILE):
          const float vtd${ABC[P]} = vtr${ABC[P]} - vtl${ABC[P]};
          const float vbd${ABC[P]} = vbr${ABC[P]} - vbl${ABC[P]};

        $for P in range(PIXEL_TILE):
          const float vt${ABC[P]} = vtl${ABC[P]} + vtd${ABC[P]} * valphah${ABC[P]};
          const float vb${ABC[P]} = vbl${ABC[P]} + vbd${ABC[P]} * valphah${ABC[P]};

        $for P in range(PIXEL_TILE):
          const float vd${ABC[P]} = vb${ABC[P]} - vt${ABC[P]};

        $for P in range(PIXEL_TILE):
          const float vo${ABC[P]} = vt${ABC[P]} + vd${ABC[P]} * valphav${ABC[P]};

        $for P in range(PIXEL_TILE):
          output[${P}] = vo${ABC[P]};
        output += ${PIXEL_TILE};
      }

      for (; p >= 1; p -= 1) {
        const float* itl = (const float*) ((uintptr_t) i[0] + input_offset);
        const float* ibl = (const float*) ((uintptr_t) i[1] + input_offset);
        i += 2;

        const float valphah = w[0];
        const float valphav = w[1];
        w += 2;

        const float vtl = itl[0];
        const float vtr = itl[1];
        const float vbl = ibl[0];
        const float vbr = ibl[1];

        const float vtd = vtr - vtl;
        const float vbd = vbr - vbl;

        const float vt = vtl + vtd * valphah;
        const float vb = vbl + vbd * valphah;

        const float vd = vb - vt;

        const float vo = vt + vd * valphav;

        *output++ = vo;
      }
    $else:
      do {
        const float* itl = (const float*) ((uintptr_t) i[0] + input_offset);
        const float* ibl = (const float*) ((uintptr_t) i[1] + input_offset);
        i += 2;

        const float valphah = w[0];
        const float valphav = w[1];
        w += 2;

        const float vtl = itl[0];
        const float vtr = itl[1];
        const float vbl = ibl[0];
        const float vbr = ibl[1];

        const float vtd = vtr - vtl;
        const float vbd = vbr - vbl;

        const float vt = vtl + vtd * valphah;
        const float vb = vbl + vbd * valphah;

        const float vd = vb - vt;

        const float vo = vt + vd * valphav;

        *output++ = vo;
      } while (--p != 0);

    input_offset += input_increment;

    c--;
  } while (c != 0);
}
