/*
 * Copyright 2024 Advanced Micro Devices, Inc.
 *
 * SPDX-License-Identifier: MIT
 */

/**
 * This pass:
 * - vectorizes lowered input/output loads and stores
 * - vectorizes low and high 16-bit loads and stores by merging them into
 *   a single 32-bit load or store (except load_interpolated_input, which has
 *   to keep bit_size=16)
 * - performs DCE of output stores that overwrite the previous value by writing
 *   into the same slot and component.
 *
 * Vectorization is only local within basic blocks. No vectorization occurs
 * across basic block boundaries, barriers (only TCS outputs), emits (only
 * GS outputs), and output load <-> output store dependencies.
 *
 * All loads and stores must be scalar. 64-bit loads and stores are forbidden.
 *
 * For each basic block, the time complexity is O(n*log(n)) where n is
 * the number of IO instructions within that block.
 */

#include "nir.h"
#include "nir_builder.h"
#include "util/u_dynarray.h"

/* Return 0 if loads/stores are vectorizable. Return 1 or -1 to define
 * an ordering between non-vectorizable instructions. This is used by qsort,
 * to sort all gathered instructions into groups of vectorizable instructions.
 */
static int
compare_is_not_vectorizable(nir_intrinsic_instr *a, nir_intrinsic_instr *b)
{
   if (a->intrinsic != b->intrinsic)
      return a->intrinsic > b->intrinsic ? 1 : -1;

   nir_src *offset0 = nir_get_io_offset_src(a);
   nir_src *offset1 = nir_get_io_offset_src(b);
   if (offset0 && offset0->ssa != offset1->ssa)
      return offset0->ssa->index > offset1->ssa->index ? 1 : -1;

   nir_src *array_idx0 = nir_get_io_arrayed_index_src(a);
   nir_src *array_idx1 = nir_get_io_arrayed_index_src(b);
   if (array_idx0 && array_idx0->ssa != array_idx1->ssa)
      return array_idx0->ssa->index > array_idx1->ssa->index ? 1 : -1;

   /* Compare barycentrics or vertex index. */
   if ((a->intrinsic == nir_intrinsic_load_interpolated_input ||
        a->intrinsic == nir_intrinsic_load_input_vertex) &&
       a->src[0].ssa != b->src[0].ssa)
      return a->src[0].ssa->index > b->src[0].ssa->index ? 1 : -1;

   nir_io_semantics sem0 = nir_intrinsic_io_semantics(a);
   nir_io_semantics sem1 = nir_intrinsic_io_semantics(b);
   if (sem0.location != sem1.location)
      return sem0.location > sem1.location ? 1 : -1;

   /* The mediump flag isn't mergable. */
   if (sem0.medium_precision != sem1.medium_precision)
      return sem0.medium_precision > sem1.medium_precision ? 1 : -1;

   /* Don't merge per-view attributes with non-per-view attributes. */
   if (sem0.per_view != sem1.per_view)
      return sem0.per_view > sem1.per_view ? 1 : -1;

   if (sem0.interp_explicit_strict != sem1.interp_explicit_strict)
      return sem0.interp_explicit_strict > sem1.interp_explicit_strict ? 1 : -1;

   /* Only load_interpolated_input can't merge low and high halves of 16-bit
    * loads/stores.
    */
   if (a->intrinsic == nir_intrinsic_load_interpolated_input &&
       sem0.high_16bits != sem1.high_16bits)
      return sem0.high_16bits > sem1.high_16bits ? 1 : -1;

   nir_shader *shader =
      nir_cf_node_get_function(&a->instr.block->cf_node)->function->shader;

   /* Compare the types. */
   if (!(shader->options->io_options & nir_io_vectorizer_ignores_types)) {
      unsigned type_a, type_b;

      if (nir_intrinsic_has_src_type(a)) {
         type_a = nir_intrinsic_src_type(a);
         type_b = nir_intrinsic_src_type(b);
      } else {
         type_a = nir_intrinsic_dest_type(a);
         type_b = nir_intrinsic_dest_type(b);
      }

      if (type_a != type_b)
         return type_a > type_b ? 1 : -1;
   }

   return 0;
}

static int
compare_intr(const void *xa, const void *xb)
{
   nir_intrinsic_instr *a = *(nir_intrinsic_instr **)xa;
   nir_intrinsic_instr *b = *(nir_intrinsic_instr **)xb;

   int comp = compare_is_not_vectorizable(a, b);
   if (comp)
      return comp;

   /* qsort isn't stable. This ensures that later stores aren't moved before earlier stores. */
   return a->instr.index > b->instr.index ? 1 : -1;
}

static void
vectorize_load(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
               bool merge_low_high_16_to_32)
{
   nir_intrinsic_instr *first = NULL;

   /* Find the first instruction where the vectorized load will be
    * inserted.
    */
   for (unsigned i = start; i < start + count; i++) {
      first = !first || chan[i]->instr.index < first->instr.index ?
                 chan[i] : first;
      if (merge_low_high_16_to_32) {
         first = !first || chan[4 + i]->instr.index < first->instr.index ?
                    chan[4 + i] : first;
      }
   }

   /* Insert the vectorized load. */
   nir_builder b = nir_builder_at(nir_before_instr(&first->instr));
   nir_intrinsic_instr *new_intr =
      nir_intrinsic_instr_create(b.shader, first->intrinsic);

   new_intr->num_components = count;
   nir_def_init(&new_intr->instr, &new_intr->def, count,
                merge_low_high_16_to_32 ? 32 : first->def.bit_size);
   memcpy(new_intr->src, first->src,
          nir_intrinsic_infos[first->intrinsic].num_srcs * sizeof(nir_src));
   nir_intrinsic_copy_const_indices(new_intr, first);
   nir_intrinsic_set_component(new_intr, start);

   if (merge_low_high_16_to_32) {
      nir_io_semantics sem = nir_intrinsic_io_semantics(new_intr);
      sem.high_16bits = 0;
      nir_intrinsic_set_io_semantics(new_intr, sem);
      nir_intrinsic_set_dest_type(new_intr,
                                  (nir_intrinsic_dest_type(new_intr) & ~16) | 32);
   }

   nir_builder_instr_insert(&b, &new_intr->instr);
   nir_def *def = &new_intr->def;

   /* Replace the scalar loads. */
   if (merge_low_high_16_to_32) {
      for (unsigned i = start; i < start + count; i++) {
         nir_def *comp = nir_channel(&b, def, i - start);

         nir_def_rewrite_uses(&chan[i]->def,
                              nir_unpack_32_2x16_split_x(&b, comp));
         nir_def_rewrite_uses(&chan[4 + i]->def,
                              nir_unpack_32_2x16_split_y(&b, comp));
         nir_instr_remove(&chan[i]->instr);
         nir_instr_remove(&chan[4 + i]->instr);
      }
   } else {
      for (unsigned i = start; i < start + count; i++) {
         nir_def_replace(&chan[i]->def, nir_channel(&b, def, i - start));
      }
   }
}

static void
vectorize_store(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
                bool merge_low_high_16_to_32)
{
   nir_intrinsic_instr *last = NULL;

   /* Find the last instruction where the vectorized store will be
    * inserted.
    */
   for (unsigned i = start; i < start + count; i++) {
      last = !last || chan[i]->instr.index > last->instr.index ?
                chan[i] : last;
      if (merge_low_high_16_to_32) {
         last = !last || chan[4 + i]->instr.index > last->instr.index ?
                   chan[4 + i] : last;
      }
   }

   /* Change the last instruction to a vectorized store. Update xfb first
    * because we need to read some info from "last" before overwriting it.
    */
   if (nir_intrinsic_has_io_xfb(last)) {
      nir_io_xfb xfb[2] = {{{{0}}}};

      for (unsigned i = start; i < start + count; i++) {
         xfb[i / 2].out[i % 2] =
            (i < 2 ? nir_intrinsic_io_xfb(chan[i]) :
                     nir_intrinsic_io_xfb2(chan[i])).out[i % 2];

         /* Merging low and high 16 bits to 32 bits is not possible
          * with xfb in some cases. (and it's not implemented for
          * cases where it's possible)
          */
         assert(!xfb[i / 2].out[i % 2].num_components ||
                !merge_low_high_16_to_32);
      }

      /* Now vectorize xfb info by merging the individual elements. */
      for (unsigned i = start; i < start + count; i++) {
         /* mediump means that xfb upconverts to 32 bits when writing to
          * memory.
          */
         unsigned xfb_comp_size =
            nir_intrinsic_io_semantics(chan[i]).medium_precision ?
                  32 : chan[i]->src[0].ssa->bit_size;

         for (unsigned j = i + 1; j < start + count; j++) {
            if (xfb[i / 2].out[i % 2].buffer != xfb[j / 2].out[j % 2].buffer ||
                xfb[i / 2].out[i % 2].offset != xfb[j / 2].out[j % 2].offset +
                xfb_comp_size * (j - i))
               break;

            xfb[i / 2].out[i % 2].num_components++;
            memset(&xfb[j / 2].out[j % 2], 0, sizeof(xfb[j / 2].out[j % 2]));
         }
      }

      nir_intrinsic_set_io_xfb(last, xfb[0]);
      nir_intrinsic_set_io_xfb2(last, xfb[1]);
   }

   /* Update gs_streams. */
   unsigned gs_streams = 0;
   for (unsigned i = start; i < start + count; i++) {
      gs_streams |= (nir_intrinsic_io_semantics(chan[i]).gs_streams & 0x3) <<
                    ((i - start) * 2);
   }

   nir_io_semantics sem = nir_intrinsic_io_semantics(last);
   sem.gs_streams = gs_streams;

   /* Update other flags. */
   for (unsigned i = start; i < start + count; i++) {
      if (!nir_intrinsic_io_semantics(chan[i]).no_sysval_output)
         sem.no_sysval_output = 0;
      if (!nir_intrinsic_io_semantics(chan[i]).no_varying)
         sem.no_varying = 0;
      if (nir_intrinsic_io_semantics(chan[i]).invariant)
         sem.invariant = 1;
   }

   if (merge_low_high_16_to_32) {
      /* Update "no" flags for high bits. */
      for (unsigned i = start; i < start + count; i++) {
         if (!nir_intrinsic_io_semantics(chan[4 + i]).no_sysval_output)
            sem.no_sysval_output = 0;
         if (!nir_intrinsic_io_semantics(chan[4 + i]).no_varying)
            sem.no_varying = 0;
         if (nir_intrinsic_io_semantics(chan[4 + i]).invariant)
            sem.invariant = 1;
      }

      /* Update the type. */
      sem.high_16bits = 0;
      nir_intrinsic_set_src_type(last,
                                 (nir_intrinsic_src_type(last) & ~16) | 32);
   }

   /* TODO: Merge names? */

   /* Update the rest. */
   nir_intrinsic_set_io_semantics(last, sem);
   nir_intrinsic_set_component(last, start);
   nir_intrinsic_set_write_mask(last, BITFIELD_MASK(count));
   last->num_components = count;

   nir_builder b = nir_builder_at(nir_before_instr(&last->instr));

   /* Replace the stored scalar with the vector. */
   if (merge_low_high_16_to_32) {
      nir_def *value[4];
      for (unsigned i = start; i < start + count; i++) {
         value[i] = nir_pack_32_2x16_split(&b, chan[i]->src[0].ssa,
                                           chan[4 + i]->src[0].ssa);
      }

      nir_src_rewrite(&last->src[0], nir_vec(&b, &value[start], count));
   } else {
      nir_def *value[4];
      for (unsigned i = start; i < start + count; i++)
         value[i] = chan[i]->src[0].ssa;

      nir_src_rewrite(&last->src[0], nir_vec(&b, &value[start], count));
   }

   /* Remove the scalar stores. */
   for (unsigned i = start; i < start + count; i++) {
      if (chan[i] != last)
         nir_instr_remove(&chan[i]->instr);
      if (merge_low_high_16_to_32 && chan[4 + i] != last)
         nir_instr_remove(&chan[4 + i]->instr);
   }
}

/* Vectorize a vector of scalar instructions. chan[8] are the channels.
 * (the last 4 are the high 16-bit channels)
 */
static bool
vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
{
   bool progress = false;

   /* First, merge low and high 16-bit halves into 32 bits separately when
    * possible. Then vectorize what's left.
    */
   for (int merge_low_high_16_to_32 = 1; merge_low_high_16_to_32 >= 0;
        merge_low_high_16_to_32--) {
      unsigned scan_mask;

      if (merge_low_high_16_to_32) {
         /* Get the subset of the mask where both low and high bits are set. */
         scan_mask = 0;
         for (unsigned i = 0; i < 4; i++) {
            unsigned low_high_bits = BITFIELD_BIT(i) | BITFIELD_BIT(i + 4);

            if ((mask & low_high_bits) == low_high_bits) {
               /* Merging low and high 16 bits to 32 bits is not possible
                * with xfb in some cases. (and it's not implemented for
                * cases where it's possible)
                */
               if (nir_intrinsic_has_io_xfb(chan[i])) {
                  unsigned hi = i + 4;

                  if ((i < 2 ? nir_intrinsic_io_xfb(chan[i])
                             : nir_intrinsic_io_xfb2(chan[i])).out[i % 2].num_components ||
                      (i < 2 ? nir_intrinsic_io_xfb(chan[hi])
                             : nir_intrinsic_io_xfb2(chan[hi])).out[i % 2].num_components)
                     continue;
               }

               /* The GS stream must be the same for both halves. */
               if ((nir_intrinsic_io_semantics(chan[i]).gs_streams & 0x3) !=
                   (nir_intrinsic_io_semantics(chan[4 + i]).gs_streams & 0x3))
                  continue;

               scan_mask |= BITFIELD_BIT(i);
               mask &= ~low_high_bits;
            }
         }
      } else {
         scan_mask = mask;
      }

      while (scan_mask) {
         int start, count;

         u_bit_scan_consecutive_range(&scan_mask, &start, &count);

         if (count == 1 && !merge_low_high_16_to_32)
            continue; /* There is nothing to vectorize. */

         bool is_load = nir_intrinsic_infos[chan[start]->intrinsic].has_dest;

         if (is_load)
            vectorize_load(chan, start, count, merge_low_high_16_to_32);
         else
            vectorize_store(chan, start, count, merge_low_high_16_to_32);

         progress = true;
      }
   }

   return progress;
}

static bool
vectorize_batch(struct util_dynarray *io_instructions)
{
   unsigned num_instr = util_dynarray_num_elements(io_instructions, void *);

   /* We need to at least 2 instructions to have something to do. */
   if (num_instr <= 1) {
      /* Clear the array. The next block will reuse it. */
      util_dynarray_clear(io_instructions);
      return false;
   }

   /* The instructions are sorted such that groups of vectorizable
    * instructions are next to each other. Multiple incompatible
    * groups of vectorizable instructions can occur in this array.
    * The reason why 2 groups would be incompatible is that they
    * could have a different intrinsic, indirect index, array index,
    * vertex index, barycentrics, or location. Each group is vectorized
    * separately.
    *
    * This reorders instructions in the array, but not in the shader.
    */
   qsort(io_instructions->data, num_instr, sizeof(void*), compare_intr);

   nir_intrinsic_instr *chan[8] = {0}, *prev = NULL;
   unsigned chan_mask = 0;
   bool progress = false;

   /* Vectorize all groups.
    *
    * The channels for each group are gathered. If 2 stores overwrite
    * the same channel, the earlier store is DCE'd here.
    */
   util_dynarray_foreach(io_instructions, nir_intrinsic_instr *, intr) {
      /* If the next instruction is not vectorizable, vectorize what
       * we have gathered so far.
       */
      if (prev && compare_is_not_vectorizable(prev, *intr)) {
         /* We need at least 2 instructions to have something to do. */
         if (util_bitcount(chan_mask) > 1)
            progress |= vectorize_slot(chan, chan_mask);

         prev = NULL;
         memset(chan, 0, sizeof(chan));
         chan_mask = 0;
      }

      /* This performs DCE of output stores because the previous value
       * is being overwritten.
       */
      unsigned index = nir_intrinsic_io_semantics(*intr).high_16bits * 4 +
                       nir_intrinsic_component(*intr);
      bool is_store = !nir_intrinsic_infos[(*intr)->intrinsic].has_dest;
      if (is_store && chan[index])
         nir_instr_remove(&chan[index]->instr);

      /* Gather the channel. */
      chan[index] = *intr;
      prev = *intr;
      chan_mask |= BITFIELD_BIT(index);
   }

   /* Vectorize the last group. */
   if (prev && util_bitcount(chan_mask) > 1)
      progress |= vectorize_slot(chan, chan_mask);

   /* Clear the array. The next block will reuse it. */
   util_dynarray_clear(io_instructions);
   return progress;
}

bool
nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
{
   assert(!(modes & ~(nir_var_shader_in | nir_var_shader_out)));

   if (shader->info.stage == MESA_SHADER_FRAGMENT &&
       shader->options->io_options & nir_io_prefer_scalar_fs_inputs)
      modes &= ~nir_var_shader_in;

   if ((shader->info.stage == MESA_SHADER_TESS_CTRL ||
        shader->info.stage == MESA_SHADER_GEOMETRY) &&
       util_bitcount(modes) == 2) {
      /* When vectorizing TCS and GS IO, inputs can ignore barriers and emits,
       * but that is only done when outputs are ignored, so vectorize them
       * separately.
       */
      return nir_opt_vectorize_io(shader, nir_var_shader_in) ||
             nir_opt_vectorize_io(shader, nir_var_shader_out);
   }

   /* Initialize dynamic arrays. */
   struct util_dynarray io_instructions;
   util_dynarray_init(&io_instructions, NULL);
   bool global_progress = false;

   nir_foreach_function_impl(impl, shader) {
      bool progress = false;
      nir_metadata_require(impl, nir_metadata_instr_index);

      nir_foreach_block(block, impl) {
         BITSET_DECLARE(has_output_loads, NUM_TOTAL_VARYING_SLOTS * 8);
         BITSET_DECLARE(has_output_stores, NUM_TOTAL_VARYING_SLOTS * 8);
         BITSET_ZERO(has_output_loads);
         BITSET_ZERO(has_output_stores);

         /* Gather load/store intrinsics within the block. */
         nir_foreach_instr(instr, block) {
            if (instr->type != nir_instr_type_intrinsic)
               continue;

            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
            bool is_load = nir_intrinsic_infos[intr->intrinsic].has_dest;
            bool is_output = false;
            nir_io_semantics sem = {0};
            unsigned index = 0;

            if (nir_intrinsic_has_io_semantics(intr)) {
               sem = nir_intrinsic_io_semantics(intr);
               assert(sem.location < NUM_TOTAL_VARYING_SLOTS);
               index = sem.location * 8 + sem.high_16bits * 4 +
                       nir_intrinsic_component(intr);
            }

            switch (intr->intrinsic) {
            case nir_intrinsic_load_input:
            case nir_intrinsic_load_per_primitive_input:
            case nir_intrinsic_load_input_vertex:
            case nir_intrinsic_load_interpolated_input:
            case nir_intrinsic_load_per_vertex_input:
               if (!(modes & nir_var_shader_in))
                  continue;
               break;

            case nir_intrinsic_load_output:
            case nir_intrinsic_load_per_vertex_output:
            case nir_intrinsic_load_per_primitive_output:
            case nir_intrinsic_store_output:
            case nir_intrinsic_store_per_vertex_output:
            case nir_intrinsic_store_per_primitive_output:
               if (!(modes & nir_var_shader_out))
                  continue;

               /* Break the batch if an output load is followed by an output
                * store to the same channel and vice versa.
                */
               if (BITSET_TEST(is_load ? has_output_stores : has_output_loads,
                               index)) {
                  progress |= vectorize_batch(&io_instructions);
                  BITSET_ZERO(has_output_loads);
                  BITSET_ZERO(has_output_stores);
               }
               is_output = true;
               break;

            case nir_intrinsic_barrier:
               /* Don't vectorize across TCS barriers. */
               if (modes & nir_var_shader_out &&
                   nir_intrinsic_memory_modes(intr) & nir_var_shader_out) {
                  progress |= vectorize_batch(&io_instructions);
                  BITSET_ZERO(has_output_loads);
                  BITSET_ZERO(has_output_stores);
               }
               continue;

            case nir_intrinsic_emit_vertex:
               /* Don't vectorize across GS emits. */
               progress |= vectorize_batch(&io_instructions);
               BITSET_ZERO(has_output_loads);
               BITSET_ZERO(has_output_stores);
               continue;

            default:
               continue;
            }

            /* Only scalar 16 and 32-bit instructions are allowed. */
            ASSERTED nir_def *value = is_load ? &intr->def : intr->src[0].ssa;
            assert(value->num_components == 1);
            assert(value->bit_size == 16 || value->bit_size == 32);

            util_dynarray_append(&io_instructions, void *, intr);
            if (is_output)
               BITSET_SET(is_load ? has_output_loads : has_output_stores, index);
         }

         progress |= vectorize_batch(&io_instructions);
      }

      nir_metadata_preserve(impl, progress ? (nir_metadata_block_index |
                                              nir_metadata_dominance) :
                                             nir_metadata_all);
      global_progress |= progress;
   }
   util_dynarray_fini(&io_instructions);

   return global_progress;
}
