/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define LOG_TAG "Operations"

#include "BidirectionalSequenceRNN.h"

#include <algorithm>
#include <utility>
#include <vector>

#include "OperationResolver.h"
#include "RNN.h"

namespace android {
namespace nn {
namespace bidirectional_sequence_rnn {

#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
namespace {

template <typename T>
void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
    const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
    const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
    const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
    for (uint32_t f = 0; f < firstDimSize; ++f) {
        for (uint32_t s = 0; s < secondDimSize; ++s) {
            for (uint32_t i = 0; i < inputSize; ++i) {
                const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
                const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
                output[outputIndex] = input[inputIndex];
            }
        }
    }
}

Shape removeFirstDim(const Shape& input) {
    Shape output = input;
    output.dimensions.resize(input.dimensions.size() - 1);
    for (size_t i = 0; i < input.dimensions.size() - 1; ++i) {
        output.dimensions[i] = input.dimensions[i + 1];
    }
    return output;
}

enum class LinkingMode {
    NO_LINKING,
    PARALLEL_LINKING,
    CROSS_LINKING,
};

bool getLinkingMode(IOperationExecutionContext* context, LinkingMode* linkingMode) {
    const bool hasAuxInput = !context->isOmittedInput(kAuxInputTensor);
    const bool hasFwAuxWeights = !context->isOmittedInput(kFwAuxWeightsTensor);
    const bool hasBwAuxWeights = !context->isOmittedInput(kBwAuxWeightsTensor);

    // Three possible configurations for three possible linking modes:
    // 1) NO_LINKING -- no auxiliary tensors at all
    // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
    //    input to the backward network, so the auxiliary weights are omitted.
    // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
    //    auxiliary weights.
    if (!hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
        *linkingMode = LinkingMode::NO_LINKING;
    } else if (hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
        *linkingMode = LinkingMode::PARALLEL_LINKING;
    } else if (hasAuxInput && hasFwAuxWeights && hasBwAuxWeights) {
        *linkingMode = LinkingMode::CROSS_LINKING;
    } else {
        NN_RET_CHECK_FAIL()
                << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
    }

    return true;
}

template <typename T>
bool executeTyped(IOperationExecutionContext* context) {
    const T* input = context->getInputBuffer<T>(kInputTensor);
    Shape inputShape = context->getInputShape(kInputTensor);

    const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
    Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
    const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
    Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
    const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
    const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);

    const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
    Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
    const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
    Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
    const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
    const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);

    const T* auxInput = nullptr;
    const T* fwAuxWeights = nullptr;
    const T* bwAuxWeights = nullptr;
    LinkingMode linkingMode;
    NN_RET_CHECK(getLinkingMode(context, &linkingMode));
    if (linkingMode == LinkingMode::CROSS_LINKING) {
        auxInput = context->getInputBuffer<T>(kAuxInputTensor);
        fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
        bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
    } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
        auxInput = context->getInputBuffer<T>(kAuxInputTensor);
    }
    const bool hasAuxInput = (linkingMode == LinkingMode::CROSS_LINKING ||
                              linkingMode == LinkingMode::PARALLEL_LINKING);
    const bool hasAuxWeights = (linkingMode == LinkingMode::CROSS_LINKING);
    Shape auxInputShape = context->getInputShape(kAuxInputTensor);
    Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
    Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);

    const int32_t activation = context->getInputValue<int32_t>(kActivationParam);
    const bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
    const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);

    T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
    Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
    T* bwOutput = nullptr;
    Shape bwOutputShape;
    if (!mergeOutputs) {
        bwOutputShape = context->getOutputShape(kBwOutputTensor);
        bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
    }

    // If the input tensors are not in time major format, we transpose the first
    // two dimensions, and set input and output pointers to temporary vectors
    // which are transposed back after the RNN is applied.
    std::vector<T> inputTransposed;
    std::vector<T> auxInputTransposed;
    std::vector<T> fwOutputTransposed;
    std::vector<T> bwOutputTransposed;
    if (!timeMajor) {
        // First, resize temporary buffers to accommodate for transposed tensors.
        inputTransposed.resize(getNumberOfElements(inputShape));
        if (hasAuxInput) {
            auxInputTransposed.resize(getNumberOfElements(auxInputShape));
        }
        fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
        if (!mergeOutputs) {
            bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
        }

        // Transpose the input tensors.
        transposeFirstTwoDims(input, inputShape, inputTransposed.data());
        if (hasAuxInput) {
            transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
        }

        // Change input and output pointers to the temporary buffers.
        input = inputTransposed.data();
        if (hasAuxInput) {
            auxInput = auxInputTransposed.data();
        }
        fwOutput = fwOutputTransposed.data();
        if (!mergeOutputs) {
            bwOutput = bwOutputTransposed.data();
        }

        // Swap the first two dimensions in the Shapes to reflect the
        // transposition.
        std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
        if (hasAuxInput) {
            std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
        }
        std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
        if (!mergeOutputs) {
            std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
        }
    }

    const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
    const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
    const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
    uint32_t auxInputSize = 0;
    if (hasAuxInput) {
        auxInputSize = getSizeOfDimension(auxInputShape, 2);
    }
    const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
    const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);

    Shape fixedTimeInputShape = removeFirstDim(inputShape);
    Shape fixedTimeAuxInputShape = auxInputShape;
    if (hasAuxInput) {
        fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
    }

    const T* bwInput = input;
    if (linkingMode == LinkingMode::PARALLEL_LINKING) {
        bwInput = auxInput;
        auxInput = nullptr;
    }

    const bool outputState = (context->getNumOutputs() == kNumOutputsWithState ||
                              context->getNumOutputs() == kNumOutputsMergedWithState);
    T* fwOutputHiddenState = nullptr;
    T* bwOutputHiddenState = nullptr;
    // Create an additional buffer to store a hidden state between steps.
    std::vector<T> tempHiddenState;
    if (outputState) {
        const int delta = mergeOutputs ? 1 : 0;
        fwOutputHiddenState = context->getOutputBuffer<T>(kFwOutputHiddenStateTensor - delta);
        bwOutputHiddenState = context->getOutputBuffer<T>(kBwOutputHiddenStateTensor - delta);
    } else {
        tempHiddenState.resize(std::max(batchSize * fwNumUnits, batchSize * bwNumUnits));
        fwOutputHiddenState = tempHiddenState.data();
        bwOutputHiddenState = tempHiddenState.data();
    }

    // Forward pass
    for (uint32_t i = 0; i < maxTime; ++i) {
        const T* inputBatchPtr = input + i * batchSize * inputSize;
        const T* auxInputBatchPtr = nullptr;
        if (hasAuxWeights) {
            auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
        }
        const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
        T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;

        RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
                        fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
                        fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
                        fwRecurrentWeightsShape, activation, fwOutputBatchStride,
                        /*outputBatchOffset=*/0, fwOutputBatchPtr, fwOutputHiddenState);

        fwHiddenState = fwOutputHiddenState;
    }

    // Backward pass
    for (int i = maxTime - 1; i >= 0; --i) {
        const T* inputBatchPtr = bwInput + i * batchSize * inputSize;
        const T* auxInputBatchPtr = nullptr;
        if (hasAuxWeights) {
            auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
        }
        T* bwOutputBatchPtr;
        uint32_t bwOutputBatchOffset = 0;
        uint32_t bwOutputBatchStride;
        if (mergeOutputs) {
            bwOutputBatchStride = fwNumUnits + bwNumUnits;
            bwOutputBatchOffset = fwNumUnits;
            bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
        } else {
            bwOutputBatchStride = bwNumUnits;
            bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
        }

        RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
                        fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
                        bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
                        bwRecurrentWeightsShape, activation, bwOutputBatchStride,
                        bwOutputBatchOffset, bwOutputBatchPtr, bwOutputHiddenState);

        bwHiddenState = bwOutputHiddenState;
    }

    // If the inputs were in batch major format, transpose data in temporary
    // buffers and write to the output(s).
    if (!timeMajor) {
        transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
                              context->getOutputBuffer<T>(kFwOutputTensor));
        if (!mergeOutputs) {
            transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
                                  context->getOutputBuffer<T>(kBwOutputTensor));
        }
    }
    return true;
}

}  // namespace

bool prepare(IOperationExecutionContext* context) {
    const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
    const int32_t numOutputs = context->getNumOutputs();
    if (mergeOutputs) {
        NN_RET_CHECK(numOutputs == kNumOutputsMerged || numOutputs == kNumOutputsMergedWithState);
    } else {
        NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
    }

    // Check that none of the required inputs are omitted.
    const std::vector<int> requiredInputs = {
            kInputTensor,         kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
            kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
            kBwHiddenStateTensor, kActivationParam, kTimeMajorParam,           kMergeOutputsParam,
    };
    for (const int requiredInput : requiredInputs) {
        NN_RET_CHECK(!context->isOmittedInput(requiredInput))
                << "required input " << requiredInput << " is omitted";
    }

    Shape input = context->getInputShape(kInputTensor);
    Shape fwWeights = context->getInputShape(kFwWeightsTensor);
    Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
    Shape fwBias = context->getInputShape(kFwBiasTensor);
    Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
    Shape bwWeights = context->getInputShape(kBwWeightsTensor);
    Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
    Shape bwBias = context->getInputShape(kBwBiasTensor);
    Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);

    Shape auxInput = context->getInputShape(kAuxInputTensor);
    Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
    Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);

    LinkingMode linkingMode;
    NN_RET_CHECK(getLinkingMode(context, &linkingMode));

    bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
    const uint32_t batchSize =
            timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
    const uint32_t maxTime =
            timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
    const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
    const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
    const uint32_t inputSize = getSizeOfDimension(input, 2);

    NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3u);
    NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2u);
    NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2u);
    NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1u);
    NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2u);
    NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2u);
    NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2u);
    NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1u);
    NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2u);

    NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
    NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
    NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
    NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
    NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
    NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));

    if (linkingMode != LinkingMode::PARALLEL_LINKING) {
        NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
    }
    NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
    NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
    NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
    NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
    NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));

    if (linkingMode == LinkingMode::CROSS_LINKING) {
        NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3u);
        NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2u);
        NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2u);

        NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
        NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
        NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
        NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
        NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
        NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
    } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
        NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3u);

        NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
        NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
        NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 2), getSizeOfDimension(bwWeights, 1));
    }

    Shape fwOutput = context->getOutputShape(kFwOutputTensor);
    fwOutput.dimensions.resize(3);
    fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
    fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
    fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
    NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
    if (!mergeOutputs) {
        Shape bwOutput = context->getOutputShape(kBwOutputTensor);
        bwOutput.dimensions.resize(3);
        bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
        bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
        bwOutput.dimensions[2] = bwNumUnits;
        NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
    }

    const bool outputState =
            (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
    if (outputState) {
        const int delta = mergeOutputs ? 1 : 0;
        NN_RET_CHECK(context->setOutputShape(kFwOutputHiddenStateTensor - delta,
                                             context->getInputShape(kFwHiddenStateTensor)));
        NN_RET_CHECK(context->setOutputShape(kBwOutputHiddenStateTensor - delta,
                                             context->getInputShape(kBwHiddenStateTensor)));
    }

    return true;
}

bool execute(IOperationExecutionContext* context) {
    if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
        executeTyped<_Float16>(context);
    } else {
        executeTyped<float>(context);
    }
    return true;
}
#endif  // NN_INCLUDE_CPU_IMPLEMENTATION

}  // namespace bidirectional_sequence_rnn

NN_REGISTER_OPERATION_DEFAULT_VALIDATION(BIDIRECTIONAL_SEQUENCE_RNN,
                                         bidirectional_sequence_rnn::prepare,
                                         bidirectional_sequence_rnn::execute,
                                         .allowOmittedOperand = true);

}  // namespace nn
}  // namespace android
