/*
 * Copyright © 2020 Corporation
 *
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the "Software"),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
 * and/or sell copies of the Software, and to permit persons to whom the
 * Software is furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice (including the next
 * paragraph) shall be included in all copies or substantial portions of the
 * Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 * IN THE SOFTWARE.
 */

#include "brw_kernel.h"
#include "brw_nir.h"
#include "elk/elk_nir_options.h"
#include "intel_nir.h"

#include "intel_nir.h"
#include "nir_clc_helpers.h"
#include "compiler/nir/nir_builder.h"
#include "compiler/spirv/nir_spirv.h"
#include "compiler/spirv/spirv_info.h"
#include "dev/intel_debug.h"
#include "util/u_atomic.h"
#include "util/u_dynarray.h"

static const nir_shader *
load_clc_shader(struct brw_compiler *compiler, struct disk_cache *disk_cache,
                const nir_shader_compiler_options *nir_options,
                const struct spirv_to_nir_options *spirv_options)
{
   if (compiler->clc_shader)
      return compiler->clc_shader;

   nir_shader *nir =  nir_load_libclc_shader(64, disk_cache,
                                             spirv_options, nir_options,
                                             disk_cache != NULL);
   if (nir == NULL)
      return NULL;

   const nir_shader *old_nir =
      p_atomic_cmpxchg(&compiler->clc_shader, NULL, nir);
   if (old_nir == NULL) {
      /* We won the race */
      ralloc_steal(compiler, nir);
      return nir;
   } else {
      /* Someone else built the shader first */
      ralloc_free(nir);
      return old_nir;
   }
}

static nir_builder
builder_init_new_impl(nir_function *func)
{
   nir_function_impl *impl = nir_function_impl_create(func);
   return nir_builder_at(nir_before_impl(impl));
}

static void
implement_atomic_builtin(nir_function *func, nir_atomic_op atomic_op,
                         enum glsl_base_type data_base_type,
                         nir_variable_mode mode)
{
   nir_builder b = builder_init_new_impl(func);
   const struct glsl_type *data_type = glsl_scalar_type(data_base_type);

   unsigned p = 0;

   nir_deref_instr *ret = NULL;
   ret = nir_build_deref_cast(&b, nir_load_param(&b, p++),
                              nir_var_function_temp, data_type, 0);

   nir_intrinsic_op op = nir_intrinsic_deref_atomic;
   nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b.shader, op);
   nir_intrinsic_set_atomic_op(atomic, atomic_op);

   for (unsigned i = 0; i < nir_intrinsic_infos[op].num_srcs; i++) {
      nir_def *src = nir_load_param(&b, p++);
      if (i == 0) {
         /* The first source is our deref */
         assert(nir_intrinsic_infos[op].src_components[i] == -1);
         src = &nir_build_deref_cast(&b, src, mode, data_type, 0)->def;
      }
      atomic->src[i] = nir_src_for_ssa(src);
   }

   nir_def_init_for_type(&atomic->instr, &atomic->def, data_type);

   nir_builder_instr_insert(&b, &atomic->instr);
   nir_store_deref(&b, ret, &atomic->def, ~0);
}

static void
implement_sub_group_ballot_builtin(nir_function *func)
{
   nir_builder b = builder_init_new_impl(func);
   nir_deref_instr *ret =
      nir_build_deref_cast(&b, nir_load_param(&b, 0),
                           nir_var_function_temp, glsl_uint_type(), 0);
   nir_def *cond = nir_load_param(&b, 1);

   nir_intrinsic_instr *ballot =
      nir_intrinsic_instr_create(b.shader, nir_intrinsic_ballot);
   ballot->src[0] = nir_src_for_ssa(cond);
   ballot->num_components = 1;
   nir_def_init(&ballot->instr, &ballot->def, 1, 32);
   nir_builder_instr_insert(&b, &ballot->instr);

   nir_store_deref(&b, ret, &ballot->def, ~0);
}

static bool
implement_intel_builtins(nir_shader *nir)
{
   bool progress = false;

   nir_foreach_function(func, nir) {
      if (strcmp(func->name, "_Z10atomic_minPU3AS1Vff") == 0) {
         /* float atom_min(__global float volatile *p, float val) */
         implement_atomic_builtin(func, nir_atomic_op_fmin,
                                  GLSL_TYPE_FLOAT, nir_var_mem_global);
         progress = true;
      } else if (strcmp(func->name, "_Z10atomic_maxPU3AS1Vff") == 0) {
         /* float atom_max(__global float volatile *p, float val) */
         implement_atomic_builtin(func, nir_atomic_op_fmax,
                                  GLSL_TYPE_FLOAT, nir_var_mem_global);
         progress = true;
      } else if (strcmp(func->name, "_Z10atomic_minPU3AS3Vff") == 0) {
         /* float atomic_min(__shared float volatile *, float) */
         implement_atomic_builtin(func, nir_atomic_op_fmin,
                                  GLSL_TYPE_FLOAT, nir_var_mem_shared);
         progress = true;
      } else if (strcmp(func->name, "_Z10atomic_maxPU3AS3Vff") == 0) {
         /* float atomic_max(__shared float volatile *, float) */
         implement_atomic_builtin(func, nir_atomic_op_fmax,
                                  GLSL_TYPE_FLOAT, nir_var_mem_shared);
         progress = true;
      } else if (strcmp(func->name, "intel_sub_group_ballot") == 0) {
         implement_sub_group_ballot_builtin(func);
         progress = true;
      }
   }

   nir_shader_preserve_all_metadata(nir);

   return progress;
}

static bool
lower_kernel_intrinsics(nir_shader *nir)
{
   nir_function_impl *impl = nir_shader_get_entrypoint(nir);

   bool progress = false;

   unsigned kernel_sysvals_start = 0;
   unsigned kernel_arg_start = sizeof(struct brw_kernel_sysvals);
   nir->num_uniforms += kernel_arg_start;

   nir_builder b = nir_builder_create(impl);

   nir_foreach_block(block, impl) {
      nir_foreach_instr_safe(instr, block) {
         if (instr->type != nir_instr_type_intrinsic)
            continue;

         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
         switch (intrin->intrinsic) {
         case nir_intrinsic_load_kernel_input: {
            b.cursor = nir_instr_remove(&intrin->instr);

            nir_intrinsic_instr *load =
               nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
            load->num_components = intrin->num_components;
            load->src[0] = nir_src_for_ssa(nir_u2u32(&b, intrin->src[0].ssa));
            nir_intrinsic_set_base(load, kernel_arg_start);
            nir_intrinsic_set_range(load, nir->num_uniforms);
            nir_def_init(&load->instr, &load->def,
                         intrin->def.num_components,
                         intrin->def.bit_size);
            nir_builder_instr_insert(&b, &load->instr);

            nir_def_rewrite_uses(&intrin->def, &load->def);
            progress = true;
            break;
         }

         case nir_intrinsic_load_constant_base_ptr: {
            b.cursor = nir_instr_remove(&intrin->instr);
            nir_def *const_data_base_addr = nir_pack_64_2x32_split(&b,
               nir_load_reloc_const_intel(&b, BRW_SHADER_RELOC_CONST_DATA_ADDR_LOW),
               nir_load_reloc_const_intel(&b, BRW_SHADER_RELOC_CONST_DATA_ADDR_HIGH));
            nir_def_rewrite_uses(&intrin->def, const_data_base_addr);
            progress = true;
            break;
         }

         case nir_intrinsic_load_num_workgroups: {
            b.cursor = nir_instr_remove(&intrin->instr);

            nir_intrinsic_instr *load =
               nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
            load->num_components = 3;
            load->src[0] = nir_src_for_ssa(nir_imm_int(&b, 0));
            nir_intrinsic_set_base(load, kernel_sysvals_start +
               offsetof(struct brw_kernel_sysvals, num_work_groups));
            nir_intrinsic_set_range(load, 3 * 4);
            nir_def_init(&load->instr, &load->def, 3, 32);
            nir_builder_instr_insert(&b, &load->instr);
            nir_def_rewrite_uses(&intrin->def, &load->def);
            progress = true;
            break;
         }

         default:
            break;
         }
      }
   }

   if (progress) {
      nir_metadata_preserve(impl, nir_metadata_control_flow);
   } else {
      nir_metadata_preserve(impl, nir_metadata_all);
   }

   return progress;
}

static const struct spirv_capabilities spirv_caps = {
   .Addresses = true,
   .Float16 = true,
   .Float64 = true,
   .Groups = true,
   .StorageImageWriteWithoutFormat = true,
   .Int8 = true,
   .Int16 = true,
   .Int64 = true,
   .Int64Atomics = true,
   .Kernel = true,
   .Linkage = true, /* We receive linked kernel from clc */
   .DenormFlushToZero = true,
   .DenormPreserve = true,
   .SignedZeroInfNanPreserve = true,
   .RoundingModeRTE = true,
   .RoundingModeRTZ = true,
   .GenericPointer = true,
   .GroupNonUniform = true,
   .GroupNonUniformArithmetic = true,
   .GroupNonUniformClustered = true,
   .GroupNonUniformBallot = true,
   .GroupNonUniformQuad = true,
   .GroupNonUniformShuffle = true,
   .GroupNonUniformVote = true,
   .SubgroupDispatch = true,

   .SubgroupShuffleINTEL = true,
   .SubgroupBufferBlockIOINTEL = true,
};

bool
brw_kernel_from_spirv(struct brw_compiler *compiler,
                      struct disk_cache *disk_cache,
                      struct brw_kernel *kernel,
                      void *log_data, void *mem_ctx,
                      const uint32_t *spirv, size_t spirv_size,
                      const char *entrypoint_name,
                      char **error_str)
{
   const struct intel_device_info *devinfo = compiler->devinfo;
   const nir_shader_compiler_options *nir_options =
      compiler->nir_options[MESA_SHADER_KERNEL];

   struct spirv_to_nir_options spirv_options = {
      .environment = NIR_SPIRV_OPENCL,
      .capabilities = &spirv_caps,
      .printf = true,
      .shared_addr_format = nir_address_format_62bit_generic,
      .global_addr_format = nir_address_format_62bit_generic,
      .temp_addr_format = nir_address_format_62bit_generic,
      .constant_addr_format = nir_address_format_64bit_global,
   };

   spirv_options.clc_shader = load_clc_shader(compiler, disk_cache,
                                              nir_options, &spirv_options);
   if (spirv_options.clc_shader == NULL) {
      fprintf(stderr, "ERROR: libclc shader missing."
              " Consider installing the libclc package\n");
      abort();
   }

   assert(spirv_size % 4 == 0);
   nir_shader *nir =
      spirv_to_nir(spirv, spirv_size / 4, NULL, 0, MESA_SHADER_KERNEL,
                   entrypoint_name, &spirv_options, nir_options);
   nir_validate_shader(nir, "after spirv_to_nir");
   nir_validate_ssa_dominance(nir, "after spirv_to_nir");
   ralloc_steal(mem_ctx, nir);
   nir->info.name = ralloc_strdup(nir, entrypoint_name);

   if (INTEL_DEBUG(DEBUG_CS)) {
      /* Re-index SSA defs so we print more sensible numbers. */
      nir_foreach_function_impl(impl, nir) {
         nir_index_ssa_defs(impl);
      }

      fprintf(stderr, "NIR (from SPIR-V) for kernel\n");
      nir_print_shader(nir, stderr);
   }

   nir_lower_printf_options printf_opts = {
      .ptr_bit_size               = 64,
      .use_printf_base_identifier = true,
   };
   NIR_PASS_V(nir, nir_lower_printf, &printf_opts);

   NIR_PASS_V(nir, implement_intel_builtins);
   NIR_PASS_V(nir, nir_link_shader_functions, spirv_options.clc_shader);

   /* We have to lower away local constant initializers right before we
    * inline functions.  That way they get properly initialized at the top
    * of the function and not at the top of its caller.
    */
   NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
   NIR_PASS_V(nir, nir_lower_returns);
   NIR_PASS_V(nir, nir_inline_functions);
   NIR_PASS_V(nir, nir_copy_prop);
   NIR_PASS_V(nir, nir_opt_deref);

   /* Pick off the single entrypoint that we want */
   nir_remove_non_entrypoints(nir);

   /* Now that we've deleted all but the main function, we can go ahead and
    * lower the rest of the constant initializers.  We do this here so that
    * nir_remove_dead_variables and split_per_member_structs below see the
    * corresponding stores.
    */
   NIR_PASS_V(nir, nir_lower_variable_initializers, ~0);

   /* LLVM loves take advantage of the fact that vec3s in OpenCL are 16B
    * aligned and so it can just read/write them as vec4s.  This results in a
    * LOT of vec4->vec3 casts on loads and stores.  One solution to this
    * problem is to get rid of all vec3 variables.
    */
   NIR_PASS_V(nir, nir_lower_vec3_to_vec4,
              nir_var_shader_temp | nir_var_function_temp |
              nir_var_mem_shared | nir_var_mem_global|
              nir_var_mem_constant);

   /* We assign explicit types early so that the optimizer can take advantage
    * of that information and hopefully get rid of some of our memcpys.
    */
   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
              nir_var_uniform |
              nir_var_shader_temp | nir_var_function_temp |
              nir_var_mem_shared | nir_var_mem_global,
              glsl_get_cl_type_size_align);

   struct brw_nir_compiler_opts opts = {};
   brw_preprocess_nir(compiler, nir, &opts);

   int max_arg_idx = -1;
   nir_foreach_uniform_variable(var, nir) {
      assert(var->data.location < 256);
      max_arg_idx = MAX2(max_arg_idx, var->data.location);
   }

   kernel->args_size = nir->num_uniforms;
   kernel->arg_count = max_arg_idx + 1;

   /* No bindings */
   struct brw_kernel_arg_desc *args =
      rzalloc_array(mem_ctx, struct brw_kernel_arg_desc, kernel->arg_count);
   kernel->args = args;

   nir_foreach_uniform_variable(var, nir) {
      struct brw_kernel_arg_desc arg_desc = {
         .offset = var->data.driver_location,
         .size = glsl_get_explicit_size(var->type, false),
      };
      assert(arg_desc.offset + arg_desc.size <= nir->num_uniforms);

      assert(var->data.location >= 0);
      args[var->data.location] = arg_desc;
   }

   NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_all, NULL);

   /* Lower again, this time after dead-variables to get more compact variable
    * layouts.
    */
   nir->global_mem_size = 0;
   nir->scratch_size = 0;
   nir->info.shared_size = 0;
   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
              nir_var_shader_temp | nir_var_function_temp |
              nir_var_mem_shared | nir_var_mem_global | nir_var_mem_constant,
              glsl_get_cl_type_size_align);
   if (nir->constant_data_size > 0) {
      assert(nir->constant_data == NULL);
      nir->constant_data = rzalloc_size(nir, nir->constant_data_size);
      nir_gather_explicit_io_initializers(nir, nir->constant_data,
                                          nir->constant_data_size,
                                          nir_var_mem_constant);
   }

   if (INTEL_DEBUG(DEBUG_CS)) {
      /* Re-index SSA defs so we print more sensible numbers. */
      nir_foreach_function_impl(impl, nir) {
         nir_index_ssa_defs(impl);
      }

      fprintf(stderr, "NIR (before I/O lowering) for kernel\n");
      nir_print_shader(nir, stderr);
   }

   NIR_PASS_V(nir, nir_lower_memcpy);

   NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_constant,
              nir_address_format_64bit_global);

   NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_uniform,
              nir_address_format_32bit_offset_as_64bit);

   NIR_PASS_V(nir, nir_lower_explicit_io,
              nir_var_shader_temp | nir_var_function_temp |
              nir_var_mem_shared | nir_var_mem_global,
              nir_address_format_62bit_generic);

   NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);

   NIR_PASS_V(nir, brw_nir_lower_cs_intrinsics, devinfo, NULL);
   NIR_PASS_V(nir, lower_kernel_intrinsics);

   struct brw_cs_prog_key key = { };

   memset(&kernel->prog_data, 0, sizeof(kernel->prog_data));
   kernel->prog_data.base.nr_params = DIV_ROUND_UP(nir->num_uniforms, 4);

   struct brw_compile_cs_params params = {
      .base = {
         .nir = nir,
         .stats = kernel->stats,
         .log_data = log_data,
         .mem_ctx = mem_ctx,
      },
      .key = &key,
      .prog_data = &kernel->prog_data,
   };

   kernel->code = brw_compile_cs(compiler, &params);

   if (error_str)
      *error_str = params.base.error_str;

   return kernel->code != NULL;
}

static nir_def *
rebuild_value_from_store(struct util_dynarray *stores,
                         nir_def *value, unsigned read_offset)
{
   unsigned read_size = value->num_components * value->bit_size / 8;

   util_dynarray_foreach(stores, nir_intrinsic_instr *, _store) {
      nir_intrinsic_instr *store = *_store;

      unsigned write_offset = nir_src_as_uint(store->src[1]);
      unsigned write_size = nir_src_num_components(store->src[0]) *
                            nir_src_bit_size(store->src[0]) / 8;
      if (write_offset <= read_offset &&
          (write_offset + write_size) >= (read_offset + read_size)) {
         assert(nir_block_dominates(store->instr.block, value->parent_instr->block));
         assert(write_size == read_size);
         return store->src[0].ssa;
      }
   }
   unreachable("Matching scratch store not found");
}

/**
 * Remove temporary variables stored to scratch to be then reloaded
 * immediately. Remap the load to the store SSA value.
 *
 * This workaround is only meant to be applied to shaders in src/intel/shaders
 * were we know there should be no issue. More complex cases might not work
 * with this approach.
 */
static bool
nir_remove_llvm17_scratch(nir_shader *nir)
{
   struct util_dynarray scratch_stores;
   void *mem_ctx = ralloc_context(NULL);

   util_dynarray_init(&scratch_stores, mem_ctx);

   nir_foreach_function_impl(func, nir) {
      nir_foreach_block(block, func) {
         nir_foreach_instr(instr, block) {
            if (instr->type != nir_instr_type_intrinsic)
               continue;

            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);

            if (intrin->intrinsic != nir_intrinsic_store_scratch)
               continue;

            nir_const_value *offset = nir_src_as_const_value(intrin->src[1]);
            if (offset != NULL) {
               util_dynarray_append(&scratch_stores, nir_intrinsic_instr *, intrin);
            }
         }
      }
   }

   bool progress = false;
   if (util_dynarray_num_elements(&scratch_stores, nir_intrinsic_instr *) > 0) {
      nir_foreach_function_impl(func, nir) {
         nir_foreach_block(block, func) {
            nir_foreach_instr_safe(instr, block) {
               if (instr->type != nir_instr_type_intrinsic)
                  continue;

               nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);

               if (intrin->intrinsic != nir_intrinsic_load_scratch)
                  continue;

               nir_const_value *offset = nir_src_as_const_value(intrin->src[0]);
               if (offset == NULL)
                  continue;

               nir_def_replace(&intrin->def,
                               rebuild_value_from_store(&scratch_stores, &intrin->def, nir_src_as_uint(intrin->src[0])));

               progress = true;
            }
         }
      }
   }

   util_dynarray_foreach(&scratch_stores, nir_intrinsic_instr *, _store) {
      nir_intrinsic_instr *store = *_store;
      nir_instr_remove(&store->instr);
   }

   /* Quick sanity check */
   assert(util_dynarray_num_elements(&scratch_stores, nir_intrinsic_instr *) == 0 ||
          progress);

   ralloc_free(mem_ctx);

   return progress;
}

static void
cleanup_llvm17_scratch(nir_shader *nir)
{
   {
      bool progress;
      do {
         progress = false;
         NIR_PASS(progress, nir, nir_copy_prop);
         NIR_PASS(progress, nir, nir_opt_dce);
         NIR_PASS(progress, nir, nir_opt_constant_folding);
         NIR_PASS(progress, nir, nir_opt_cse);
         NIR_PASS(progress, nir, nir_opt_algebraic);
      } while (progress);
   }

   nir_remove_llvm17_scratch(nir);

   {
      bool progress;
      do {
         progress = false;
         NIR_PASS(progress, nir, nir_copy_prop);
         NIR_PASS(progress, nir, nir_opt_dce);
         NIR_PASS(progress, nir, nir_opt_constant_folding);
         NIR_PASS(progress, nir, nir_opt_cse);
         NIR_PASS(progress, nir, nir_opt_algebraic);
      } while (progress);
   }
}

nir_shader *
brw_nir_from_spirv(void *mem_ctx, unsigned gfx_version, const uint32_t *spirv,
                   size_t spirv_size, bool llvm17_wa)
{
   struct spirv_to_nir_options spirv_options = {
      .environment = NIR_SPIRV_OPENCL,
      .capabilities = &spirv_caps,
      .printf = true,
      .shared_addr_format = nir_address_format_62bit_generic,
      .global_addr_format = nir_address_format_62bit_generic,
      .temp_addr_format = nir_address_format_62bit_generic,
      .constant_addr_format = nir_address_format_64bit_global,
      .create_library = true,
   };

   assert(spirv_size % 4 == 0);

   assert(gfx_version);
   const nir_shader_compiler_options *nir_options =
      gfx_version >= 9 ? &brw_scalar_nir_options
                       : &elk_scalar_nir_options;

   nir_shader *nir =
      spirv_to_nir(spirv, spirv_size / 4, NULL, 0, MESA_SHADER_KERNEL,
                   "library", &spirv_options, nir_options);
   nir_validate_shader(nir, "after spirv_to_nir");
   nir_validate_ssa_dominance(nir, "after spirv_to_nir");
   ralloc_steal(mem_ctx, nir);
   nir->info.name = ralloc_strdup(nir, "library");

   if (INTEL_DEBUG(DEBUG_CS)) {
      /* Re-index SSA defs so we print more sensible numbers. */
      nir_foreach_function_impl(impl, nir) {
         nir_index_ssa_defs(impl);
      }

      fprintf(stderr, "NIR (from SPIR-V) for kernel\n");
      nir_print_shader(nir, stderr);
   }

   nir_lower_printf_options printf_opts = {
      .ptr_bit_size               = 64,
      .use_printf_base_identifier = true,
   };
   NIR_PASS_V(nir, nir_lower_printf, &printf_opts);

   NIR_PASS_V(nir, implement_intel_builtins);
   NIR_PASS_V(nir, nir_link_shader_functions, spirv_options.clc_shader);

   /* We have to lower away local constant initializers right before we
    * inline functions.  That way they get properly initialized at the top
    * of the function and not at the top of its caller.
    */
   NIR_PASS_V(nir, nir_lower_variable_initializers, ~(nir_var_shader_temp |
                                                      nir_var_function_temp));
   NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_uniform | nir_var_mem_ubo |
              nir_var_mem_constant | nir_var_function_temp | nir_var_image, NULL);
   {
      bool progress;
      do
      {
         progress = false;
         NIR_PASS(progress, nir, nir_copy_prop);
         NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
         NIR_PASS(progress, nir, nir_opt_deref);
         NIR_PASS(progress, nir, nir_opt_dce);
         NIR_PASS(progress, nir, nir_opt_undef);
         NIR_PASS(progress, nir, nir_opt_constant_folding);
         NIR_PASS(progress, nir, nir_opt_cse);
         NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
         NIR_PASS(progress, nir, nir_opt_algebraic);
      } while (progress);
   }

   NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
   NIR_PASS_V(nir, nir_lower_returns);
   NIR_PASS_V(nir, nir_inline_functions);

   assert(nir->scratch_size == 0);
   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_function_temp, glsl_get_cl_type_size_align);

   {
      bool progress;
      do
      {
         progress = false;
         NIR_PASS(progress, nir, nir_copy_prop);
         NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
         NIR_PASS(progress, nir, nir_opt_deref);
         NIR_PASS(progress, nir, nir_opt_dce);
         NIR_PASS(progress, nir, nir_opt_undef);
         NIR_PASS(progress, nir, nir_opt_constant_folding);
         NIR_PASS(progress, nir, nir_opt_cse);
         NIR_PASS(progress, nir, nir_split_var_copies);
         NIR_PASS(progress, nir, nir_lower_var_copies);
         NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
         NIR_PASS(progress, nir, nir_opt_algebraic);
         NIR_PASS(progress, nir, nir_opt_if, nir_opt_if_optimize_phi_true_false);
         NIR_PASS(progress, nir, nir_opt_dead_cf);
         NIR_PASS(progress, nir, nir_opt_remove_phis);
         NIR_PASS(progress, nir, nir_opt_peephole_select, 8, true, true);
         NIR_PASS(progress, nir, nir_lower_vec3_to_vec4, nir_var_mem_generic | nir_var_uniform);
         NIR_PASS(progress, nir, nir_opt_memcpy);
      } while (progress);
   }

   NIR_PASS_V(nir, nir_scale_fdiv);

   NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_uniform | nir_var_mem_ubo |
              nir_var_mem_constant | nir_var_function_temp | nir_var_image, NULL);


   NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_mem_shared | nir_var_function_temp, NULL);

   nir->scratch_size = 0;
   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
              nir_var_mem_shared | nir_var_function_temp | nir_var_shader_temp |
              nir_var_mem_global | nir_var_mem_constant,
              glsl_get_cl_type_size_align);

   // Lower memcpy - needs to wait until types are sized
   {
      bool progress;
      do {
         progress = false;
         NIR_PASS(progress, nir, nir_opt_memcpy);
         NIR_PASS(progress, nir, nir_copy_prop);
         NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
         NIR_PASS(progress, nir, nir_opt_deref);
         NIR_PASS(progress, nir, nir_opt_dce);
         NIR_PASS(progress, nir, nir_split_var_copies);
         NIR_PASS(progress, nir, nir_lower_var_copies);
         NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
         NIR_PASS(progress, nir, nir_opt_constant_folding);
         NIR_PASS(progress, nir, nir_opt_cse);
      } while (progress);
   }
   NIR_PASS_V(nir, nir_lower_memcpy);

   NIR_PASS_V(nir, nir_lower_explicit_io,
              nir_var_mem_shared | nir_var_function_temp | nir_var_shader_temp | nir_var_uniform,
              nir_address_format_32bit_offset_as_64bit);

   NIR_PASS_V(nir, nir_lower_system_values);

   /* Hopefully we can drop this once lower_vars_to_ssa has improved to not
    * lower everything to scratch.
    */
   if (llvm17_wa)
      cleanup_llvm17_scratch(nir);

   /* Lower again, this time after dead-variables to get more compact variable
    * layouts.
    */
   nir->global_mem_size = 0;
   nir->scratch_size = 0;
   nir->info.shared_size = 0;
   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
              nir_var_mem_shared | nir_var_mem_global | nir_var_mem_constant,
              glsl_get_cl_type_size_align);
   if (nir->constant_data_size > 0) {
      assert(nir->constant_data == NULL);
      nir->constant_data = rzalloc_size(nir, nir->constant_data_size);
      nir_gather_explicit_io_initializers(nir, nir->constant_data,
                                          nir->constant_data_size,
                                          nir_var_mem_constant);
   }

   NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_constant,
              nir_address_format_64bit_global);

   NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_uniform,
              nir_address_format_32bit_offset_as_64bit);

   NIR_PASS_V(nir, nir_lower_explicit_io,
              nir_var_shader_temp | nir_var_function_temp |
              nir_var_mem_shared | nir_var_mem_global,
              nir_address_format_62bit_generic);

   if (INTEL_DEBUG(DEBUG_CS)) {
      /* Re-index SSA defs so we print more sensible numbers. */
      nir_foreach_function_impl(impl, nir) {
         nir_index_ssa_defs(impl);
      }

      fprintf(stderr, "NIR (before I/O lowering) for kernel\n");
      nir_print_shader(nir, stderr);
   }

   return nir;
}
