// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert SSE in [2, 4]
$assert not XOP or AVX
$assert not AVX or SSE == 4
$assert DATATYPE in ["S8", "U8"]
$assert CHANNEL_TILE % 8 == 0
$assert CHANNEL_TILE >= 8
$assert PIXEL_TILE == 1
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#include <assert.h>

$if XOP:
  #if defined(__GNUC__) || defined(__clang__)
    #include <x86intrin.h>
  #else
    #include <immintrin.h>
    #include <ammintrin.h>
  #endif
$else:
  $SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
  #include <${SSE_HEADER}>

#include <xnnpack/common.h>
#include <xnnpack/ibilinear.h>
#include <xnnpack/unaligned.h>


$XINT8_T = {"S8": "int8_t", "U8": "uint8_t"}[DATATYPE]
$_MM_CVTEPX8_EPI16 = {"S8": "_mm_cvtepi8_epi16", "U8": "_mm_cvtepu8_epi16"}[DATATYPE]
$_MM_SRXI_EPI32 = {"S8": "_mm_srai_epi32", "U8": "_mm_srli_epi32"}[DATATYPE]
$_MM_SRXI_EPI16 = {"S8": "_mm_srai_epi16", "U8": "_mm_srli_epi16"}[DATATYPE]
$_MM_PACKXS_EPI16 = {"S8": "_mm_packs_epi16", "U8": "_mm_packus_epi16"}[DATATYPE]
$ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 3: "ssse3", 4: "sse41"}[SSE]
void xnn_${DATATYPE.lower()}_ibilinear_ukernel__${ISA}_c${CHANNEL_TILE}${"" if PIXEL_TILE == 1 else "x%d" % PIXEL_TILE}(
    size_t output_pixels,
    size_t channels,
    const ${XINT8_T}**restrict input,
    size_t input_offset,
    const int16_t*restrict weights,
    ${XINT8_T}*restrict output,
    size_t output_increment) XNN_OOB_READS
{
  assert(output_pixels != 0);
  assert(channels != 0);

  do {
    const ${XINT8_T}* i0 = (const ${XINT8_T}*) ((uintptr_t) input[0] + input_offset);
    const ${XINT8_T}* i1 = (const ${XINT8_T}*) ((uintptr_t) input[1] + input_offset);
    const ${XINT8_T}* i2 = (const ${XINT8_T}*) ((uintptr_t) input[2] + input_offset);
    const ${XINT8_T}* i3 = (const ${XINT8_T}*) ((uintptr_t) input[3] + input_offset);
    input += 4;

    const __m128i valpha = _mm_cvtsi32_si128(*((const int*) weights));
    weights += 2;
    __m128i valphah = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(0, 0, 0, 0));
    valphah = _mm_unpacklo_epi64(valphah, valphah);
    $if SSE == 2:
      __m128i valphav = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(1, 1, 1, 1));
      valphav = _mm_unpacklo_epi64(valphav, valphav);
    $else:
      __m128i valphav = _mm_srli_epi32(valpha, 16);
      valphav = _mm_shuffle_epi32(valphav, _MM_SHUFFLE(0, 0, 0, 0));

    $if SSE == 4:
      valphah = _mm_blend_epi16(valphah, _mm_sub_epi16(_mm_set1_epi32(0x08000000), valphah), 0xAA);
    $else:
      valphah = _mm_xor_si128(valphah, _mm_set1_epi32(0xFFFF0000));
      valphah = _mm_add_epi16(valphah, _mm_set1_epi32(0x08010000));

    const __m128i vrounding = _mm_set1_epi32(0x00200000);

    size_t c = channels;
    $if CHANNEL_TILE > 8:
      for (; c >= ${CHANNEL_TILE} * sizeof(${XINT8_T}); c -= ${CHANNEL_TILE} * sizeof(${XINT8_T})) {
        $if SSE == 4:
          const __m128i vtl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0));
          const __m128i vtr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1));
          const __m128i vbl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2));
          const __m128i vbr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3));
          $for C in range(8, CHANNEL_TILE, 8):
            const __m128i vtl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i0 + ${C})));
            const __m128i vtr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i1 + ${C})));
            const __m128i vbl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i2 + ${C})));
            const __m128i vbr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i3 + ${C})));
        $else:
          __m128i vtl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i0);
          __m128i vtr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i1);
          __m128i vbl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i2);
          __m128i vbr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i3);
          $for C in range(8, CHANNEL_TILE, 8):
            __m128i vtl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i0 + ${C}));
            __m128i vtr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i1 + ${C}));
            __m128i vbl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i2 + ${C}));
            __m128i vbr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i3 + ${C}));
        i0 += ${CHANNEL_TILE};
        i1 += ${CHANNEL_TILE};
        i2 += ${CHANNEL_TILE};
        i3 += ${CHANNEL_TILE};

        $if SSE != 4:
          $if DATATYPE == "U8":
            __m128i vzero = _mm_setzero_si128();
            $for C in range(0, CHANNEL_TILE, 8):
              vtl${ABC[C:C+8]} = _mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vzero);
              vtr${ABC[C:C+8]} = _mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vzero);
              vbl${ABC[C:C+8]} = _mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vzero);
              vbr${ABC[C:C+8]} = _mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vzero);
          $else:
            $for C in range(0, CHANNEL_TILE, 8):
              vtl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vtl${ABC[C:C+8]}), 8);
              vtr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vtr${ABC[C:C+8]}), 8);
              vbl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vbl${ABC[C:C+8]}), 8);
              vbr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vbr${ABC[C:C+8]}), 8);

        $for C in range(0, CHANNEL_TILE, 8):
          const __m128i vdr${ABC[C:C+8]} = _mm_sub_epi16(vbr${ABC[C:C+8]}, vtr${ABC[C:C+8]});
          const __m128i vt${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah);
          const __m128i vdl${ABC[C:C+8]} = _mm_sub_epi16(vbl${ABC[C:C+8]}, vtl${ABC[C:C+8]});
          const __m128i vt${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah);

        $for C in range(0, CHANNEL_TILE, 8):
          const __m128i vd${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah);
          const __m128i vd${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah);

        $if SSE == 4:
          $for C in range(0, CHANNEL_TILE, 4):
            __m128i vacc${ABC[C:C+4]} = _mm_mullo_epi32(vd${ABC[C:C+4]}, valphav);
        $else:
          $for C in range(0, CHANNEL_TILE, 4):
            __m128i vacc${ABC[C:C+4]} = _mm_slli_epi32(_mm_mulhi_epu16(vd${ABC[C:C+4]}, valphav), 16);

          $for C in range(0, CHANNEL_TILE, 4):
            vacc${ABC[C:C+4]} = _mm_add_epi16(_mm_mullo_epi16(vd${ABC[C:C+4]}, valphav), vacc${ABC[C:C+4]});

        $for C in range(0, CHANNEL_TILE, 4):
          vacc${ABC[C:C+4]} = _mm_add_epi32(_mm_slli_epi32(vt${ABC[C:C+4]}, 11), vacc${ABC[C:C+4]});

        $for C in range(0, CHANNEL_TILE, 4):
          vacc${ABC[C:C+4]} = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc${ABC[C:C+4]}, vrounding), 22);

        $for C in range(0, CHANNEL_TILE, 8):
          const __m128i vacc${ABC[C:C+8]} = _mm_packs_epi32(vacc${ABC[C:C+4]}, vacc${ABC[C+4:C+8]});

        $for C in range(0, CHANNEL_TILE, 16):
          $if C + 8 < CHANNEL_TILE:
            const __m128i vo${ABC[C:C+16]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]});
          $else:
            const __m128i vo${ABC[C:C+8]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C:C+8]});

        _mm_storeu_si128((__m128i*) output, vo${ABC[0:16]});
        $for C in range(16, CHANNEL_TILE, 16):
          $if C + 8 < CHANNEL_TILE:
            _mm_storeu_si128((__m128i*) (output + ${C}), vo${ABC[C:C+16]});
          $else:
            _mm_storel_epi64((__m128i*) (output + ${C}), vo${ABC[C:C+8]});
        output += ${CHANNEL_TILE};
      }
    for (; c >= 8 * sizeof(${XINT8_T}); c -= 8 * sizeof(${XINT8_T})) {
      $if SSE == 4:
        const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0));
        i0 += 8;
        const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1));
        i1 += 8;
        const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2));
        i2 += 8;
        const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3));
        i3 += 8;
      $else:
        __m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0);
        i0 += 8;
        __m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1);
        i1 += 8;
        __m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2);
        i2 += 8;
        __m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3);
        i3 += 8;

      $if SSE != 4:
        $if DATATYPE == "U8":
          __m128i vzero = _mm_setzero_si128();
          vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero);
          vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero);
          vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero);
          vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero);
        $else:
          vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8);
          vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8);
          vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8);
          vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8);

      const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567);
      const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah);
      const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567);
      const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah);

      const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah);
      const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah);

      $if SSE == 4:
        __m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav);
        __m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav);
      $else:
        __m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16);
        __m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16);

        vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123);
        vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567);

      vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123);
      vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567);

      vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22);
      vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22);

      const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567);

      const __m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567);

      _mm_storel_epi64((__m128i*) output, vo01234567);
      output += 8;
    }
    if XNN_UNLIKELY(c != 0) {
      $if SSE == 4:
        const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0));
        const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1));
        const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2));
        const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3));
      $else:
        __m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0);
        __m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1);
        __m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2);
        __m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3);

      $if SSE != 4:
        $if DATATYPE == "U8":
          __m128i vzero = _mm_setzero_si128();
          vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero);
          vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero);
          vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero);
          vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero);
        $else:
          vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8);
          vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8);
          vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8);
          vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8);

      const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567);
      const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah);
      const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567);
      const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah);

      const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah);
      const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah);

      $if SSE == 4:
        __m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav);
        __m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav);
      $else:
        __m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16);
        __m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16);

        vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123);
        vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567);

      vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123);
      vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567);

      vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22);
      vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22);

      const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567);

      __m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567);

      if (c & (4 * sizeof(${XINT8_T}))) {
        unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vo01234567));
        output += 4;
        vo01234567 = _mm_srli_epi64(vo01234567, 32);
      }
      $if SSE == 4:
        if (c & (2 * sizeof(${XINT8_T}))) {
          unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vo01234567, 0));
          output += 2;
          vo01234567 = _mm_srli_epi32(vo01234567, 16);
        }
        if (c & (1 * sizeof(${XINT8_T}))) {
          *output++ = (uint8_t) _mm_extract_epi8(vo01234567, 0);
        }
      $else:
        uint32_t vo0123 = (uint32_t) _mm_cvtsi128_si32(vo01234567);
        if (c & (2 * sizeof(${XINT8_T}))) {
          unaligned_store_u16(output, (uint16_t) vo0123);
          output += 2;
          vo0123 >>= 16;
        }
        if (c & (1 * sizeof(${XINT8_T}))) {
          *output++ = (uint8_t) vo0123;
        }
    }

    output = (${XINT8_T}*) ((uintptr_t) output + output_increment);
  } while (--output_pixels != 0);
}
