// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>

#include <xnnpack.h>
#include <xnnpack/allocator.h>
#include <xnnpack/operator.h>
#include <xnnpack/log.h>
#include <xnnpack/params.h>


static enum xnn_status create_channel_shuffle_nc(
  size_t groups,
  size_t group_channels,
  size_t input_stride,
  size_t output_stride,
  uint32_t flags,
  enum xnn_operator_type operator_type,
  xnn_operator_t* channel_shuffle_op_out)
{
  xnn_operator_t channel_shuffle_op = NULL;
  enum xnn_status status = xnn_status_uninitialized;

  if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
    xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
      xnn_operator_type_to_string(operator_type));
    goto error;
  }

  status = xnn_status_invalid_parameter;

  if (groups <= 1) {
    xnn_log_error(
      "failed to create %s operator with %zu groups: at least two groups required",
      xnn_operator_type_to_string(operator_type), groups);
    goto error;
  }

  if (group_channels == 0) {
    xnn_log_error(
      "failed to create %s operator with %zu group channels: number of group channels must be non-zero",
      xnn_operator_type_to_string(operator_type), group_channels);
    goto error;
  }

  const size_t channels = groups * group_channels;
  if (input_stride < channels) {
    xnn_log_error(
      "failed to create %s operator with input element stride of %zu: "
      "stride must be at least as large as the number of channels (%zux%zu)",
      xnn_operator_type_to_string(operator_type), input_stride, groups, group_channels);
    goto error;
  }

  if (output_stride < channels) {
    xnn_log_error(
      "failed to create %s operator with output element stride of %zu: "
      "stride must be at least as large as the number of channels (%zux%zu)",
      xnn_operator_type_to_string(operator_type), output_stride, groups, group_channels);
    goto error;
  }

  status = xnn_status_out_of_memory;

  channel_shuffle_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
  if (channel_shuffle_op == NULL) {
    xnn_log_error(
      "failed to allocate %zu bytes for %s operator descriptor",
      sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
    goto error;
  }

  channel_shuffle_op->groups = groups;
  channel_shuffle_op->group_channels = group_channels;
  channel_shuffle_op->input_pixel_stride = input_stride;
  channel_shuffle_op->output_pixel_stride = output_stride;

  channel_shuffle_op->type = operator_type;
  channel_shuffle_op->flags = flags;

  channel_shuffle_op->state = xnn_run_state_invalid;

  *channel_shuffle_op_out = channel_shuffle_op;
  return xnn_status_success;

error:
  xnn_delete_operator(channel_shuffle_op);
  return status;
}


enum xnn_status xnn_create_channel_shuffle_nc_x8(
    size_t groups,
    size_t group_channels,
    size_t input_stride,
    size_t output_stride,
    uint32_t flags,
    xnn_operator_t* channel_shuffle_op_out)
{
  return create_channel_shuffle_nc(
    groups,
    group_channels,
    input_stride,
    output_stride,
    flags,
    xnn_operator_type_channel_shuffle_nc_x8,
    channel_shuffle_op_out);
}

enum xnn_status xnn_create_channel_shuffle_nc_x32(
    size_t groups,
    size_t group_channels,
    size_t input_stride,
    size_t output_stride,
    uint32_t flags,
    xnn_operator_t* channel_shuffle_op_out)
{
  return create_channel_shuffle_nc(
    groups,
    group_channels,
    input_stride,
    output_stride,
    flags,
    xnn_operator_type_channel_shuffle_nc_x32,
    channel_shuffle_op_out);
}

static enum xnn_status setup_channel_shuffle_nc(
    xnn_operator_t channel_shuffle_op,
    size_t batch_size,
    const void* input,
    void* output,
    uint32_t log2_element_size,
    const struct zip_parameters zip[restrict XNN_MIN_ELEMENTS(1)])
{
  channel_shuffle_op->state = xnn_run_state_invalid;

  if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
    xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
      xnn_operator_type_to_string(channel_shuffle_op->type));
    return xnn_status_uninitialized;
  }

  if (batch_size == 0) {
    channel_shuffle_op->state = xnn_run_state_skip;
    return xnn_status_success;
  }

  channel_shuffle_op->batch_size = batch_size;
  channel_shuffle_op->input = input;
  channel_shuffle_op->output = output;

  const size_t groups = channel_shuffle_op->groups;
  channel_shuffle_op->context.channel_shuffle = (struct channel_shuffle_context) {
    .x = input,
    .x_stride = channel_shuffle_op->input_pixel_stride << log2_element_size,
    .y = output,
    .y_stride = channel_shuffle_op->output_pixel_stride << log2_element_size,
    .n = channel_shuffle_op->group_channels << log2_element_size,
    .m = groups,
  };
  channel_shuffle_op->compute.type = xnn_parallelization_type_1d;
  channel_shuffle_op->compute.range[0] = batch_size;
  switch (groups) {
    case 2:
      channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
      channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x2;
      break;
    case 3:
      channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
      channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x3;
      break;
    case 4:
      channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
      channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x4;
      break;
    default:
      channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_variable;
      channel_shuffle_op->context.channel_shuffle.variable_ukernel = zip->xm;
      break;
    case 0:
    case 1:
      XNN_UNREACHABLE;
  }
  channel_shuffle_op->state = xnn_run_state_ready;

  return xnn_status_success;
}

enum xnn_status xnn_setup_channel_shuffle_nc_x8(
    xnn_operator_t channel_shuffle_op,
    size_t batch_size,
    const void* input,
    void* output,
    pthreadpool_t threadpool)
{
  if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x8) {
    xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
      xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x8),
      xnn_operator_type_to_string(channel_shuffle_op->type));
    return xnn_status_invalid_parameter;
  }

  return setup_channel_shuffle_nc(
    channel_shuffle_op,
    batch_size,
    input,
    output,
    0 /* log2(sizeof(element)) = log2(sizeof(uint8_t)) */,
    &xnn_params.x8.zip);
}

enum xnn_status xnn_setup_channel_shuffle_nc_x32(
    xnn_operator_t channel_shuffle_op,
    size_t batch_size,
    const void* input,
    void* output,
    pthreadpool_t threadpool)
{
  if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x32) {
    xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
      xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x32),
      xnn_operator_type_to_string(channel_shuffle_op->type));
    return xnn_status_invalid_parameter;
  }

  return setup_channel_shuffle_nc(
    channel_shuffle_op,
    batch_size,
    input,
    output,
    2 /* log2(sizeof(element)) = log2(sizeof(uint32_t)) */,
    &xnn_params.x32.zip);
}
