// Copyright (c) 2023 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/opt/trim_capabilities_pass.h"

#include <algorithm>
#include <array>
#include <cassert>
#include <functional>
#include <optional>
#include <queue>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "source/enum_set.h"
#include "source/enum_string_mapping.h"
#include "source/opt/ir_context.h"
#include "source/opt/reflect.h"
#include "source/spirv_target_env.h"
#include "source/util/string_utils.h"

namespace spvtools {
namespace opt {

namespace {
constexpr uint32_t kOpTypeFloatSizeIndex = 0;
constexpr uint32_t kOpTypePointerStorageClassIndex = 0;
constexpr uint32_t kTypeArrayTypeIndex = 0;
constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
constexpr uint32_t kTypePointerTypeIdInIndex = 1;
constexpr uint32_t kOpTypeIntSizeIndex = 0;
constexpr uint32_t kOpTypeImageDimIndex = 1;
constexpr uint32_t kOpTypeImageArrayedIndex = kOpTypeImageDimIndex + 2;
constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
constexpr uint32_t kOpImageReadImageIndex = 0;
constexpr uint32_t kOpImageSparseReadImageIndex = 0;

// DFS visit of the type defined by `instruction`.
// If `condition` is true, children of the current node are visited.
// If `condition` is false, the children of the current node are ignored.
template <class UnaryPredicate>
static void DFSWhile(const Instruction* instruction, UnaryPredicate condition) {
  std::stack<uint32_t> instructions_to_visit;
  instructions_to_visit.push(instruction->result_id());
  const auto* def_use_mgr = instruction->context()->get_def_use_mgr();

  while (!instructions_to_visit.empty()) {
    const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top());
    instructions_to_visit.pop();

    if (!condition(item)) {
      continue;
    }

    if (item->opcode() == spv::Op::OpTypePointer) {
      instructions_to_visit.push(
          item->GetSingleWordInOperand(kTypePointerTypeIdInIndex));
      continue;
    }

    if (item->opcode() == spv::Op::OpTypeMatrix ||
        item->opcode() == spv::Op::OpTypeVector ||
        item->opcode() == spv::Op::OpTypeArray ||
        item->opcode() == spv::Op::OpTypeRuntimeArray) {
      instructions_to_visit.push(
          item->GetSingleWordInOperand(kTypeArrayTypeIndex));
      continue;
    }

    if (item->opcode() == spv::Op::OpTypeStruct) {
      item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
        instructions_to_visit.push(*op_id);
      });
      continue;
    }
  }
}

// Walks the type defined by `instruction` (OpType* only).
// Returns `true` if any call to `predicate` with the type/subtype returns true.
template <class UnaryPredicate>
static bool AnyTypeOf(const Instruction* instruction,
                      UnaryPredicate predicate) {
  assert(IsTypeInst(instruction->opcode()) &&
         "AnyTypeOf called with a non-type instruction.");

  bool found_one = false;
  DFSWhile(instruction, [&found_one, predicate](const Instruction* node) {
    if (found_one || predicate(node)) {
      found_one = true;
      return false;
    }

    return true;
  });
  return found_one;
}

static bool is16bitType(const Instruction* instruction) {
  if (instruction->opcode() != spv::Op::OpTypeInt &&
      instruction->opcode() != spv::Op::OpTypeFloat) {
    return false;
  }

  return instruction->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16;
}

static bool Has16BitCapability(const FeatureManager* feature_manager) {
  const CapabilitySet& capabilities = feature_manager->GetCapabilities();
  return capabilities.contains(spv::Capability::Float16) ||
         capabilities.contains(spv::Capability::Int16);
}

}  // namespace

// ============== Begin opcode handler implementations. =======================
//
// Adding support for a new capability should only require adding a new handler,
// and updating the
// kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
//
// Handler names follow the following convention:
//  Handler_<Opcode>_<Capability>()

static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
    const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpTypeFloat &&
         "This handler only support OpTypeFloat opcodes.");

  const uint32_t size =
      instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
  return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
}

static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
    const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpTypeFloat &&
         "This handler only support OpTypeFloat opcodes.");

  const uint32_t size =
      instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
  return size == 64 ? std::optional(spv::Capability::Float64) : std::nullopt;
}

static std::optional<spv::Capability>
Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpTypePointer &&
         "This handler only support OpTypePointer opcodes.");

  // This capability is only required if the variable has an Input/Output
  // storage class.
  spv::StorageClass storage_class = spv::StorageClass(
      instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
  if (storage_class != spv::StorageClass::Input &&
      storage_class != spv::StorageClass::Output) {
    return std::nullopt;
  }

  if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
    return std::nullopt;
  }

  return AnyTypeOf(instruction, is16bitType)
             ? std::optional(spv::Capability::StorageInputOutput16)
             : std::nullopt;
}

static std::optional<spv::Capability>
Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpTypePointer &&
         "This handler only support OpTypePointer opcodes.");

  // This capability is only required if the variable has a PushConstant storage
  // class.
  spv::StorageClass storage_class = spv::StorageClass(
      instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
  if (storage_class != spv::StorageClass::PushConstant) {
    return std::nullopt;
  }

  if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
    return std::nullopt;
  }

  return AnyTypeOf(instruction, is16bitType)
             ? std::optional(spv::Capability::StoragePushConstant16)
             : std::nullopt;
}

static std::optional<spv::Capability>
Handler_OpTypePointer_StorageUniformBufferBlock16(
    const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpTypePointer &&
         "This handler only support OpTypePointer opcodes.");

  // This capability is only required if the variable has a Uniform storage
  // class.
  spv::StorageClass storage_class = spv::StorageClass(
      instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
  if (storage_class != spv::StorageClass::Uniform) {
    return std::nullopt;
  }

  if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
    return std::nullopt;
  }

  const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
  const bool matchesCondition =
      AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
        if (!decoration_mgr->HasDecoration(item->result_id(),
                                           spv::Decoration::BufferBlock)) {
          return false;
        }

        return AnyTypeOf(item, is16bitType);
      });

  return matchesCondition
             ? std::optional(spv::Capability::StorageUniformBufferBlock16)
             : std::nullopt;
}

static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
    const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpTypePointer &&
         "This handler only support OpTypePointer opcodes.");

  // This capability is only required if the variable has a Uniform storage
  // class.
  spv::StorageClass storage_class = spv::StorageClass(
      instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
  if (storage_class != spv::StorageClass::Uniform) {
    return std::nullopt;
  }

  const auto* feature_manager = instruction->context()->get_feature_mgr();
  if (!Has16BitCapability(feature_manager)) {
    return std::nullopt;
  }

  const bool hasBufferBlockCapability =
      feature_manager->GetCapabilities().contains(
          spv::Capability::StorageUniformBufferBlock16);
  const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
  bool found16bitType = false;

  DFSWhile(instruction, [decoration_mgr, hasBufferBlockCapability,
                         &found16bitType](const Instruction* item) {
    if (found16bitType) {
      return false;
    }

    if (hasBufferBlockCapability &&
        decoration_mgr->HasDecoration(item->result_id(),
                                      spv::Decoration::BufferBlock)) {
      return false;
    }

    if (is16bitType(item)) {
      found16bitType = true;
      return false;
    }

    return true;
  });

  return found16bitType ? std::optional(spv::Capability::StorageUniform16)
                        : std::nullopt;
}

static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
    const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpTypeInt &&
         "This handler only support OpTypeInt opcodes.");

  const uint32_t size =
      instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
  return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
}

static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
    const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpTypeInt &&
         "This handler only support OpTypeInt opcodes.");

  const uint32_t size =
      instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
  return size == 64 ? std::optional(spv::Capability::Int64) : std::nullopt;
}

static std::optional<spv::Capability> Handler_OpTypeImage_ImageMSArray(
    const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpTypeImage &&
         "This handler only support OpTypeImage opcodes.");

  const uint32_t arrayed =
      instruction->GetSingleWordInOperand(kOpTypeImageArrayedIndex);
  const uint32_t ms = instruction->GetSingleWordInOperand(kOpTypeImageMSIndex);
  const uint32_t sampled =
      instruction->GetSingleWordInOperand(kOpTypeImageSampledIndex);

  return arrayed == 1 && sampled == 2 && ms == 1
             ? std::optional(spv::Capability::ImageMSArray)
             : std::nullopt;
}

static std::optional<spv::Capability>
Handler_OpImageRead_StorageImageReadWithoutFormat(
    const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpImageRead &&
         "This handler only support OpImageRead opcodes.");
  const auto* def_use_mgr = instruction->context()->get_def_use_mgr();

  const uint32_t image_index =
      instruction->GetSingleWordInOperand(kOpImageReadImageIndex);
  const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
  const Instruction* type = def_use_mgr->GetDef(type_index);
  const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
  const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);

  const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
  const bool requires_capability_for_unknown =
      spv::Dim(dim) != spv::Dim::SubpassData;
  return is_unknown && requires_capability_for_unknown
             ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
             : std::nullopt;
}

static std::optional<spv::Capability>
Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
    const Instruction* instruction) {
  assert(instruction->opcode() == spv::Op::OpImageSparseRead &&
         "This handler only support OpImageSparseRead opcodes.");
  const auto* def_use_mgr = instruction->context()->get_def_use_mgr();

  const uint32_t image_index =
      instruction->GetSingleWordInOperand(kOpImageSparseReadImageIndex);
  const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
  const Instruction* type = def_use_mgr->GetDef(type_index);
  const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);

  return spv::ImageFormat(format) == spv::ImageFormat::Unknown
             ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
             : std::nullopt;
}

// Opcode of interest to determine capabilities requirements.
constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 12> kOpcodeHandlers{{
    // clang-format off
    {spv::Op::OpImageRead,         Handler_OpImageRead_StorageImageReadWithoutFormat},
    {spv::Op::OpImageSparseRead,   Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
    {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float16 },
    {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float64 },
    {spv::Op::OpTypeImage,         Handler_OpTypeImage_ImageMSArray},
    {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int16 },
    {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int64 },
    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageInputOutput16},
    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StoragePushConstant16},
    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniformBufferBlock16},
    // clang-format on
}};

// ==============  End opcode handler implementations.  =======================

namespace {
ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
                                    const AssemblyGrammar& grammar) {
  ExtensionSet output;
  const spv_operand_desc_t* desc = nullptr;
  for (auto capability : capabilities) {
    if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
                                             static_cast<uint32_t>(capability),
                                             &desc)) {
      continue;
    }

    for (uint32_t i = 0; i < desc->numExtensions; ++i) {
      output.insert(desc->extensions[i]);
    }
  }

  return output;
}
}  // namespace

TrimCapabilitiesPass::TrimCapabilitiesPass()
    : supportedCapabilities_(
          TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
          TrimCapabilitiesPass::kSupportedCapabilities.cend()),
      forbiddenCapabilities_(
          TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
          TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
      untouchableCapabilities_(
          TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
          TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
      opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}

void TrimCapabilitiesPass::addInstructionRequirementsForOpcode(
    spv::Op opcode, CapabilitySet* capabilities,
    ExtensionSet* extensions) const {
  // Ignoring OpBeginInvocationInterlockEXT and OpEndInvocationInterlockEXT
  // because they have three possible capabilities, only one of which is needed
  if (opcode == spv::Op::OpBeginInvocationInterlockEXT ||
      opcode == spv::Op::OpEndInvocationInterlockEXT) {
    return;
  }

  const spv_opcode_desc_t* desc = {};
  auto result = context()->grammar().lookupOpcode(opcode, &desc);
  if (result != SPV_SUCCESS) {
    return;
  }

  addSupportedCapabilitiesToSet(desc, capabilities);
  addSupportedExtensionsToSet(desc, extensions);
}

void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
    const Operand& operand, CapabilitySet* capabilities,
    ExtensionSet* extensions) const {
  // No supported capability relies on a 2+-word operand.
  if (operand.words.size() != 1) {
    return;
  }

  // No supported capability relies on a literal string operand or an ID.
  if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING ||
      operand.type == SPV_OPERAND_TYPE_ID ||
      operand.type == SPV_OPERAND_TYPE_RESULT_ID) {
    return;
  }

  // If the Vulkan memory model is declared and any instruction uses Device
  // scope, the VulkanMemoryModelDeviceScope capability must be declared. This
  // rule cannot be covered by the grammar, so must be checked explicitly.
  if (operand.type == SPV_OPERAND_TYPE_SCOPE_ID) {
    const Instruction* memory_model = context()->GetMemoryModel();
    if (memory_model && memory_model->GetSingleWordInOperand(1u) ==
                            uint32_t(spv::MemoryModel::Vulkan)) {
      capabilities->insert(spv::Capability::VulkanMemoryModelDeviceScope);
    }
  }

  // case 1: Operand is a single value, can directly lookup.
  if (!spvOperandIsConcreteMask(operand.type)) {
    const spv_operand_desc_t* desc = {};
    auto result = context()->grammar().lookupOperand(operand.type,
                                                     operand.words[0], &desc);
    if (result != SPV_SUCCESS) {
      return;
    }
    addSupportedCapabilitiesToSet(desc, capabilities);
    addSupportedExtensionsToSet(desc, extensions);
    return;
  }

  // case 2: operand can be a bitmask, we need to decompose the lookup.
  for (uint32_t i = 0; i < 32; i++) {
    const uint32_t mask = (1 << i) & operand.words[0];
    if (!mask) {
      continue;
    }

    const spv_operand_desc_t* desc = {};
    auto result = context()->grammar().lookupOperand(operand.type, mask, &desc);
    if (result != SPV_SUCCESS) {
      continue;
    }

    addSupportedCapabilitiesToSet(desc, capabilities);
    addSupportedExtensionsToSet(desc, extensions);
  }
}

void TrimCapabilitiesPass::addInstructionRequirements(
    Instruction* instruction, CapabilitySet* capabilities,
    ExtensionSet* extensions) const {
  // Ignoring OpCapability and OpExtension instructions.
  if (instruction->opcode() == spv::Op::OpCapability ||
      instruction->opcode() == spv::Op::OpExtension) {
    return;
  }

  addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
                                      extensions);

  // Second case: one of the opcode operand is gated by a capability.
  const uint32_t operandCount = instruction->NumOperands();
  for (uint32_t i = 0; i < operandCount; i++) {
    addInstructionRequirementsForOperand(instruction->GetOperand(i),
                                         capabilities, extensions);
  }

  // Last case: some complex logic needs to be run to determine capabilities.
  auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
  for (auto it = begin; it != end; it++) {
    const OpcodeHandler handler = it->second;
    auto result = handler(instruction);
    if (!result.has_value()) {
      continue;
    }

    capabilities->insert(*result);
  }
}

void TrimCapabilitiesPass::AddExtensionsForOperand(
    const spv_operand_type_t type, const uint32_t value,
    ExtensionSet* extensions) const {
  const spv_operand_desc_t* desc = nullptr;
  spv_result_t result = context()->grammar().lookupOperand(type, value, &desc);
  if (result != SPV_SUCCESS) {
    return;
  }
  addSupportedExtensionsToSet(desc, extensions);
}

std::pair<CapabilitySet, ExtensionSet>
TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
  CapabilitySet required_capabilities;
  ExtensionSet required_extensions;

  get_module()->ForEachInst([&](Instruction* instruction) {
    addInstructionRequirements(instruction, &required_capabilities,
                               &required_extensions);
  });

  for (auto capability : required_capabilities) {
    AddExtensionsForOperand(SPV_OPERAND_TYPE_CAPABILITY,
                            static_cast<uint32_t>(capability),
                            &required_extensions);
  }

#if !defined(NDEBUG)
  // Debug only. We check the outputted required capabilities against the
  // supported capabilities list. The supported capabilities list is useful for
  // API users to quickly determine if they can use the pass or not. But this
  // list has to remain up-to-date with the pass code. If we can detect a
  // capability as required, but it's not listed, it means the list is
  // out-of-sync. This method is not ideal, but should cover most cases.
  {
    for (auto capability : required_capabilities) {
      assert(supportedCapabilities_.contains(capability) &&
             "Module is using a capability that is not listed as supported.");
    }
  }
#endif

  return std::make_pair(std::move(required_capabilities),
                        std::move(required_extensions));
}

Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
    const CapabilitySet& required_capabilities) const {
  const FeatureManager* feature_manager = context()->get_feature_mgr();
  CapabilitySet capabilities_to_trim;
  for (auto capability : feature_manager->GetCapabilities()) {
    // Some capabilities cannot be safely removed. Leaving them untouched.
    if (untouchableCapabilities_.contains(capability)) {
      continue;
    }

    // If the capability is unsupported, don't trim it.
    if (!supportedCapabilities_.contains(capability)) {
      continue;
    }

    if (required_capabilities.contains(capability)) {
      continue;
    }

    capabilities_to_trim.insert(capability);
  }

  for (auto capability : capabilities_to_trim) {
    context()->RemoveCapability(capability);
  }

  return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
                                          : Pass::Status::SuccessWithChange;
}

Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
    const ExtensionSet& required_extensions) const {
  const auto supported_extensions =
      getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());

  bool modified_module = false;
  for (auto extension : supported_extensions) {
    if (required_extensions.contains(extension)) {
      continue;
    }

    if (context()->RemoveExtension(extension)) {
      modified_module = true;
    }
  }

  return modified_module ? Pass::Status::SuccessWithChange
                         : Pass::Status::SuccessWithoutChange;
}

bool TrimCapabilitiesPass::HasForbiddenCapabilities() const {
  // EnumSet.HasAnyOf returns `true` if the given set is empty.
  if (forbiddenCapabilities_.size() == 0) {
    return false;
  }

  const auto& capabilities = context()->get_feature_mgr()->GetCapabilities();
  return capabilities.HasAnyOf(forbiddenCapabilities_);
}

Pass::Status TrimCapabilitiesPass::Process() {
  if (HasForbiddenCapabilities()) {
    return Status::SuccessWithoutChange;
  }

  auto[required_capabilities, required_extensions] =
      DetermineRequiredCapabilitiesAndExtensions();

  Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
  Pass::Status extStatus = TrimUnrequiredExtensions(required_extensions);

  return capStatus == Pass::Status::SuccessWithChange ||
                 extStatus == Pass::Status::SuccessWithChange
             ? Pass::Status::SuccessWithChange
             : Pass::Status::SuccessWithoutChange;
}

}  // namespace opt
}  // namespace spvtools
