// Copyright (c) 2021 Alastair F. Donaldson
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/fuzz/available_instructions.h"
#include "source/fuzz/fuzzer_util.h"

namespace spvtools {
namespace fuzz {

AvailableInstructions::AvailableInstructions(
    opt::IRContext* ir_context,
    const std::function<bool(opt::IRContext*, opt::Instruction*)>& predicate)
    : ir_context_(ir_context) {
  // Consider all global declarations
  for (auto& global : ir_context->module()->types_values()) {
    if (predicate(ir_context, &global)) {
      available_globals_.push_back(&global);
    }
  }

  // Consider every function
  for (auto& function : *ir_context->module()) {
    // Identify those function parameters that satisfy the predicate.
    std::vector<opt::Instruction*> available_params_for_function;
    function.ForEachParam(
        [&predicate, ir_context,
         &available_params_for_function](opt::Instruction* param) {
          if (predicate(ir_context, param)) {
            available_params_for_function.push_back(param);
          }
        });

    // Consider every reachable block in the function.
    auto dominator_analysis = ir_context->GetDominatorAnalysis(&function);
    for (auto& block : function) {
      if (!ir_context->IsReachable(block)) {
        // The block is not reachable.
        continue;
      }
      if (&block == &*function.begin()) {
        // The function entry block is special: only the relevant globals and
        // function parameters are available at its entry point.
        num_available_at_block_entry_.insert(
            {&block,
             static_cast<uint32_t>(available_params_for_function.size() +
                                   available_globals_.size())});
      } else {
        // |block| is not the entry block and is reachable, so it must have an
        // immediate dominator. The number of instructions available on entry to
        // |block| is thus the number of instructions available on entry to the
        // immediate dominator + the number of instructions generated_by_block
        // by the immediate dominator.
        auto immediate_dominator =
            dominator_analysis->ImmediateDominator(&block);
        assert(immediate_dominator != nullptr &&
               "The block is reachable so should have an immediate dominator.");
        assert(generated_by_block_.count(immediate_dominator) != 0 &&
               "Immediate dominator should have already been processed.");
        assert(num_available_at_block_entry_.count(immediate_dominator) != 0 &&
               "Immediate dominator should have already been processed.");
        num_available_at_block_entry_.insert(
            {&block,
             static_cast<uint32_t>(
                 generated_by_block_.at(immediate_dominator).size()) +
                 num_available_at_block_entry_.at(immediate_dominator)});
      }
      // Now consider each instruction in the block.
      std::vector<opt::Instruction*> generated_by_block;
      for (auto& inst : block) {
        assert(num_available_at_block_entry_.count(&block) != 0 &&
               "Block should have already been processed.");
        // The number of available instructions before |inst| is the number
        // available at the start of the block + the number of relevant
        // instructions generated by the block so far.
        num_available_before_instruction_.insert(
            {&inst, num_available_at_block_entry_.at(&block) +
                        static_cast<uint32_t>(generated_by_block.size())});
        if (predicate(ir_context, &inst)) {
          // This instruction satisfies the predicate, so note that it is
          // generated by |block|.
          generated_by_block.push_back(&inst);
        }
      }
      generated_by_block_.emplace(&block, std::move(generated_by_block));
    }
    available_params_.emplace(&function,
                              std::move(available_params_for_function));
  }
}

AvailableInstructions::AvailableBeforeInstruction
AvailableInstructions::GetAvailableBeforeInstruction(
    opt::Instruction* inst) const {
  assert(num_available_before_instruction_.count(inst) != 0 &&
         "Availability can only be queried for reachable instructions.");
  return {*this, inst};
}

AvailableInstructions::AvailableBeforeInstruction::AvailableBeforeInstruction(
    const AvailableInstructions& available_instructions, opt::Instruction* inst)
    : available_instructions_(available_instructions), inst_(inst) {}

uint32_t AvailableInstructions::AvailableBeforeInstruction::size() const {
  return available_instructions_.num_available_before_instruction_.at(inst_);
}

bool AvailableInstructions::AvailableBeforeInstruction::empty() const {
  return size() == 0;
}

opt::Instruction* AvailableInstructions::AvailableBeforeInstruction::operator[](
    uint32_t index) const {
  assert(index < size() && "Index out of bounds.");

  // First, check the cache to see whether we can return the available
  // instruction in constant time.
  auto cached_result = index_cache.find(index);
  if (cached_result != index_cache.end()) {
    return cached_result->second;
  }

  // Next check whether the index falls into the global region.
  if (index < available_instructions_.available_globals_.size()) {
    auto result = available_instructions_.available_globals_[index];
    index_cache.insert({index, result});
    return result;
  }

  auto block = available_instructions_.ir_context_->get_instr_block(inst_);
  auto function = block->GetParent();

  // Next check whether the index falls into the available instructions that
  // correspond to function parameters.
  if (index <
      available_instructions_.available_globals_.size() +
          available_instructions_.available_params_.at(function).size()) {
    auto result = available_instructions_.available_params_.at(
        function)[index - available_instructions_.available_globals_.size()];
    index_cache.insert({index, result});
    return result;
  }

  auto dominator_analysis =
      available_instructions_.ir_context_->GetDominatorAnalysis(function);

  // Now the expensive part (which is why we have the cache): walk the dominator
  // tree backwards starting from the block containing |inst_| until we get to
  // the block in which the instruction corresponding to |index| exists.
  for (auto* ancestor = block; true;
       ancestor = dominator_analysis->ImmediateDominator(ancestor)) {
    uint32_t num_available_at_ancestor_entry =
        available_instructions_.num_available_at_block_entry_.at(ancestor);
    if (index_cache.count(num_available_at_ancestor_entry) == 0) {
      // This is the first time we have traversed this block, so we populate the
      // cache with the index of each instruction, so that if a future index
      // query relates to indices associated with this block we can return the
      // result in constant time.
      auto& generated_by_ancestor =
          available_instructions_.generated_by_block_.at(ancestor);
      for (uint32_t local_index = 0; local_index < generated_by_ancestor.size();
           local_index++) {
        index_cache.insert({num_available_at_ancestor_entry + local_index,
                            generated_by_ancestor[local_index]});
      }
    }
    if (index >= num_available_at_ancestor_entry) {
      // This block contains the instruction we want, so by now it will be in
      // the cache.
      return index_cache.at(index);
    }
    assert(ancestor != &*function->begin() &&
           "By construction we should find a block associated with the index.");
  }

  assert(false && "Unreachable.");
  return nullptr;
}

}  // namespace fuzz
}  // namespace spvtools
