/*
 * Copyright © 2021 Valve Corporation
 * SPDX-License-Identifier: MIT
 */

#include "ir3.h"
#include "ir3_nir.h"
#include "util/ralloc.h"

/* Lower several macro-instructions needed for shader subgroup support that
 * must be turned into if statements. We do this after RA and post-RA
 * scheduling to give the scheduler a chance to rearrange them, because RA
 * may need to insert OPC_META_READ_FIRST to handle splitting live ranges, and
 * also because some (e.g. BALLOT and READ_FIRST) must produce a shared
 * register that cannot be spilled to a normal register until after the if,
 * which makes implementing spilling more complicated if they are already
 * lowered.
 */

static void
replace_pred(struct ir3_block *block, struct ir3_block *old_pred,
             struct ir3_block *new_pred)
{
   for (unsigned i = 0; i < block->predecessors_count; i++) {
      if (block->predecessors[i] == old_pred) {
         block->predecessors[i] = new_pred;
         return;
      }
   }
}

static void
replace_physical_pred(struct ir3_block *block, struct ir3_block *old_pred,
                      struct ir3_block *new_pred)
{
   for (unsigned i = 0; i < block->physical_predecessors_count; i++) {
      if (block->physical_predecessors[i] == old_pred) {
         block->physical_predecessors[i] = new_pred;
         return;
      }
   }
}

static void
mov_immed(struct ir3_register *dst, struct ir3_block *block, unsigned immed)
{
   struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);
   struct ir3_register *mov_dst = ir3_dst_create(mov, dst->num, dst->flags);
   mov_dst->wrmask = dst->wrmask;
   struct ir3_register *src = ir3_src_create(
      mov, INVALID_REG, (dst->flags & IR3_REG_HALF) | IR3_REG_IMMED);
   src->uim_val = immed;
   mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
   mov->cat1.src_type = mov->cat1.dst_type;
   mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
}

static void
mov_reg(struct ir3_block *block, struct ir3_register *dst,
        struct ir3_register *src)
{
   struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);

   struct ir3_register *mov_dst =
      ir3_dst_create(mov, dst->num, dst->flags & (IR3_REG_HALF | IR3_REG_SHARED));
   struct ir3_register *mov_src =
      ir3_src_create(mov, src->num, src->flags & (IR3_REG_HALF | IR3_REG_SHARED));
   mov_dst->wrmask = dst->wrmask;
   mov_src->wrmask = src->wrmask;
   mov->repeat = util_last_bit(mov_dst->wrmask) - 1;

   mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
   mov->cat1.src_type = (src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
}

static void
binop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
      struct ir3_register *src0, struct ir3_register *src1)
{
   struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 2);
   
   unsigned flags = dst->flags & IR3_REG_HALF;
   struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
   struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
   struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);

   instr_dst->wrmask = dst->wrmask;
   instr_src0->wrmask = src0->wrmask;
   instr_src1->wrmask = src1->wrmask;
   instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
}

static void
triop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
      struct ir3_register *src0, struct ir3_register *src1,
      struct ir3_register *src2)
{
   struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 3);
   
   unsigned flags = dst->flags & IR3_REG_HALF;
   struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
   struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
   struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
   struct ir3_register *instr_src2 = ir3_src_create(instr, src2->num, flags);

   instr_dst->wrmask = dst->wrmask;
   instr_src0->wrmask = src0->wrmask;
   instr_src1->wrmask = src1->wrmask;
   instr_src2->wrmask = src2->wrmask;
   instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
}

static void
do_reduce(struct ir3_block *block, reduce_op_t opc,
          struct ir3_register *dst, struct ir3_register *src0,
          struct ir3_register *src1)
{
   switch (opc) {
#define CASE(name)                                                             \
   case REDUCE_OP_##name:                                                      \
      binop(block, OPC_##name, dst, src0, src1);                               \
      break;

   CASE(ADD_U)
   CASE(ADD_F)
   CASE(MUL_F)
   CASE(MIN_U)
   CASE(MIN_S)
   CASE(MIN_F)
   CASE(MAX_U)
   CASE(MAX_S)
   CASE(MAX_F)
   CASE(AND_B)
   CASE(OR_B)
   CASE(XOR_B)

#undef CASE

   case REDUCE_OP_MUL_U:
      if (dst->flags & IR3_REG_HALF) {
         binop(block, OPC_MUL_S24, dst, src0, src1);
      } else {
         /* 32-bit multiplication macro - see ir3_nir_imul */
         binop(block, OPC_MULL_U, dst, src0, src1);
         triop(block, OPC_MADSH_M16, dst, src0, src1, dst);
         triop(block, OPC_MADSH_M16, dst, src1, src0, dst);
      }
      break;
   }
}

static struct ir3_block *
split_block(struct ir3 *ir, struct ir3_block *before_block,
            struct ir3_instruction *instr)
{
   struct ir3_block *after_block = ir3_block_create(ir);
   list_add(&after_block->node, &before_block->node);

   for (unsigned i = 0; i < ARRAY_SIZE(before_block->successors); i++) {
      after_block->successors[i] = before_block->successors[i];
      if (after_block->successors[i])
         replace_pred(after_block->successors[i], before_block, after_block);
   }

   for (unsigned i = 0; i < before_block->physical_successors_count; i++) {
      replace_physical_pred(before_block->physical_successors[i],
                            before_block, after_block);
   }

   ralloc_steal(after_block, before_block->physical_successors);
   after_block->physical_successors = before_block->physical_successors;
   after_block->physical_successors_sz = before_block->physical_successors_sz;
   after_block->physical_successors_count =
      before_block->physical_successors_count;

   before_block->successors[0] = before_block->successors[1] = NULL;
   before_block->physical_successors = NULL;
   before_block->physical_successors_count = 0;
   before_block->physical_successors_sz = 0;

   foreach_instr_from_safe (rem_instr, &instr->node,
                            &before_block->instr_list) {
      list_del(&rem_instr->node);
      list_addtail(&rem_instr->node, &after_block->instr_list);
      rem_instr->block = after_block;
   }

   after_block->divergent_condition = before_block->divergent_condition;
   before_block->divergent_condition = false;
   return after_block;
}

static void
link_blocks(struct ir3_block *pred, struct ir3_block *succ, unsigned index)
{
   pred->successors[index] = succ;
   ir3_block_add_predecessor(succ, pred);
   ir3_block_link_physical(pred, succ);
}

static void
link_blocks_jump(struct ir3_block *pred, struct ir3_block *succ)
{
   ir3_JUMP(pred);
   link_blocks(pred, succ, 0);
}

static void
link_blocks_branch(struct ir3_block *pred, struct ir3_block *target,
                   struct ir3_block *fallthrough, unsigned opc, unsigned flags,
                   struct ir3_instruction *condition)
{
   unsigned nsrc = condition ? 1 : 0;
   struct ir3_instruction *branch = ir3_instr_create(pred, opc, 0, nsrc);
   branch->flags |= flags;

   if (condition) {
      struct ir3_register *cond_dst = condition->dsts[0];
      struct ir3_register *src =
         ir3_src_create(branch, cond_dst->num, cond_dst->flags);
      src->def = cond_dst;
   }

   link_blocks(pred, target, 0);
   link_blocks(pred, fallthrough, 1);

   if (opc != OPC_BALL && opc != OPC_BANY) {
      pred->divergent_condition = true;
   }
}

static struct ir3_block *
create_if(struct ir3 *ir, struct ir3_block *before_block,
          struct ir3_block *after_block, unsigned opc, unsigned flags,
          struct ir3_instruction *condition)
{
   struct ir3_block *then_block = ir3_block_create(ir);
   list_add(&then_block->node, &before_block->node);

   link_blocks_branch(before_block, then_block, after_block, opc, flags,
                      condition);
   link_blocks_jump(then_block, after_block);

   return then_block;
}

static bool
lower_instr(struct ir3 *ir, struct ir3_block **block, struct ir3_instruction *instr)
{
   switch (instr->opc) {
   case OPC_BALLOT_MACRO:
   case OPC_ANY_MACRO:
   case OPC_ALL_MACRO:
   case OPC_ELECT_MACRO:
   case OPC_READ_COND_MACRO:
   case OPC_SCAN_MACRO:
   case OPC_SCAN_CLUSTERS_MACRO:
      break;
   case OPC_READ_FIRST_MACRO:
      /* Moves to shared registers read the first active fiber, so we can just
       * turn read_first.macro into a move. However we must still use the macro
       * and lower it late because in ir3_cp we need to distinguish between
       * moves where all source fibers contain the same value, which can be copy
       * propagated, and moves generated from API-level ReadFirstInvocation
       * which cannot.
       */
      assert(instr->dsts[0]->flags & IR3_REG_SHARED);
      instr->opc = OPC_MOV;
      instr->cat1.dst_type = TYPE_U32;
      instr->cat1.src_type =
         (instr->srcs[0]->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
      return false;
   default:
      return false;
   }

   struct ir3_block *before_block = *block;
   struct ir3_block *after_block = split_block(ir, before_block, instr);

   if (instr->opc == OPC_SCAN_MACRO) {
      /* The pseudo-code for the scan macro is:
       *
       * while (true) {
       *    header:
       *    if (elect()) {
       *       exit:
       *       exclusive = reduce;
       *       inclusive = src OP exclusive;
       *       reduce = inclusive;
       *       break;
       *    }
       *    footer:
       * }
       *
       * This is based on the blob's sequence, and carefully crafted to avoid
       * using the shared register "reduce" except in move instructions, since
       * using it in the actual OP isn't possible for half-registers.
       */
      struct ir3_block *header = ir3_block_create(ir);
      list_add(&header->node, &before_block->node);

      struct ir3_block *exit = ir3_block_create(ir);
      list_add(&exit->node, &header->node);

      struct ir3_block *footer = ir3_block_create(ir);
      list_add(&footer->node, &exit->node);
      footer->reconvergence_point = true;

      after_block->reconvergence_point = true;

      link_blocks_jump(before_block, header);

      link_blocks_branch(header, exit, footer, OPC_GETONE,
                         IR3_INSTR_NEEDS_HELPERS, NULL);

      link_blocks_jump(exit, after_block);
      ir3_block_link_physical(exit, footer);

      link_blocks_jump(footer, header);

      struct ir3_register *exclusive = instr->dsts[0];
      struct ir3_register *inclusive = instr->dsts[1];
      struct ir3_register *reduce = instr->dsts[2];
      struct ir3_register *src = instr->srcs[0];

      mov_reg(exit, exclusive, reduce);
      do_reduce(exit, instr->cat1.reduce_op, inclusive, src, exclusive);
      mov_reg(exit, reduce, inclusive);
   } else if (instr->opc == OPC_SCAN_CLUSTERS_MACRO) {
      /* The pseudo-code for the scan macro is:
       *
       * while (true) {
       *    body:
       *    scratch = reduce;
       *
       *    inclusive = inclusive_src OP scratch;
       *
       *    static if (is exclusive scan)
       *       exclusive = exclusive_src OP scratch
       *
       *    if (getlast()) {
       *       store:
       *       reduce = inclusive;
       *       if (elect())
       *           break;
       *    } else {
       *       break;
       *    }
       * }
       * after_block:
       */
      struct ir3_block *body = ir3_block_create(ir);
      list_add(&body->node, &before_block->node);

      struct ir3_block *store = ir3_block_create(ir);
      list_add(&store->node, &body->node);

      after_block->reconvergence_point = true;

      link_blocks_jump(before_block, body);

      link_blocks_branch(body, store, after_block, OPC_GETLAST, 0, NULL);

      link_blocks_branch(store, after_block, body, OPC_GETONE,
                         IR3_INSTR_NEEDS_HELPERS, NULL);

      struct ir3_register *reduce = instr->dsts[0];
      struct ir3_register *inclusive = instr->dsts[1];
      struct ir3_register *inclusive_src = instr->srcs[1];

      /* We need to perform the following operations:
       *  - inclusive = inclusive_src OP reduce
       *  - exclusive = exclusive_src OP reduce (iff exclusive scan)
       * Since reduce is initially in a shared register, we need to copy it to a
       * scratch register before performing the operations.
       *
       * The scratch register used is:
       *  - an explicitly allocated one if op is 32b mul_u.
       *    - necessary because we cannot do 'foo = foo mul_u bar' since mul_u
       *      clobbers its destination.
       *  - exclusive if this is an exclusive scan (and not 32b mul_u).
       *    - since we calculate inclusive first.
       *  - inclusive otherwise.
       *
       * In all cases, this is the last destination.
       */
      struct ir3_register *scratch = instr->dsts[instr->dsts_count - 1];

      mov_reg(body, scratch, reduce);
      do_reduce(body, instr->cat1.reduce_op, inclusive, inclusive_src, scratch);

      /* exclusive scan */
      if (instr->srcs_count == 3) {
         struct ir3_register *exclusive_src = instr->srcs[2];
         struct ir3_register *exclusive = instr->dsts[2];
         do_reduce(body, instr->cat1.reduce_op, exclusive, exclusive_src,
                   scratch);
      }

      mov_reg(store, reduce, inclusive);
   } else {
      /* For ballot, the destination must be initialized to 0 before we do
       * the movmsk because the condition may be 0 and then the movmsk will
       * be skipped.
       */
      if (instr->opc == OPC_BALLOT_MACRO) {
         mov_immed(instr->dsts[0], before_block, 0);
      }

      struct ir3_instruction *condition = NULL;
      unsigned branch_opc = 0;
      unsigned branch_flags = 0;

      switch (instr->opc) {
      case OPC_BALLOT_MACRO:
      case OPC_READ_COND_MACRO:
      case OPC_ANY_MACRO:
      case OPC_ALL_MACRO:
         condition = instr->srcs[0]->def->instr;
         break;
      default:
         break;
      }

      switch (instr->opc) {
      case OPC_BALLOT_MACRO:
      case OPC_READ_COND_MACRO:
         after_block->reconvergence_point = true;
         branch_opc = OPC_BR;
         break;
      case OPC_ANY_MACRO:
         branch_opc = OPC_BANY;
         break;
      case OPC_ALL_MACRO:
         branch_opc = OPC_BALL;
         break;
      case OPC_ELECT_MACRO:
         after_block->reconvergence_point = true;
         branch_opc = OPC_GETONE;
         branch_flags = instr->flags & IR3_INSTR_NEEDS_HELPERS;
         break;
      default:
         unreachable("bad opcode");
      }

      struct ir3_block *then_block =
         create_if(ir, before_block, after_block, branch_opc, branch_flags,
                   condition);

      switch (instr->opc) {
      case OPC_ALL_MACRO:
      case OPC_ANY_MACRO:
      case OPC_ELECT_MACRO:
         mov_immed(instr->dsts[0], then_block, 1);
         mov_immed(instr->dsts[0], before_block, 0);
         break;

      case OPC_BALLOT_MACRO: {
         unsigned comp_count = util_last_bit(instr->dsts[0]->wrmask);
         struct ir3_instruction *movmsk =
            ir3_instr_create(then_block, OPC_MOVMSK, 1, 0);
         ir3_dst_create(movmsk, instr->dsts[0]->num, instr->dsts[0]->flags);
         movmsk->repeat = comp_count - 1;
         break;
      }

      case OPC_READ_COND_MACRO: {
         struct ir3_instruction *mov =
            ir3_instr_create(then_block, OPC_MOV, 1, 1);
         ir3_dst_create(mov, instr->dsts[0]->num, instr->dsts[0]->flags);
         struct ir3_register *new_src = ir3_src_create(mov, 0, 0);
         *new_src = *instr->srcs[1];
         mov->cat1.dst_type = TYPE_U32;
         mov->cat1.src_type =
            (new_src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
         mov->flags |= IR3_INSTR_NEEDS_HELPERS;
         break;
      }

      default:
         unreachable("bad opcode");
      }
   }

   *block = after_block;
   list_delinit(&instr->node);
   return true;
}

static bool
lower_block(struct ir3 *ir, struct ir3_block **block)
{
   bool progress = true;

   bool inner_progress;
   do {
      inner_progress = false;
      foreach_instr (instr, &(*block)->instr_list) {
         if (lower_instr(ir, block, instr)) {
            /* restart the loop with the new block we created because the
             * iterator has been invalidated.
             */
            progress = inner_progress = true;
            break;
         }
      }
   } while (inner_progress);

   return progress;
}

bool
ir3_lower_subgroups(struct ir3 *ir)
{
   bool progress = false;

   foreach_block (block, &ir->block_list)
      progress |= lower_block(ir, &block);

   return progress;
}

static bool
filter_scan_reduce(const nir_instr *instr, const void *data)
{
   if (instr->type != nir_instr_type_intrinsic)
      return false;

   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);

   switch (intrin->intrinsic) {
   case nir_intrinsic_reduce:
   case nir_intrinsic_inclusive_scan:
   case nir_intrinsic_exclusive_scan:
      return true;
   default:
      return false;
   }
}

static nir_def *
lower_scan_reduce(struct nir_builder *b, nir_instr *instr, void *data)
{
   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
   unsigned bit_size = intrin->def.bit_size;

   nir_op op = nir_intrinsic_reduction_op(intrin);
   nir_const_value ident_val = nir_alu_binop_identity(op, bit_size);
   nir_def *ident = nir_build_imm(b, 1, bit_size, &ident_val);
   nir_def *inclusive = intrin->src[0].ssa;
   nir_def *exclusive = ident;

   for (unsigned cluster_size = 2; cluster_size <= 8; cluster_size *= 2) {
      nir_def *brcst = nir_brcst_active_ir3(b, ident, inclusive,
                                            .cluster_size = cluster_size);
      inclusive = nir_build_alu2(b, op, inclusive, brcst);

      if (intrin->intrinsic == nir_intrinsic_exclusive_scan)
         exclusive = nir_build_alu2(b, op, exclusive, brcst);
   }

   switch (intrin->intrinsic) {
   case nir_intrinsic_reduce:
      return nir_reduce_clusters_ir3(b, inclusive, .reduction_op = op);
   case nir_intrinsic_inclusive_scan:
      return nir_inclusive_scan_clusters_ir3(b, inclusive, .reduction_op = op);
   case nir_intrinsic_exclusive_scan:
      return nir_exclusive_scan_clusters_ir3(b, inclusive, exclusive,
                                             .reduction_op = op);
   default:
      unreachable("filtered intrinsic");
   }
}

bool
ir3_nir_opt_subgroups(nir_shader *nir, struct ir3_shader_variant *v)
{
   if (!v->compiler->has_getfiberid)
      return false;

   return nir_shader_lower_instructions(nir, filter_scan_reduce,
                                        lower_scan_reduce, NULL);
}
