/*
 * Copyright (c) 2022 Arm Limited.
 *
 * SPDX-License-Identifier: MIT
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#pragma once

#include "src/core/NEON/kernels/assembly/winograd.hpp"
#include <memory>
#include <string>

namespace arm_conv {
namespace winograd {

enum class MethodConstraints
{
  None,
  RequiresSVE  = 0x1,
  RequiresSVE2 = 0x2,
  RequiresSME  = 0x4,
  RequiresSME2 = 0x8,
  LargerShape  = 0x10, // Input tensor shape is larger than the output transform tile shape.
};

constexpr inline bool operator!(const MethodConstraints &c)
{
  return c == MethodConstraints::None;
}

constexpr inline MethodConstraints operator|(const MethodConstraints &a, const MethodConstraints &b)
{
  return static_cast<MethodConstraints>(static_cast<unsigned int>(a) | static_cast<unsigned int>(b));
}

constexpr inline MethodConstraints operator&(const MethodConstraints &a, const MethodConstraints &b)
{
  return static_cast<MethodConstraints>(static_cast<unsigned int>(a) & static_cast<unsigned int>(b));
}

inline bool constraints_met(const MethodConstraints &c, const CPUInfo *ci, const ConvolutionArgs &, const WinogradConfig *)
{
  return (
    (!(c & MethodConstraints::RequiresSVE) || (ci->has_sve())) &&
    (!(c & MethodConstraints::RequiresSVE2) || (ci->has_sve2())) &&
    (!(c & MethodConstraints::RequiresSME) || (ci->has_sme())) &&
    (!(c & MethodConstraints::RequiresSME2) || (ci->has_sme2()))
    // Add further constraints here
  );
}

inline bool output_transform_constraints_met(const output_transform::ITransform *transform, const MethodConstraints &c, const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg)
{
  return (
    constraints_met(c, ci, conv_args, cfg) &&
    (!(c & MethodConstraints::LargerShape) || (conv_args.input_shape.rows > transform->get_output_rows() && conv_args.input_shape.cols > transform->get_output_cols()))
  );
}

namespace weight_transform {

template <typename TIn, typename TOut=TIn>
struct TransformImplementation
{
  std::unique_ptr<const ITransform> transform;
  MethodConstraints constraints;

  TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None)
  : transform(transform), constraints(constraints)
  {
  }
};

template <typename TIn, typename TOut=TIn>
const TransformImplementation<TIn, TOut> *implementation_list(void);

}  // namespace weight_transform

namespace input_transform
{

template <typename TIn, typename TOut=TIn>
struct TransformImplementation
{
  std::unique_ptr<const ITransform> transform;
  MethodConstraints constraints;

  TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None)
  : transform(transform), constraints(constraints)
  {
  }
};

template <typename TIn, typename TOut=TIn>
const TransformImplementation<TIn, TOut> *implementation_list(void);

}  // namespace input_transform

namespace output_transform
{

template <typename TIn, typename TOut=TIn>
struct TransformImplementation
{
  std::unique_ptr<const ITransform> transform;
  MethodConstraints constraints;

  TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None)
  : transform(transform), constraints(constraints)
  {
  }
};

template <typename TIn, typename TOut=TIn>
const TransformImplementation<TIn, TOut> *implementation_list(void);

}  // namespace output_transform

namespace{

template <typename T>
constexpr T iceildiv(T num, T den)
{
  return (num + den - 1) / den;
}

template <typename T>
constexpr T iroundup(T num, T den)
{
  return den * iceildiv(num, den);
}

}

template <typename TWeight, typename TWinogradIn>
inline std::vector<const weight_transform::ITransform *> get_weight_transforms(
  const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg
)
{
  // Get target inner tile size
  const auto target_inner_tile_rows = cfg->output_rows == 0 ? 0 : (conv_args.kernel_shape.rows + cfg->output_rows - 1);
  const auto target_inner_tile_cols = cfg->output_cols == 0 ? 0 : (conv_args.kernel_shape.cols + cfg->output_cols - 1);

  std::vector<const weight_transform::ITransform *> weight_transforms;
  for (auto impl = weight_transform::implementation_list<TWeight, TWinogradIn>();
       impl->transform.get() != nullptr; impl++)
  {
    // If this transform supports the requested kernel size, then add it to the
    // list of weight transforms.
    if (
      constraints_met(impl->constraints, ci, conv_args,  cfg) &&
      impl->transform->get_kernel_rows() == conv_args.kernel_shape.rows &&
      impl->transform->get_kernel_cols() == conv_args.kernel_shape.cols &&
      (target_inner_tile_rows == 0 || target_inner_tile_rows == impl->transform->get_transformed_tile_rows()) &&
      (target_inner_tile_cols == 0 || target_inner_tile_cols == impl->transform->get_transformed_tile_cols()) &&
      (cfg->weight_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->weight_transform_filter.c_str()))
    )
    {
      weight_transforms.push_back(impl->transform.get());
    }
  }

  return weight_transforms;
}

template <typename TIn, typename TWinogradIn>
inline std::vector<const input_transform::ITransform *> get_input_transforms(
  const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg
)
{
  // Get target inner tile size
  const auto target_inner_tile_rows = cfg->output_rows == 0 ? 0 : (conv_args.kernel_shape.rows + cfg->output_rows - 1);
  const auto target_inner_tile_cols = cfg->output_cols == 0 ? 0 : (conv_args.kernel_shape.cols + cfg->output_cols - 1);

  std::vector<const input_transform::ITransform *> input_transforms;
  for (auto impl = input_transform::implementation_list<TIn, TWinogradIn>();
       impl->transform.get() != nullptr; impl++)
  {
    if(
      constraints_met(impl->constraints, ci, conv_args,  cfg) &&
      (target_inner_tile_rows == 0 || target_inner_tile_rows == impl->transform->get_input_rows()) &&
      (target_inner_tile_cols == 0 || target_inner_tile_cols == impl->transform->get_input_cols()) &&
      (cfg->input_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->input_transform_filter.c_str()))
    )
    {
      input_transforms.push_back(impl->transform.get());
    }
  }

  return input_transforms;
}

template <typename TWinogradOut, typename TOut>
inline std::vector<const output_transform::ITransform *> get_output_transforms(
  const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg
)
{
  std::vector<const output_transform::ITransform *> output_transforms;
  for (auto impl = output_transform::implementation_list<TWinogradOut, TOut>();
       impl->transform.get() != nullptr; impl++)
  {
    if(
      output_transform_constraints_met(impl->transform.get(), impl->constraints, ci, conv_args,  cfg) &&
      impl->transform->get_kernel_rows() == conv_args.kernel_shape.rows &&
      impl->transform->get_kernel_cols() == conv_args.kernel_shape.cols &&
      (cfg->output_rows == 0 || cfg->output_rows == impl->transform->get_output_rows()) &&
      (cfg->output_cols == 0 || cfg->output_cols == impl->transform->get_output_cols()) &&
      (cfg->output_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->output_transform_filter.c_str()))
    )
    {
      output_transforms.push_back(impl->transform.get());
    }
  }

  return output_transforms;
}

template <typename TIn, typename TWeight, typename TOut, typename TWinogradIn, typename TWinogradOut>
bool get_implementation(
  WinogradImpl &dest,  // Destination for the selected implementation
  const CPUInfo *ci,
  const ConvolutionArgs &conv_args,
  int max_threads,
  bool fast_mode,
  const WinogradConfig *cfg,
  const arm_gemm::GemmConfig *gemm_cfg
)
{
  // Get vectors of valid weight, input and output transforms; then select the
  // combination which produces the biggest output tile.
  const auto weight_transforms = get_weight_transforms<TWeight, TWinogradIn>(ci, conv_args, cfg);
  const auto input_transforms = get_input_transforms<TIn, TWinogradIn>(ci, conv_args, cfg);
  const auto output_transforms = get_output_transforms<TWinogradOut, TOut>(ci, conv_args, cfg);

  // Now attempt to select a complete set of Winograd transformations which can
  // solve the problem. Work backwards from the output transform to find
  // matching input implementations.
  bool success = false;
  for (auto output_transform = output_transforms.cbegin();
       !success && output_transform != output_transforms.cend();
       output_transform++)
  {
    // Look for matching weight transforms, if we find one then we look for
    // matching input transforms.
    for (auto weight_transform = weight_transforms.cbegin();
         !success && weight_transform != weight_transforms.cend();
         weight_transform++)
    {
      // If this weight transform is compatible, then look for a matching input
      // transform
      if ((*output_transform)->get_input_rows() == (*weight_transform)->get_transformed_tile_rows() &&
          (*output_transform)->get_input_cols() == (*weight_transform)->get_transformed_tile_cols())
      {
        for (auto input_transform = input_transforms.cbegin();
             !success && input_transform != input_transforms.cend();
             input_transform++)
        {
          // If the input transform is suitable, then set the configuration and
          // indicate success.
          if ((*input_transform)->get_input_rows() == (*output_transform)->get_input_rows() &&
              (*input_transform)->get_input_cols() == (*output_transform)->get_input_cols())
          {
            dest.output_transform = *output_transform;
            dest.input_transform = *input_transform;
            dest.weight_transform = *weight_transform;
            success = true;
          }
        }
      }
    }
  }

  if (!success)
  {
    return false;
  }

  // If we're able to construct the Winograd elements, then specify the GEMM
  // arguments required to perform the multiply-accumulate step of the
  // convolution.
  const auto n_output_row_tiles = iceildiv(conv_args.output_shape.rows, dest.output_transform->get_output_rows());
  const auto n_output_col_tiles = iceildiv(conv_args.output_shape.cols, dest.output_transform->get_output_cols());
  const auto n_output_patches = n_output_row_tiles * n_output_col_tiles;

  const int n_multis = dest.input_transform->get_input_rows() *
                       dest.input_transform->get_input_cols();

  dest.gemm_args.reset(new arm_gemm::GemmArgs(
    ci,
    n_output_patches,  // M
    conv_args.n_output_channels,  // N
    conv_args.n_input_channels,  // K
    1,  // K-sections
    conv_args.n_batches,  // # Batches
    n_multis,
    false,  // Indirect input
    {},  // No activation
    max_threads,
    fast_mode,
    gemm_cfg
  ));

  // Also provide hints for the Winograd memory layout
  auto &ws = dest.winograd_spec;
  ws.weight_ld_row = iroundup(conv_args.n_output_channels, 4u);
  ws.weight_ld_matrix = conv_args.n_input_channels * ws.weight_ld_row;
  ws.weight_matrix_size_bytes = n_multis * ws.weight_ld_matrix * sizeof(TWinogradIn);

  ws.input_ld_row = iroundup(conv_args.n_input_channels, 4u);
  ws.input_ld_matrix = iroundup(n_output_patches, 4u) * ws.input_ld_row;
  ws.input_ld_batch = n_multis * ws.input_ld_matrix;
  ws.input_matrix_size_bytes = conv_args.n_batches * ws.input_ld_batch * sizeof(TWinogradIn);

  ws.output_ld_row = ws.weight_ld_row;
  ws.output_ld_matrix = n_output_patches * ws.output_ld_row;
  ws.output_ld_batch = n_multis * ws.output_ld_matrix;
  ws.output_matrix_size_bytes = conv_args.n_batches * ws.output_ld_batch * sizeof(TWinogradOut);

  return true;
}

}  // namespace winograd
}  // namespace arm_conv
