// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/opt/reduce_load_size.h"

#include <set>
#include <vector>

#include "source/opt/instruction.h"
#include "source/opt/ir_builder.h"
#include "source/opt/ir_context.h"
#include "source/util/bit_vector.h"

namespace spvtools {
namespace opt {
namespace {
constexpr uint32_t kExtractCompositeIdInIdx = 0;
constexpr uint32_t kVariableStorageClassInIdx = 0;
constexpr uint32_t kLoadPointerInIdx = 0;
}  // namespace

Pass::Status ReduceLoadSize::Process() {
  bool modified = false;

  for (auto& func : *get_module()) {
    func.ForEachInst([&modified, this](Instruction* inst) {
      if (inst->opcode() == spv::Op::OpCompositeExtract) {
        if (ShouldReplaceExtract(inst)) {
          modified |= ReplaceExtract(inst);
        }
      }
    });
  }

  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}

bool ReduceLoadSize::ReplaceExtract(Instruction* inst) {
  assert(inst->opcode() == spv::Op::OpCompositeExtract &&
         "Wrong opcode.  Should be OpCompositeExtract.");
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  analysis::TypeManager* type_mgr = context()->get_type_mgr();
  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();

  uint32_t composite_id =
      inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  Instruction* composite_inst = def_use_mgr->GetDef(composite_id);

  if (composite_inst->opcode() != spv::Op::OpLoad) {
    return false;
  }

  analysis::Type* composite_type = type_mgr->GetType(composite_inst->type_id());
  if (composite_type->kind() == analysis::Type::kVector ||
      composite_type->kind() == analysis::Type::kMatrix) {
    return false;
  }

  Instruction* var = composite_inst->GetBaseAddress();
  if (var == nullptr || var->opcode() != spv::Op::OpVariable) {
    return false;
  }

  spv::StorageClass storage_class = static_cast<spv::StorageClass>(
      var->GetSingleWordInOperand(kVariableStorageClassInIdx));
  switch (storage_class) {
    case spv::StorageClass::Uniform:
    case spv::StorageClass::UniformConstant:
    case spv::StorageClass::Input:
      break;
    default:
      return false;
  }

  // Create a new access chain and load just after the old load.
  // We cannot create the new access chain load in the position of the extract
  // because the storage may have been written to in between.
  InstructionBuilder ir_builder(
      inst->context(), composite_inst,
      IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse);

  uint32_t pointer_to_result_type_id =
      type_mgr->FindPointerToType(inst->type_id(), storage_class);
  assert(pointer_to_result_type_id != 0 &&
         "We did not find the pointer type that we need.");

  analysis::Integer int_type(32, false);
  const analysis::Type* uint32_type = type_mgr->GetRegisteredType(&int_type);
  std::vector<uint32_t> ids;
  for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
    uint32_t index = inst->GetSingleWordInOperand(i);
    const analysis::Constant* index_const =
        const_mgr->GetConstant(uint32_type, {index});
    ids.push_back(const_mgr->GetDefiningInstruction(index_const)->result_id());
  }

  Instruction* new_access_chain = ir_builder.AddAccessChain(
      pointer_to_result_type_id,
      composite_inst->GetSingleWordInOperand(kLoadPointerInIdx), ids);
  Instruction* new_load =
      ir_builder.AddLoad(inst->type_id(), new_access_chain->result_id());

  context()->ReplaceAllUsesWith(inst->result_id(), new_load->result_id());
  context()->KillInst(inst);
  return true;
}

bool ReduceLoadSize::ShouldReplaceExtract(Instruction* inst) {
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  Instruction* op_inst = def_use_mgr->GetDef(
      inst->GetSingleWordInOperand(kExtractCompositeIdInIdx));

  if (op_inst->opcode() != spv::Op::OpLoad) {
    return false;
  }

  auto cached_result = should_replace_cache_.find(op_inst->result_id());
  if (cached_result != should_replace_cache_.end()) {
    return cached_result->second;
  }

  bool all_elements_used = false;
  std::set<uint32_t> elements_used;

  all_elements_used =
      !def_use_mgr->WhileEachUser(op_inst, [&elements_used](Instruction* use) {
        if (use->IsCommonDebugInstr()) return true;
        if (use->opcode() != spv::Op::OpCompositeExtract ||
            use->NumInOperands() == 1) {
          return false;
        }
        elements_used.insert(use->GetSingleWordInOperand(1));
        return true;
      });

  bool should_replace = false;
  if (all_elements_used) {
    should_replace = false;
  } else if (1.0 <= replacement_threshold_) {
    should_replace = true;
  } else {
    analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
    analysis::TypeManager* type_mgr = context()->get_type_mgr();
    analysis::Type* load_type = type_mgr->GetType(op_inst->type_id());
    uint32_t total_size = 1;
    switch (load_type->kind()) {
      case analysis::Type::kArray: {
        const analysis::Constant* size_const =
            const_mgr->FindDeclaredConstant(load_type->AsArray()->LengthId());

        if (size_const) {
          assert(size_const->AsIntConstant());
          total_size = size_const->GetU32();
        } else {
          // The size is spec constant, so it is unknown at this time.  Assume
          // it is very large.
          total_size = UINT32_MAX;
        }
      } break;
      case analysis::Type::kStruct:
        total_size = static_cast<uint32_t>(
            load_type->AsStruct()->element_types().size());
        break;
      default:
        break;
    }
    double percent_used = static_cast<double>(elements_used.size()) /
                          static_cast<double>(total_size);
    should_replace = (percent_used < replacement_threshold_);
  }

  should_replace_cache_[op_inst->result_id()] = should_replace;
  return should_replace;
}

}  // namespace opt
}  // namespace spvtools
