/*
 * Copyright 2023 Advanced Micro Devices, Inc.
 *
 * SPDX-License-Identifier: MIT
 */

#include "nir_opt_varyings_test.h"

class nir_opt_varyings_test_dead_input : public nir_opt_varyings_test
{};

#define TEST_DEAD_INPUT_TO_UNDEF(producer_stage, consumer_stage, slot, bitsize) \
TEST_F(nir_opt_varyings_test_dead_input, producer_stage##_##consumer_stage##_##slot##_##bitsize) \
{ \
   create_shaders(MESA_SHADER_##producer_stage, MESA_SHADER_##consumer_stage); \
   nir_def *input = load_input(b2, VARYING_SLOT_##slot, 0, nir_type_float##bitsize, 0, 0); \
   store_output(b2, VARYING_SLOT_POS, 0, nir_type_float##bitsize, input, 0); \
   \
   ASSERT_TRUE(opt_varyings() == nir_progress_consumer); \
   ASSERT_TRUE(b2->shader->info.inputs_read == 0 && \
               b2->shader->info.patch_inputs_read == 0 && \
               b2->shader->info.inputs_read_16bit == 0); \
   ASSERT_TRUE(!shader_contains_def(b2, input)); \
   ASSERT_TRUE(shader_contains_undef(b2, bitsize)); \
}

#define TEST_DEAD_INPUT_TO_CONST(producer_stage, consumer_stage, slot, comp, bitsize, value) \
TEST_F(nir_opt_varyings_test_dead_input, producer_stage##_##consumer_stage##_##slot##_##comp##_##bitsize) \
{ \
   create_shaders(MESA_SHADER_##producer_stage, MESA_SHADER_##consumer_stage); \
   nir_def *input = load_input(b2, VARYING_SLOT_##slot, comp, nir_type_float##bitsize, 0, 0); \
   store_output(b2, VARYING_SLOT_POS, 0, nir_type_float##bitsize, input, 0); \
   \
   ASSERT_TRUE(opt_varyings() == nir_progress_consumer); \
   ASSERT_TRUE(b2->shader->info.inputs_read == 0 && \
               b2->shader->info.patch_inputs_read == 0 && \
               b2->shader->info.inputs_read_16bit == 0); \
   ASSERT_TRUE(!shader_contains_def(b2, input)); \
   ASSERT_TRUE(shader_contains_const_float(b2, value, bitsize)); \
}

#define TEST_DEAD_INPUT_KEPT(producer_stage, consumer_stage, slot, bitsize) \
TEST_F(nir_opt_varyings_test_dead_input, producer_stage##_##consumer_stage##_##slot##_##bitsize) \
{ \
   create_shaders(MESA_SHADER_##producer_stage, MESA_SHADER_##consumer_stage); \
   nir_def *input = load_input(b2, VARYING_SLOT_##slot, 0, nir_type_float##bitsize, 0, 0); \
   store_output(b2, VARYING_SLOT_POS, 0, nir_type_float##bitsize, input, 0); \
   \
   ASSERT_TRUE(opt_varyings() == 0); \
   ASSERT_TRUE(b2->shader->info.inputs_read == VARYING_BIT_##slot); \
   ASSERT_TRUE(shader_contains_def(b2, input)); \
}

#define TEST_OUTPUT_INPUT_ROUTING_KEPT(producer_stage, consumer_stage, pslot, cslot, bitsize) \
TEST_F(nir_opt_varyings_test_dead_input, \
       routing_##producer_stage##_##pslot##_##consumer_stage##_##cslot##_##bitsize) \
{ \
   create_shaders(MESA_SHADER_##producer_stage, MESA_SHADER_##consumer_stage); \
   store_output(b1, VARYING_SLOT_##pslot, 0, nir_type_float##bitsize, \
                load_input(b1, VARYING_SLOT_POS, 0, nir_type_float##bitsize, 0, 0), 0); \
   \
   nir_def *input = load_input(b2, VARYING_SLOT_##cslot, 0, nir_type_float##bitsize, 0, 0); \
   store_output(b2, VARYING_SLOT_POS, 0, nir_type_float##bitsize, input, 0); \
   \
   /* Compaction moves COL1 to COL0. */ \
   unsigned pindex = VARYING_SLOT_##pslot; \
   unsigned cindex = VARYING_SLOT_##cslot; \
   if (cindex == VARYING_SLOT_COL1) { \
      pindex--; \
      cindex--; \
   } \
   \
   ASSERT_TRUE(opt_varyings() == 0); \
   ASSERT_TRUE(b1->shader->info.outputs_written == BITFIELD64_BIT(pindex)); \
   ASSERT_TRUE(b2->shader->info.inputs_read == BITFIELD64_BIT(cindex)); \
   ASSERT_TRUE(shader_contains_def(b2, input)); \
}

TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, POS, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, COL0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, COL1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, FOGC, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, TEX0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, PSIZ, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, BFC0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, BFC1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, CLIP_VERTEX, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, CLIP_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, CLIP_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, CULL_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, CULL_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, PRIMITIVE_ID, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, LAYER, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, VIEWPORT, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, VAR0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, VAR0, 16)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_CTRL, VAR0_16BIT, 16)

TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, POS, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, COL0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, COL1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, FOGC, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, TEX0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, PSIZ, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, BFC0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, BFC1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, CLIP_VERTEX, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, CLIP_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, CLIP_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, CULL_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, CULL_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, PRIMITIVE_ID, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, LAYER, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, VIEWPORT, 32)
TEST_DEAD_INPUT_KEPT(VERTEX, TESS_EVAL, TESS_LEVEL_INNER, 32)
TEST_DEAD_INPUT_KEPT(VERTEX, TESS_EVAL, TESS_LEVEL_OUTER, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, VAR0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, VAR0, 16)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, TESS_EVAL, VAR0_16BIT, 16)

TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, POS, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, COL0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, COL1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, FOGC, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, TEX0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, PSIZ, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, BFC0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, BFC1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, CLIP_VERTEX, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, CLIP_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, CLIP_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, CULL_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, CULL_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, PRIMITIVE_ID, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, LAYER, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, VIEWPORT, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, TESS_LEVEL_INNER, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, TESS_LEVEL_OUTER, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, VAR0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, VAR0, 16)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, VAR0_16BIT, 16)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, PATCH0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_CTRL, TESS_EVAL, PATCH0, 16)

TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, POS, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, COL0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, COL1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, FOGC, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, TEX0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, PSIZ, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, BFC0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, BFC1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, CLIP_VERTEX, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, CLIP_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, CLIP_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, CULL_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, CULL_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, PRIMITIVE_ID, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, LAYER, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, VIEWPORT, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, VAR0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, VAR0, 16)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, GEOMETRY, VAR0_16BIT, 16)

TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, POS, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, COL0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, COL1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, FOGC, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, TEX0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, PSIZ, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, BFC0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, BFC1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, CLIP_VERTEX, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, CLIP_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, CLIP_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, CULL_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, CULL_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, PRIMITIVE_ID, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, LAYER, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, VIEWPORT, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, VAR0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, VAR0, 16)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, GEOMETRY, VAR0_16BIT, 16)

TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, COL0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, COL1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, FOGC, 32)
TEST_DEAD_INPUT_KEPT(VERTEX, FRAGMENT, TEX0, 32)
TEST_DEAD_INPUT_TO_CONST(VERTEX, FRAGMENT, TEX0, 2, 32, 0)
TEST_DEAD_INPUT_TO_CONST(VERTEX, FRAGMENT, TEX0, 3, 32, 1)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, CLIP_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, CLIP_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, CULL_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, CULL_DIST1, 32)
TEST_DEAD_INPUT_KEPT(VERTEX, FRAGMENT, PRIMITIVE_ID, 32)
TEST_DEAD_INPUT_TO_CONST(VERTEX, FRAGMENT, LAYER, 0, 32, 0)
TEST_DEAD_INPUT_TO_CONST(VERTEX, FRAGMENT, VIEWPORT, 0, 32, 0)
TEST_DEAD_INPUT_KEPT(VERTEX, FRAGMENT, PNTC, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, VAR0, 32)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, VAR0, 16)
TEST_DEAD_INPUT_TO_UNDEF(VERTEX, FRAGMENT, VAR0_16BIT, 16)

TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, COL0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, COL1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, FOGC, 32)
TEST_DEAD_INPUT_KEPT(TESS_EVAL, FRAGMENT, TEX0, 32)
TEST_DEAD_INPUT_TO_CONST(TESS_EVAL, FRAGMENT, TEX0, 2, 32, 0)
TEST_DEAD_INPUT_TO_CONST(TESS_EVAL, FRAGMENT, TEX0, 3, 32, 1)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, CLIP_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, CLIP_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, CULL_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, CULL_DIST1, 32)
TEST_DEAD_INPUT_KEPT(TESS_EVAL, FRAGMENT, PRIMITIVE_ID, 32)
TEST_DEAD_INPUT_TO_CONST(TESS_EVAL, FRAGMENT, LAYER, 0, 32, 0)
TEST_DEAD_INPUT_TO_CONST(TESS_EVAL, FRAGMENT, VIEWPORT, 0, 32, 0)
TEST_DEAD_INPUT_KEPT(TESS_EVAL, FRAGMENT, PNTC, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, VAR0, 32)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, VAR0, 16)
TEST_DEAD_INPUT_TO_UNDEF(TESS_EVAL, FRAGMENT, VAR0_16BIT, 16)

TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, COL0, 32)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, COL1, 32)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, FOGC, 32)
TEST_DEAD_INPUT_KEPT(GEOMETRY, FRAGMENT, TEX0, 32)
TEST_DEAD_INPUT_TO_CONST(GEOMETRY, FRAGMENT, TEX0, 2, 32, 0)
TEST_DEAD_INPUT_TO_CONST(GEOMETRY, FRAGMENT, TEX0, 3, 32, 1)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, CLIP_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, CLIP_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, CULL_DIST0, 32)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, CULL_DIST1, 32)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, PRIMITIVE_ID, 32)
TEST_DEAD_INPUT_TO_CONST(GEOMETRY, FRAGMENT, LAYER, 0, 32, 0)
TEST_DEAD_INPUT_TO_CONST(GEOMETRY, FRAGMENT, VIEWPORT, 0, 32, 0)
TEST_DEAD_INPUT_KEPT(GEOMETRY, FRAGMENT, PNTC, 32)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, VAR0, 32)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, VAR0, 16)
TEST_DEAD_INPUT_TO_UNDEF(GEOMETRY, FRAGMENT, VAR0_16BIT, 16)

TEST_DEAD_INPUT_KEPT(MESH, FRAGMENT, PNTC, 32)
TEST_DEAD_INPUT_TO_UNDEF(MESH, FRAGMENT, VAR0, 32)
TEST_DEAD_INPUT_TO_UNDEF(MESH, FRAGMENT, VAR0, 16)
TEST_DEAD_INPUT_TO_UNDEF(MESH, FRAGMENT, VAR0_16BIT, 16)

TEST_OUTPUT_INPUT_ROUTING_KEPT(VERTEX, FRAGMENT, BFC0, COL0, 32)
TEST_OUTPUT_INPUT_ROUTING_KEPT(VERTEX, FRAGMENT, BFC1, COL1, 32)
TEST_OUTPUT_INPUT_ROUTING_KEPT(TESS_EVAL, FRAGMENT, BFC0, COL0, 32)
TEST_OUTPUT_INPUT_ROUTING_KEPT(TESS_EVAL, FRAGMENT, BFC1, COL1, 32)
TEST_OUTPUT_INPUT_ROUTING_KEPT(GEOMETRY, FRAGMENT, BFC0, COL0, 32)
TEST_OUTPUT_INPUT_ROUTING_KEPT(GEOMETRY, FRAGMENT, BFC1, COL1, 32)

}
