// Copyright (c) 2020 The Khronos Group Inc.
// Copyright (c) 2020 Valve Corporation
// Copyright (c) 2020 LunarG Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef LIBSPIRV_OPT_INST_DEBUG_PRINTF_PASS_H_
#define LIBSPIRV_OPT_INST_DEBUG_PRINTF_PASS_H_

#include "instrument_pass.h"

namespace spvtools {
namespace opt {

// This class/pass is designed to support the debug printf GPU-assisted layer
// of https://github.com/KhronosGroup/Vulkan-ValidationLayers. Its internal and
// external design may change as the layer evolves.
class InstDebugPrintfPass : public InstrumentPass {
 public:
  // For test harness only
  InstDebugPrintfPass() : InstrumentPass(7, 23, false, false) {}
  // For all other interfaces
  InstDebugPrintfPass(uint32_t desc_set, uint32_t shader_id)
      : InstrumentPass(desc_set, shader_id, false, false) {}

  ~InstDebugPrintfPass() override = default;

  // See optimizer.hpp for pass user documentation.
  Status Process() override;

  const char* name() const override { return "inst-printf-pass"; }

 private:
  // Gen code into |builder| to write |field_value_id| into debug output
  // buffer at |base_offset_id| + |field_offset|.
  void GenDebugOutputFieldCode(uint32_t base_offset_id, uint32_t field_offset,
                               uint32_t field_value_id,
                               InstructionBuilder* builder);

  // Generate instructions in |builder| which will atomically fetch and
  // increment the size of the debug output buffer stream of the current
  // validation and write a record to the end of the stream, if enough space
  // in the buffer remains. The record will contain the index of the function
  // and instruction within that function |func_idx, instruction_idx| which
  // generated the record. Finally, the record will contain validation-specific
  // data contained in |validation_ids| which will identify the validation
  // error as well as the values involved in the error.
  //
  // The output buffer binding written to by the code generated by the function
  // is determined by the validation id specified when each specific
  // instrumentation pass is created.
  //
  // The output buffer is a sequence of 32-bit values with the following
  // format (where all elements are unsigned 32-bit unless otherwise noted):
  //
  //     Size
  //     Record0
  //     Record1
  //     Record2
  //     ...
  //
  // Size is the number of 32-bit values that have been written or
  // attempted to be written to the output buffer, excluding the Size. It is
  // initialized to 0. If the size of attempts to write the buffer exceeds
  // the actual size of the buffer, it is possible that this field can exceed
  // the actual size of the buffer.
  //
  // Each Record* is a variable-length sequence of 32-bit values with the
  // following format defined using static const offsets in the .cpp file:
  //
  //     Record Size
  //     Shader ID
  //     Instruction Index
  //     ...
  //     Validation Error Code
  //     Validation-specific Word 0
  //     Validation-specific Word 1
  //     Validation-specific Word 2
  //     ...
  //
  // Each record consists of two subsections: members common across all
  // validation and members specific to a
  // validation.
  //
  // The Record Size is the number of 32-bit words in the record, including
  // the Record Size word.
  //
  // Shader ID is a value that identifies which shader has generated the
  // validation error. It is passed when the instrumentation pass is created.
  //
  // The Instruction Index is the position of the instruction within the
  // SPIR-V file which is in error.
  //
  // The Validation Error Code specifies the exact error which has occurred.
  // These are enumerated with the kInstError* static consts. This allows
  // multiple validation layers to use the same, single output buffer.
  //
  // The Validation-specific Words are a validation-specific number of 32-bit
  // words which give further information on the validation error that
  // occurred. These are documented further in each file containing the
  // validation-specific class which derives from this base class.
  //
  // Because the code that is generated checks against the size of the buffer
  // before writing, the size of the debug out buffer can be used by the
  // validation layer to control the number of error records that are written.
  void GenDebugStreamWrite(uint32_t shader_id, uint32_t instruction_idx_id,
                           const std::vector<uint32_t>& validation_ids,
                           InstructionBuilder* builder);

  // Return id for output function. Define if it doesn't exist with
  // |val_spec_param_cnt| validation-specific uint32 parameters.
  uint32_t GetStreamWriteFunctionId(uint32_t val_spec_param_cnt);

  // Generate instructions for OpDebugPrintf.
  //
  // If |ref_inst_itr| is an OpDebugPrintf, return in |new_blocks| the result
  // of replacing it with buffer write instructions within its block at
  // |ref_block_itr|.  The instructions write a record to the printf
  // output buffer stream including |function_idx, instruction_idx|
  // and removes the OpDebugPrintf. The block at |ref_block_itr| can just be
  // replaced with the block in |new_blocks|. Besides the buffer writes, this
  // block will comprise all instructions preceding and following
  // |ref_inst_itr|.
  //
  // This function is designed to be passed to
  // InstrumentPass::InstProcessEntryPointCallTree(), which applies the
  // function to each instruction in a module and replaces the instruction
  // if warranted.
  //
  // This instrumentation function utilizes GenDebugStreamWrite() to write its
  // error records. The validation-specific part of the error record will
  // consist of a uint32 which is the id of the format string plus a sequence
  // of uint32s representing the values of the remaining operands of the
  // DebugPrintf.
  void GenDebugPrintfCode(BasicBlock::iterator ref_inst_itr,
                          UptrVectorIterator<BasicBlock> ref_block_itr,
                          std::vector<std::unique_ptr<BasicBlock>>* new_blocks);

  // Generate a sequence of uint32 instructions in |builder| (if necessary)
  // representing the value of |val_inst|, which must be a buffer pointer, a
  // uint64, or a scalar or vector of type uint32, float32 or float16. Append
  // the ids of all values to the end of |val_ids|.
  void GenOutputValues(Instruction* val_inst, std::vector<uint32_t>* val_ids,
                       InstructionBuilder* builder);

  // Generate instructions to write a record containing the operands of
  // |printf_inst| arguments to printf buffer, adding new code to the end of
  // the last block in |new_blocks|. Kill OpDebugPrintf instruction.
  void GenOutputCode(Instruction* printf_inst,
                     std::vector<std::unique_ptr<BasicBlock>>* new_blocks);

  // Set the name for a function or global variable, names will be
  // prefixed to identify which instrumentation pass generated them.
  std::unique_ptr<Instruction> NewGlobalName(uint32_t id,
                                             const std::string& name_str);

  // Set the name for a structure member
  std::unique_ptr<Instruction> NewMemberName(uint32_t id, uint32_t member_index,
                                             const std::string& name_str);

  // Return id for debug output buffer
  uint32_t GetOutputBufferId();

  // Return id for buffer uint type
  uint32_t GetOutputBufferPtrId();

  // Return binding for output buffer for current validation.
  uint32_t GetOutputBufferBinding();

  // Initialize state for instrumenting bindless checking
  void InitializeInstDebugPrintf();

  // Apply GenDebugPrintfCode to every instruction in module.
  Pass::Status ProcessImpl();

  uint32_t ext_inst_printf_id_{0};

  // id for output buffer variable
  uint32_t output_buffer_id_{0};

  // ptr type id for output buffer element
  uint32_t output_buffer_ptr_id_{0};
};

}  // namespace opt
}  // namespace spvtools

#endif  // LIBSPIRV_OPT_INST_DEBUG_PRINTF_PASS_H_
