// Copyright (c) 2019 Valve Corporation
// Copyright (c) 2019 LunarG Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef LIBSPIRV_OPT_CONVERT_TO_HALF_PASS_H_
#define LIBSPIRV_OPT_CONVERT_TO_HALF_PASS_H_

#include "source/opt/ir_builder.h"
#include "source/opt/pass.h"

namespace spvtools {
namespace opt {

class ConvertToHalfPass : public Pass {
 public:
  ConvertToHalfPass() : Pass() {}

  ~ConvertToHalfPass() override = default;

  IRContext::Analysis GetPreservedAnalyses() override {
    return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping;
  }

  // See optimizer.hpp for pass user documentation.
  Status Process() override;

  const char* name() const override { return "convert-to-half-pass"; }

 private:
  // Return true if |inst| is an arithmetic, composite or phi op that can be
  // of type float16
  bool IsArithmetic(Instruction* inst);

  // Return true if |inst| returns scalar, vector or matrix type with base
  // float and |width|
  bool IsFloat(Instruction* inst, uint32_t width);
  bool IsStruct(Instruction* inst);

  // Return true if |inst| is decorated with RelaxedPrecision
  bool IsDecoratedRelaxed(Instruction* inst);

  // Return true if |id| has been added to the relaxed id set
  bool IsRelaxed(uint32_t id);

  // Add |id| to the relaxed id set
  void AddRelaxed(uint32_t id);

  // Return true if the instruction's operands can be relaxed
  bool CanRelaxOpOperands(Instruction* inst);

  // Return type id for float with |width|
  analysis::Type* FloatScalarType(uint32_t width);

  // Return type id for vector of length |vlen| of float of |width|
  analysis::Type* FloatVectorType(uint32_t v_len, uint32_t width);

  // Return type id for matrix of |v_cnt| vectors of length identical to
  // |vty_id| of float of |width|
  analysis::Type* FloatMatrixType(uint32_t v_cnt, uint32_t vty_id,
                                  uint32_t width);

  // Return equivalent to float type |ty_id| with |width|
  uint32_t EquivFloatTypeId(uint32_t ty_id, uint32_t width);

  // Append instructions to builder to convert value |*val_idp| to type
  // |ty_id| but with |width|. Set |*val_idp| to the new id.
  void GenConvert(uint32_t* val_idp, uint32_t width, Instruction* inst);

  // Remove RelaxedPrecision decoration of |id|.
  bool RemoveRelaxedDecoration(uint32_t id);

  // Add |inst| to relaxed instruction set if warranted. Specifically, if
  // it is float32 and either decorated relaxed or a composite or phi
  // instruction where all operands are relaxed or all uses are relaxed.
  bool CloseRelaxInst(Instruction* inst);

  // If |inst| is an arithmetic, phi, extract or convert instruction of float32
  // base type and decorated with RelaxedPrecision, change it to the equivalent
  // float16 based type instruction. Specifically, insert instructions to
  // convert all operands to float16 (if needed) and change its type to the
  // equivalent float16 type. Otherwise, insert instructions to convert its
  // operands back to their original types, if needed.
  bool GenHalfInst(Instruction* inst);

  // Gen code for relaxed arithmetic |inst|
  bool GenHalfArith(Instruction* inst);

  // Gen code for relaxed phi |inst|
  bool ProcessPhi(Instruction* inst, uint32_t from_width, uint32_t to_width);

  // Gen code for relaxed convert |inst|
  bool ProcessConvert(Instruction* inst);

  // Gen code for image reference |inst|
  bool ProcessImageRef(Instruction* inst);

  // Process default non-relaxed |inst|
  bool ProcessDefault(Instruction* inst);

  // If |inst| is an FConvert of a matrix type, decompose it to a series
  // of vector extracts, converts and inserts into an Undef. These are
  // generated by GenHalfInst because they are easier to manipulate, but are
  // invalid so we need to clean them up.
  bool MatConvertCleanup(Instruction* inst);

  // Call GenHalfInst on every instruction in |func|.
  // If code is generated for an instruction, replace the instruction
  // with the new instructions that are generated.
  bool ProcessFunction(Function* func);

  Pass::Status ProcessImpl();

  // Initialize state for converting to half
  void Initialize();

  struct hasher {
    size_t operator()(const spv::Op& op) const noexcept {
      return std::hash<uint32_t>()(uint32_t(op));
    }
  };

  // Set of core operations to be processed
  std::unordered_set<spv::Op, hasher> target_ops_core_;

  // Set of 450 extension operations to be processed
  std::unordered_set<uint32_t> target_ops_450_;

  // Set of all sample operations, including dref and non-dref operations
  std::unordered_set<spv::Op, hasher> image_ops_;

  // Set of only dref sample operations
  std::unordered_set<spv::Op, hasher> dref_image_ops_;

  // Set of operations that can be marked as relaxed
  std::unordered_set<spv::Op, hasher> closure_ops_;

  // Set of ids of all relaxed instructions
  std::unordered_set<uint32_t> relaxed_ids_set_;

  // Ids of all converted instructions
  std::unordered_set<uint32_t> converted_ids_;
};

}  // namespace opt
}  // namespace spvtools

#endif  // LIBSPIRV_OPT_CONVERT_TO_HALF_PASS_H_
