/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "fuzzing/operation_signatures/OperationSignatureUtils.h"

namespace android {
namespace nn {
namespace fuzzing_test {

// For pooling ops with explicit padding.
static void poolingExplicitOpConstructor(TestOperandType, uint32_t rank, RandomOperation* op) {
    NN_FUZZER_CHECK(rank == 4);

    // Parameters
    int32_t paddingLeft = op->inputs[1]->value<int32_t>();
    int32_t paddingRight = op->inputs[2]->value<int32_t>();
    int32_t paddingTop = op->inputs[3]->value<int32_t>();
    int32_t paddingBottom = op->inputs[4]->value<int32_t>();
    int32_t strideWidth = op->inputs[5]->value<int32_t>();
    int32_t strideHeight = op->inputs[6]->value<int32_t>();
    auto filterWidth = op->inputs[7]->value<RandomVariable>();
    auto filterHeight = op->inputs[8]->value<RandomVariable>();
    bool useNchw = false;
    if (op->inputs.size() > 10) useNchw = op->inputs[10]->value<bool8>();
    int heightIndex = useNchw ? 2 : 1;
    int widthIndex = useNchw ? 3 : 2;
    int channelIndex = useNchw ? 1 : 3;

    // Input, [batch, height_in, width_in, channel]
    op->inputs[0]->dimensions = {RandomVariableType::FREE, RandomVariableType::FREE,
                                 RandomVariableType::FREE, RandomVariableType::FREE};

    // Output, [batch, height_out, width_out, channel]
    op->outputs[0]->dimensions.resize(4);

    // batch and channel
    op->outputs[0]->dimensions[0] = op->inputs[0]->dimensions[0];
    op->outputs[0]->dimensions[channelIndex] = op->inputs[0]->dimensions[channelIndex];

    // height
    explicitPadding(op->inputs[0]->dimensions[heightIndex], filterHeight, strideHeight,
                    /*dilation=*/1, paddingTop, paddingBottom,
                    &op->outputs[0]->dimensions[heightIndex]);

    // width
    explicitPadding(op->inputs[0]->dimensions[widthIndex], filterWidth, strideWidth, /*dilation=*/1,
                    paddingLeft, paddingRight, &op->outputs[0]->dimensions[widthIndex]);

    setSameQuantization(op->outputs[0], op->inputs[0]);
}

// For pooling ops with implicit padding.
static void poolingImplicitOpConstructor(TestOperandType, uint32_t rank, RandomOperation* op) {
    NN_FUZZER_CHECK(rank == 4);

    // Parameters
    int32_t paddingScheme = op->inputs[1]->value<int32_t>();
    int32_t strideWidth = op->inputs[2]->value<int32_t>();
    int32_t strideHeight = op->inputs[3]->value<int32_t>();
    auto filterWidth = op->inputs[4]->value<RandomVariable>();
    auto filterHeight = op->inputs[5]->value<RandomVariable>();
    bool useNchw = false;
    if (op->inputs.size() > 7) useNchw = op->inputs[7]->value<bool8>();
    int heightIndex = useNchw ? 2 : 1;
    int widthIndex = useNchw ? 3 : 2;
    int channelIndex = useNchw ? 1 : 3;

    // Input, [batch, height_in, width_in, channel]
    op->inputs[0]->dimensions = {RandomVariableType::FREE, RandomVariableType::FREE,
                                 RandomVariableType::FREE, RandomVariableType::FREE};

    // Output, [batch, height_out, width_out, channel]
    op->outputs[0]->dimensions.resize(4);

    // batch and channel
    op->outputs[0]->dimensions[0] = op->inputs[0]->dimensions[0];
    op->outputs[0]->dimensions[channelIndex] = op->inputs[0]->dimensions[channelIndex];

    // height and width
    implicitPadding(op->inputs[0]->dimensions[heightIndex], filterHeight, strideHeight,
                    /*dilation=*/1, paddingScheme, &op->outputs[0]->dimensions[heightIndex]);
    implicitPadding(op->inputs[0]->dimensions[widthIndex], filterWidth, strideWidth,
                    /*dilation=*/1, paddingScheme, &op->outputs[0]->dimensions[widthIndex]);

    setSameQuantization(op->outputs[0], op->inputs[0]);
}

#define DEFINE_POOLING_SIGNATURE(op, ver, ...)                                    \
    DEFINE_OPERATION_SIGNATURE(op##_explicit_##ver){                              \
            .opType = TestOperationType::op,                                      \
            .supportedDataTypes = {__VA_ARGS__},                                  \
            .supportedRanks = {4},                                                \
            .version = TestHalVersion::ver,                                       \
            .inputs =                                                             \
                    {                                                             \
                            INPUT_DEFAULT,                                        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            RANDOM_INT_RANGE(1, 4),                               \
                            RANDOM_INT_RANGE(1, 4),                               \
                            PARAMETER_CHOICE(TestOperandType::INT32, 0, 1, 2, 3), \
                    },                                                            \
            .outputs = {OUTPUT_DEFAULT},                                          \
            .constructor = poolingExplicitOpConstructor};                         \
    DEFINE_OPERATION_SIGNATURE(op##_implicit_##ver){                              \
            .opType = TestOperationType::op,                                      \
            .supportedDataTypes = {__VA_ARGS__},                                  \
            .supportedRanks = {4},                                                \
            .version = TestHalVersion::ver,                                       \
            .inputs =                                                             \
                    {                                                             \
                            INPUT_DEFAULT,                                        \
                            PARAMETER_CHOICE(TestOperandType::INT32, 1, 2),       \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            RANDOM_INT_RANGE(1, 4),                               \
                            RANDOM_INT_RANGE(1, 4),                               \
                            PARAMETER_CHOICE(TestOperandType::INT32, 0, 1, 2, 3), \
                    },                                                            \
            .outputs = {OUTPUT_DEFAULT},                                          \
            .constructor = poolingImplicitOpConstructor};

DEFINE_POOLING_SIGNATURE(AVERAGE_POOL_2D, V1_0, TestOperandType::TENSOR_FLOAT32,
                         TestOperandType::TENSOR_QUANT8_ASYMM);
DEFINE_POOLING_SIGNATURE(L2_POOL_2D, V1_0, TestOperandType::TENSOR_FLOAT32);
DEFINE_POOLING_SIGNATURE(MAX_POOL_2D, V1_0, TestOperandType::TENSOR_FLOAT32,
                         TestOperandType::TENSOR_QUANT8_ASYMM);

DEFINE_POOLING_SIGNATURE(AVERAGE_POOL_2D, V1_2, TestOperandType::TENSOR_FLOAT16);
DEFINE_POOLING_SIGNATURE(L2_POOL_2D, V1_2, TestOperandType::TENSOR_FLOAT16);
DEFINE_POOLING_SIGNATURE(MAX_POOL_2D, V1_2, TestOperandType::TENSOR_FLOAT16);

#define DEFINE_POOLING_WITH_LAYOUT_SIGNATURE(op, ver, ...)                        \
    DEFINE_OPERATION_SIGNATURE(op##_explicit_layout_##ver){                       \
            .opType = TestOperationType::op,                                      \
            .supportedDataTypes = {__VA_ARGS__},                                  \
            .supportedRanks = {4},                                                \
            .version = TestHalVersion::ver,                                       \
            .inputs =                                                             \
                    {                                                             \
                            INPUT_DEFAULT,                                        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            RANDOM_INT_RANGE(1, 4),                               \
                            RANDOM_INT_RANGE(1, 4),                               \
                            PARAMETER_CHOICE(TestOperandType::INT32, 0, 1, 2, 3), \
                            PARAMETER_CHOICE(TestOperandType::BOOL, true, false), \
                    },                                                            \
            .outputs = {OUTPUT_DEFAULT},                                          \
            .constructor = poolingExplicitOpConstructor};                         \
    DEFINE_OPERATION_SIGNATURE(op##_implicit_layout_##ver){                       \
            .opType = TestOperationType::op,                                      \
            .supportedDataTypes = {__VA_ARGS__},                                  \
            .supportedRanks = {4},                                                \
            .version = TestHalVersion::ver,                                       \
            .inputs =                                                             \
                    {                                                             \
                            INPUT_DEFAULT,                                        \
                            PARAMETER_CHOICE(TestOperandType::INT32, 1, 2),       \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            PARAMETER_RANGE(TestOperandType::INT32, 1, 3),        \
                            RANDOM_INT_RANGE(1, 4),                               \
                            RANDOM_INT_RANGE(1, 4),                               \
                            PARAMETER_CHOICE(TestOperandType::INT32, 0, 1, 2, 3), \
                            PARAMETER_CHOICE(TestOperandType::BOOL, true, false), \
                    },                                                            \
            .outputs = {OUTPUT_DEFAULT},                                          \
            .constructor = poolingImplicitOpConstructor};

DEFINE_POOLING_WITH_LAYOUT_SIGNATURE(AVERAGE_POOL_2D, V1_2, TestOperandType::TENSOR_FLOAT32,
                                     TestOperandType::TENSOR_FLOAT16,
                                     TestOperandType::TENSOR_QUANT8_ASYMM);
DEFINE_POOLING_WITH_LAYOUT_SIGNATURE(L2_POOL_2D, V1_2, TestOperandType::TENSOR_FLOAT32,
                                     TestOperandType::TENSOR_FLOAT16);
DEFINE_POOLING_WITH_LAYOUT_SIGNATURE(MAX_POOL_2D, V1_2, TestOperandType::TENSOR_FLOAT32,
                                     TestOperandType::TENSOR_FLOAT16,
                                     TestOperandType::TENSOR_QUANT8_ASYMM);
DEFINE_POOLING_WITH_LAYOUT_SIGNATURE(AVERAGE_POOL_2D, V1_3,
                                     TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED);
DEFINE_POOLING_WITH_LAYOUT_SIGNATURE(MAX_POOL_2D, V1_3,
                                     TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED);

}  // namespace fuzzing_test
}  // namespace nn
}  // namespace android
