/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "QLSTM.h"

#include <algorithm>
#include <memory>
#include <vector>

#include "CpuExecutor.h"
#include "OperationsExecutionUtils.h"

#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
#include "QuantUtils.h"
#endif  // NN_INCLUDE_CPU_IMPLEMENTATION

namespace android {
namespace nn {
namespace qlstm {

namespace {

inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
    return context->getInputBuffer(tensor) != nullptr;
}

}  // namespace

bool prepare(IOperationExecutionContext* context) {
    // Check that none of the required inputs are omitted
    const std::vector<int> requiredTensorInputs = {
            kInputTensor,
            kInputToForgetWeightsTensor,
            kInputToCellWeightsTensor,
            kInputToOutputWeightsTensor,
            kRecurrentToForgetWeightsTensor,
            kRecurrentToCellWeightsTensor,
            kRecurrentToOutputWeightsTensor,
            kForgetGateBiasTensor,
            kCellGateBiasTensor,
            kOutputGateBiasTensor,
            kPrevOutputTensor,
            kPrevCellStateTensor,
    };
    for (const int tensor : requiredTensorInputs) {
        NN_RET_CHECK(!context->isOmittedInput(tensor))
                << "required input " << tensor << " is omitted";
    }

    const Shape inputShape = context->getInputShape(kInputTensor);
    const uint32_t inputRank = getNumberOfDimensions(inputShape);
    NN_RET_CHECK_EQ(inputRank, 2u) << "Invalid input tensor rank: " << inputRank;

    const uint32_t batchSize = getSizeOfDimension(inputShape, 0);
    const uint32_t inputSize = getSizeOfDimension(inputShape, 1);

    const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2u);
    NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
    const uint32_t numUnits = getSizeOfDimension(inputToOutputShape, 0);

    const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2u);
    NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numUnits);
    const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);

    if (hasTensor(context, kInputToInputWeightsTensor)) {
        const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2u);
        NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numUnits);
        NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
    }

    const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2u);
    NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numUnits);
    NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
    const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2u);
    NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numUnits);
    NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);

    if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
        const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2u);
        NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numUnits);
        NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
    }

    const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2u);
    NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numUnits);
    NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
    const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2u);
    NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numUnits);
    NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);

    // Make sure the input-gate's parameters are either all present (non-CIFG) or
    // not at all (CIFG).
    const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
                                       hasTensor(context, kRecurrentToInputWeightsTensor)) ||
                                      (!hasTensor(context, kInputToInputWeightsTensor) &&
                                       !hasTensor(context, kRecurrentToInputWeightsTensor));
    NN_RET_CHECK(cifgWeightsAllOrNone);

    if (hasTensor(context, kCellToInputWeightsTensor)) {
        const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1u);
        NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numUnits);
    }

    if (hasTensor(context, kCellToForgetWeightsTensor)) {
        const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1u);
        NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numUnits);
    }

    if (hasTensor(context, kCellToOutputWeightsTensor)) {
        const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1u);
        NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numUnits);
    }

    // Making sure the peephole weights are there all or none.
    const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
    const bool peepholeWeightsAllOrNone =
            ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
             hasTensor(context, kCellToForgetWeightsTensor) &&
             hasTensor(context, kCellToOutputWeightsTensor)) ||
            (!hasTensor(context, kCellToInputWeightsTensor) &&
             !hasTensor(context, kCellToForgetWeightsTensor) &&
             !hasTensor(context, kCellToOutputWeightsTensor));
    NN_RET_CHECK(peepholeWeightsAllOrNone);

    if (!cifgUsed) {
        NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
        const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1u);
        NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numUnits);
    } else {
        NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
                << "Input gate bias tensor is present when CIFG is used";
    }

    const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1u);
    NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numUnits);
    const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1u);
    NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numUnits);
    const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1u);
    NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numUnits);

    if (hasTensor(context, kProjectionWeightsTensor)) {
        const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2u);
        NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
        NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numUnits);
    }

    if (hasTensor(context, kProjectionBiasTensor)) {
        const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1u);
        NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
    }

    const Shape outputStateShape = context->getInputShape(kPrevOutputTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2u);
    NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
    NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
    const Shape cellStateShape = context->getInputShape(kPrevCellStateTensor);
    NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2u);
    NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
    NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numUnits);

    if (hasTensor(context, kInputLayerNormTensor)) {
        const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1u);
        NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numUnits);
    }

    if (hasTensor(context, kForgetLayerNormTensor)) {
        const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1u);
        NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numUnits);
    }

    if (hasTensor(context, kCellLayerNormTensor)) {
        const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1u);
        NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numUnits);
    }

    if (hasTensor(context, kOutputLayerNormTensor)) {
        const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
        NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1u);
        NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numUnits);
    }

    if (cifgUsed) {
        NN_RET_CHECK(!hasTensor(context, kInputLayerNormTensor))
                << "Input layer norm weights tensor is present when CIFG is used";
        const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormTensor) &&
                                                    hasTensor(context, kCellLayerNormTensor) &&
                                                    hasTensor(context, kOutputLayerNormTensor)) ||
                                                   (!hasTensor(context, kForgetLayerNormTensor) &&
                                                    !hasTensor(context, kCellLayerNormTensor) &&
                                                    !hasTensor(context, kOutputLayerNormTensor));
        NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
    } else {
        const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormTensor) &&
                                                hasTensor(context, kForgetLayerNormTensor) &&
                                                hasTensor(context, kCellLayerNormTensor) &&
                                                hasTensor(context, kOutputLayerNormTensor)) ||
                                               (!hasTensor(context, kInputLayerNormTensor) &&
                                                !hasTensor(context, kForgetLayerNormTensor) &&
                                                !hasTensor(context, kCellLayerNormTensor) &&
                                                !hasTensor(context, kOutputLayerNormTensor));
        NN_RET_CHECK(layerNormWeightsAllOrNone);
    }

    const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
    Shape outputShape = context->getOutputShape(kOutputTensor);
    outputShape.dimensions = prevOutputShape.dimensions;

    const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
    Shape cellStateOutShape = context->getOutputShape(kCellStateOutTensor);
    cellStateOutShape.dimensions = prevCellStateShape.dimensions;

    return context->setOutputShape(kOutputStateOutTensor, outputShape) &&
           context->setOutputShape(kCellStateOutTensor, cellStateOutShape) &&
           context->setOutputShape(kOutputTensor, outputShape);
}

#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
bool execute(IOperationExecutionContext* context) {
    // Gets the inputs.
    const Shape inputShape = context->getInputShape(kInputTensor);
    const Shape inputToInputWeightsShape = context->getInputShape(kInputToInputWeightsTensor);
    const Shape recurrentToInputWeightsShape =
            context->getInputShape(kRecurrentToInputWeightsTensor);
    const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
    const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
    const Shape inputToForgetWeightsShape = context->getInputShape(kInputToForgetWeightsTensor);
    const Shape recurrentToForgetWeightsShape =
            context->getInputShape(kRecurrentToForgetWeightsTensor);
    const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
    const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
    const Shape inputToCellWeightsShape = context->getInputShape(kInputToCellWeightsTensor);
    const Shape recurrentToCellWeightsShape = context->getInputShape(kRecurrentToCellWeightsTensor);
    const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
    const Shape inputToOutputWeightsShape = context->getInputShape(kInputToOutputWeightsTensor);
    const Shape recurrentToOutputWeightsShape =
            context->getInputShape(kRecurrentToOutputWeightsTensor);
    const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
    const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
    const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor);
    const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
    const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);

    const uint32_t batchSize = inputShape.dimensions[0];
    const uint32_t inputSize = inputShape.dimensions[1];
    const uint32_t numUnits = inputToOutputWeightsShape.dimensions[0];
    const uint32_t outputSize = recurrentToOutputWeightsShape.dimensions[1];

    const float cellClip = context->getInputValue<float>(kCellClip);
    const float projectionClip = context->getInputValue<float>(kProjectionClip);
    const float inputIntermediateScale = context->getInputValue<float>(kInputIntermediateScale);
    const float forgetIntermediateScale = context->getInputValue<float>(kForgetIntermediateScale);
    const float cellIntermediateScale = context->getInputValue<float>(kCellIntermediateScale);
    const float outputIntermediateScale = context->getInputValue<float>(kOutputIntermediateScale);
    const int8_t hiddenStateZeroPoint = context->getInputValue<int8_t>(kHiddenStateZeroPoint);
    const float hiddenStateScale = context->getInputValue<float>(kHiddenStateScale);

    const int8_t* inputBuffer =
            reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputTensor));

    const int8_t* inputToInputWeightsBuffer =
            reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToInputWeightsTensor));
    const bool useCifg = (inputToInputWeightsBuffer == nullptr);
    const int8_t* recurrentToInputWeightsBuffer = reinterpret_cast<const int8_t*>(
            context->getInputBuffer(kRecurrentToInputWeightsTensor));
    const int16_t* cellToInputBuffer =
            reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToInputWeightsTensor));
    const int16_t* inputLayerNormBuffer =
            reinterpret_cast<const int16_t*>(context->getInputBuffer(kInputLayerNormTensor));
    const int32_t* inputBiasBuffer =
            reinterpret_cast<const int32_t*>(context->getInputBuffer(kInputGateBiasTensor));

    const int8_t* inputToForgetWeightsBuffer =
            reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToForgetWeightsTensor));
    const int8_t* recurrentToForgetWeightsBuffer = reinterpret_cast<const int8_t*>(
            context->getInputBuffer(kRecurrentToForgetWeightsTensor));
    const int16_t* cellToForgetBuffer =
            reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToForgetWeightsTensor));
    const int16_t* forgetLayerNormBuffer =
            reinterpret_cast<const int16_t*>(context->getInputBuffer(kForgetLayerNormTensor));
    const int32_t* forgetBiasBuffer =
            reinterpret_cast<const int32_t*>(context->getInputBuffer(kForgetGateBiasTensor));

    const int8_t* inputToCellWeightsBuffer =
            reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToCellWeightsTensor));
    const int8_t* recurrentToCellWeightsBuffer =
            reinterpret_cast<const int8_t*>(context->getInputBuffer(kRecurrentToCellWeightsTensor));
    const int16_t* cellLayerNormBuffer =
            reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellLayerNormTensor));
    const int32_t* cellBiasBuffer =
            reinterpret_cast<const int32_t*>(context->getInputBuffer(kCellGateBiasTensor));

    const int8_t* inputToOutputWeightsBuffer =
            reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToOutputWeightsTensor));
    const int8_t* recurrentToOutputWeightsBuffer = reinterpret_cast<const int8_t*>(
            context->getInputBuffer(kRecurrentToOutputWeightsTensor));
    const int16_t* cellToOutputBuffer =
            reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToOutputWeightsTensor));
    const int16_t* outputLayerNormBuffer =
            reinterpret_cast<const int16_t*>(context->getInputBuffer(kOutputLayerNormTensor));
    const int32_t* outputBiasBuffer =
            reinterpret_cast<const int32_t*>(context->getInputBuffer(kOutputGateBiasTensor));

    const int8_t* projectionWeightsBuffer =
            reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor));
    const int32_t* projectionBiasBuffer =
            reinterpret_cast<const int32_t*>(context->getInputBuffer(kProjectionBiasTensor));

    const int8_t* prevOutputBuffer =
            reinterpret_cast<const int8_t*>(context->getInputBuffer(kPrevOutputTensor));
    const int16_t* prevCellStateBuffer =
            reinterpret_cast<const int16_t*>(context->getInputBuffer(kPrevCellStateTensor));

    uint8_t* outputStateBuffer =
            reinterpret_cast<uint8_t*>(context->getOutputBuffer(kOutputStateOutTensor));
    int16_t* cellStateBuffer =
            reinterpret_cast<int16_t*>(context->getOutputBuffer(kCellStateOutTensor));
    int8_t* outputBuffer = reinterpret_cast<int8_t*>(context->getOutputBuffer(kOutputTensor));

    // Calculates and decomposes effective scales.
    // This is for optimizing the matmul calculation.
    int cellShift;
    NN_RET_CHECK(CheckedLog2(prevCellStateShape.scale, &cellShift));
    NN_RET_CHECK(cellShift <= -9);

    int32_t inputToInputEffectiveScaleA;
    int32_t inputToInputEffectiveScaleB;
    int32_t recurrentToInputEffectiveScaleA;
    int32_t recurrentToInputEffectiveScaleB;
    int32_t cellToInputEffectiveScaleA;
    int32_t cellToInputEffectiveScaleB;
    if (!useCifg) {
        const float inputToInputEffectiveScale =
                inputToInputWeightsShape.scale * inputShape.scale / inputIntermediateScale;
        NN_RET_CHECK(QuantizeMultiplier(inputToInputEffectiveScale, &inputToInputEffectiveScaleA,
                                        &inputToInputEffectiveScaleB));
        const float recurrentToInputEffectiveScale =
                recurrentToInputWeightsShape.scale * prevOutputShape.scale / inputIntermediateScale;
        NN_RET_CHECK(QuantizeMultiplier(recurrentToInputEffectiveScale,
                                        &recurrentToInputEffectiveScaleA,
                                        &recurrentToInputEffectiveScaleB));
        if (cellToInputBuffer != nullptr) {
            const float cellToInputEffectiveScale =
                    std::pow(2, cellShift) * cellToInputShape.scale / inputIntermediateScale;
            NN_RET_CHECK(QuantizeMultiplier(cellToInputEffectiveScale, &cellToInputEffectiveScaleA,
                                            &cellToInputEffectiveScaleB));
        }
    }

    int32_t inputLayerNormScaleA;
    int32_t inputLayerNormScaleB;
    if (inputLayerNormBuffer != nullptr) {
        NN_RET_CHECK(QuantizeMultiplier(inputLayerNormShape.scale, &inputLayerNormScaleA,
                                        &inputLayerNormScaleB));
    }

    const float inputToForgetEffectiveScale =
            inputToForgetWeightsShape.scale * inputShape.scale / forgetIntermediateScale;
    int32_t inputToForgetEffectiveScaleA;
    int32_t inputToForgetEffectiveScaleB;
    NN_RET_CHECK(QuantizeMultiplier(inputToForgetEffectiveScale, &inputToForgetEffectiveScaleA,
                                    &inputToForgetEffectiveScaleB));
    const float recurrentToForgetEffectiveScale =
            recurrentToForgetWeightsShape.scale * prevOutputShape.scale / forgetIntermediateScale;
    int32_t recurrentToForgetEffectiveScaleA;
    int32_t recurrentToForgetEffectiveScaleB;
    NN_RET_CHECK(QuantizeMultiplier(recurrentToForgetEffectiveScale,
                                    &recurrentToForgetEffectiveScaleA,
                                    &recurrentToForgetEffectiveScaleB));
    int32_t cellToForgetEffectiveScaleA;
    int32_t cellToForgetEffectiveScaleB;
    if (cellToForgetBuffer != nullptr) {
        const float cellToForgetEffectiveScale =
                std::pow(2, cellShift) * cellToForgetShape.scale / forgetIntermediateScale;
        NN_RET_CHECK(QuantizeMultiplier(cellToForgetEffectiveScale, &cellToForgetEffectiveScaleA,
                                        &cellToForgetEffectiveScaleB));
    }
    int32_t forgetLayerNormScaleA;
    int32_t forgetLayerNormScaleB;
    if (forgetLayerNormBuffer != nullptr) {
        NN_RET_CHECK(QuantizeMultiplier(forgetLayerNormShape.scale, &forgetLayerNormScaleA,
                                        &forgetLayerNormScaleB));
    }

    const float inputToCellEffectiveScale =
            inputToCellWeightsShape.scale * inputShape.scale / cellIntermediateScale;
    int32_t inputToCellEffectiveScaleA;
    int32_t inputToCellEffectiveScaleB;
    NN_RET_CHECK(QuantizeMultiplier(inputToCellEffectiveScale, &inputToCellEffectiveScaleA,
                                    &inputToCellEffectiveScaleB));
    const float recurrentToCellEffectiveScale =
            recurrentToCellWeightsShape.scale * prevOutputShape.scale / cellIntermediateScale;
    int32_t recurrentToCellEffectiveScaleA;
    int32_t recurrentToCellEffectiveScaleB;
    NN_RET_CHECK(QuantizeMultiplier(recurrentToCellEffectiveScale, &recurrentToCellEffectiveScaleA,
                                    &recurrentToCellEffectiveScaleB));

    int32_t cellLayerNormScaleA;
    int32_t cellLayerNormScaleB;
    if (cellLayerNormBuffer != nullptr) {
        NN_RET_CHECK(QuantizeMultiplier(cellLayerNormShape.scale, &cellLayerNormScaleA,
                                        &cellLayerNormScaleB));
    }

    const float inputToOutputEffectiveScale =
            inputToOutputWeightsShape.scale * inputShape.scale / outputIntermediateScale;
    int32_t inputToOutputEffectiveScaleA;
    int32_t inputToOutputEffectiveScaleB;
    NN_RET_CHECK(QuantizeMultiplier(inputToOutputEffectiveScale, &inputToOutputEffectiveScaleA,
                                    &inputToOutputEffectiveScaleB));
    const float recurrentToOutputEffectiveScale =
            recurrentToOutputWeightsShape.scale * prevOutputShape.scale / outputIntermediateScale;
    int32_t recurrentToOutputEffectiveScaleA;
    int32_t recurrentToOutputEffectiveScaleB;
    NN_RET_CHECK(QuantizeMultiplier(recurrentToOutputEffectiveScale,
                                    &recurrentToOutputEffectiveScaleA,
                                    &recurrentToOutputEffectiveScaleB));
    int32_t cellToOutputEffectiveScaleA;
    int32_t cellToOutputEffectiveScaleB;
    if (cellToOutputBuffer != nullptr) {
        const float cellToOutputEffectiveScale =
                std::pow(2, cellShift) * cellToOutputShape.scale / outputIntermediateScale;
        NN_RET_CHECK(QuantizeMultiplier(cellToOutputEffectiveScale, &cellToOutputEffectiveScaleA,
                                        &cellToOutputEffectiveScaleB));
    }
    int32_t outputLayerNormScaleA;
    int32_t outputLayerNormScaleB;
    if (outputLayerNormBuffer != nullptr) {
        NN_RET_CHECK(QuantizeMultiplier(outputLayerNormShape.scale, &outputLayerNormScaleA,
                                        &outputLayerNormScaleB));
    }

    const float hiddenStateEffectiveScale = std::pow(2, -15) / hiddenStateScale * std::pow(2, -15);
    int32_t hiddenStateEffectiveScaleA;
    int32_t hiddenStateEffectiveScaleB;
    NN_RET_CHECK(QuantizeMultiplier(hiddenStateEffectiveScale, &hiddenStateEffectiveScaleA,
                                    &hiddenStateEffectiveScaleB));

    int32_t projectionEffectiveScaleA;
    int32_t projectionEffectiveScaleB;
    if (projectionWeightsBuffer != nullptr) {
        const float projectionEffectiveScale =
                projectionWeightsShape.scale * hiddenStateScale / prevOutputShape.scale;
        NN_RET_CHECK(QuantizeMultiplier(projectionEffectiveScale, &projectionEffectiveScaleA,
                                        &projectionEffectiveScaleB));
    }

    // Calculates quantized clipping parameters.
    int16_t quantizedCellClip = 0;
    if (cellClip > 0.0) {
        quantizedCellClip = static_cast<int32_t>(
                std::min(std::max(cellClip / prevCellStateShape.scale, -32768.0f), 32767.0f));
    }
    int8_t quantizedProjectionClip = 0;
    if (projectionClip > 0.0) {
        quantizedProjectionClip = static_cast<int32_t>(
                std::min(std::max(projectionClip / projectionWeightsShape.scale, -128.0f), 127.0f));
    }

    // Calculates effective bias.
    // This is for optimizing the matmul calculation.
    std::unique_ptr<int32_t[]> inputToInputEffectiveBias;
    std::unique_ptr<int32_t[]> recurrentToInputEffectiveBias;
    if (!useCifg) {
        NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
                -inputShape.offset, inputToInputWeightsBuffer, inputToInputWeightsShape,
                /*bias=*/nullptr, &inputToInputEffectiveBias));
        NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
                -prevOutputShape.offset, recurrentToInputWeightsBuffer,
                recurrentToInputWeightsShape,
                /*bias=*/nullptr, &recurrentToInputEffectiveBias));
    }

    std::unique_ptr<int32_t[]> inputToForgetEffectiveBias;
    NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
            -inputShape.offset, inputToForgetWeightsBuffer, inputToForgetWeightsShape,
            /*bias=*/nullptr, &inputToForgetEffectiveBias));
    std::unique_ptr<int32_t[]> recurrentToForgetEffectiveBias;
    NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
            -prevOutputShape.offset, recurrentToForgetWeightsBuffer, recurrentToForgetWeightsShape,
            /*bias=*/nullptr, &recurrentToForgetEffectiveBias));

    std::unique_ptr<int32_t[]> inputToCellEffectiveBias;
    NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
            -inputShape.offset, inputToCellWeightsBuffer, inputToCellWeightsShape,
            /*bias=*/nullptr, &inputToCellEffectiveBias));
    std::unique_ptr<int32_t[]> recurrentToCellEffectiveBias;
    NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
            -prevOutputShape.offset, recurrentToCellWeightsBuffer, recurrentToCellWeightsShape,
            /*bias=*/nullptr, &recurrentToCellEffectiveBias));

    std::unique_ptr<int32_t[]> inputToOutputEffectiveBias;
    NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
            -inputShape.offset, inputToOutputWeightsBuffer, inputToOutputWeightsShape,
            /*bias=*/nullptr, &inputToOutputEffectiveBias));
    std::unique_ptr<int32_t[]> recurrentToOutputEffectiveBias;
    NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
            -prevOutputShape.offset, recurrentToOutputWeightsBuffer, recurrentToOutputWeightsShape,
            /*bias=*/nullptr, &recurrentToOutputEffectiveBias));

    std::unique_ptr<int32_t[]> projectionEffectiveBias;
    if (projectionBiasBuffer != nullptr) {
        NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
                hiddenStateZeroPoint, projectionWeightsBuffer, projectionWeightsShape,
                projectionBiasBuffer, &projectionEffectiveBias));
    }

    // Temporary buffers.
    std::vector<int16_t> inputGateBuffer(batchSize * numUnits);
    std::vector<int16_t> forgetGateBuffer(batchSize * numUnits);
    std::vector<int16_t> cellGateBuffer(batchSize * numUnits);
    std::vector<int16_t> outputGateBuffer(batchSize * numUnits);
    std::vector<int8_t> buffer8(batchSize * numUnits);

    // To avoid overflow when calculating layer norm.
    const int32_t inputInvLargeValue =
            std::min(1, static_cast<int32_t>(10000 * inputLayerNormShape.scale));
    const int32_t forgetInvLargeValue =
            std::min(1, static_cast<int32_t>(10000 * forgetLayerNormShape.scale));
    const int32_t cellInvLargeValue =
            std::min(1, static_cast<int32_t>(10000 * cellLayerNormShape.scale));
    const int32_t outputInvLargeValue =
            std::min(1, static_cast<int32_t>(10000 * outputLayerNormShape.scale));

    // Forget gate.
    MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToForgetEffectiveBias.get(),
                                        inputToForgetWeightsBuffer, inputToForgetEffectiveScaleA,
                                        inputToForgetEffectiveScaleB, batchSize, inputSize,
                                        numUnits,
                                        /*outputZeroPoint=*/0, forgetGateBuffer.data());
    MatrixBatchVectorMultiplyAccumulate(
            prevOutputBuffer, recurrentToForgetEffectiveBias.get(), recurrentToForgetWeightsBuffer,
            recurrentToForgetEffectiveScaleA, recurrentToForgetEffectiveScaleB, batchSize,
            outputSize, numUnits,
            /*outputZeroPoint=*/0, forgetGateBuffer.data());
    if (cellToForgetBuffer != nullptr) {
        VectorBatchVectorCwiseProductAccumulate(
                cellToForgetBuffer, outputSize, cellStateBuffer, batchSize,
                cellToForgetEffectiveScaleA, cellToForgetEffectiveScaleB, forgetGateBuffer.data());
    }
    if (forgetLayerNormBuffer != nullptr) {
        ApplyLayerNorm(forgetGateBuffer.data(), forgetLayerNormBuffer, forgetBiasBuffer,
                       forgetLayerNormScaleA, forgetLayerNormScaleB, forgetInvLargeValue, batchSize,
                       numUnits, forgetGateBuffer.data());
    }
    ApplySigmoid(forgetGateBuffer.data(), batchSize, numUnits, forgetGateBuffer.data());

    // Modulation gate.
    MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToCellEffectiveBias.get(),
                                        inputToCellWeightsBuffer, inputToCellEffectiveScaleA,
                                        inputToCellEffectiveScaleB, batchSize, inputSize, numUnits,
                                        /*outputZeroPoint=*/0, cellGateBuffer.data());
    MatrixBatchVectorMultiplyAccumulate(
            prevOutputBuffer, recurrentToCellEffectiveBias.get(), recurrentToCellWeightsBuffer,
            recurrentToCellEffectiveScaleA, recurrentToCellEffectiveScaleB, batchSize, outputSize,
            numUnits,
            /*outputZeroPoint=*/0, cellGateBuffer.data());
    if (cellLayerNormBuffer != nullptr) {
        ApplyLayerNorm(cellGateBuffer.data(), cellLayerNormBuffer, cellBiasBuffer,
                       cellLayerNormScaleA, cellLayerNormScaleB, cellInvLargeValue, batchSize,
                       numUnits, cellGateBuffer.data());
    }
    ApplyTanh<3>(cellGateBuffer.data(), batchSize, numUnits, cellGateBuffer.data());

    // Input gate.
    if (useCifg) {
        Sub1Vector(forgetGateBuffer.data(), batchSize * numUnits, inputGateBuffer.data());
    } else {
        MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToInputEffectiveBias.get(),
                                            inputToInputWeightsBuffer, inputToInputEffectiveScaleA,
                                            inputToInputEffectiveScaleB, batchSize, inputSize,
                                            numUnits,
                                            /*outputZeroPoint=*/0, inputGateBuffer.data());
        MatrixBatchVectorMultiplyAccumulate(
                prevOutputBuffer, recurrentToInputEffectiveBias.get(),
                recurrentToInputWeightsBuffer, recurrentToInputEffectiveScaleA,
                recurrentToInputEffectiveScaleB, batchSize, outputSize, numUnits,
                /*outputZeroPoint=*/0, inputGateBuffer.data());
        if (cellToInputBuffer != nullptr) {
            VectorBatchVectorCwiseProductAccumulate(
                    cellToInputBuffer, outputSize, cellStateBuffer, batchSize,
                    cellToInputEffectiveScaleA, cellToInputEffectiveScaleB, inputGateBuffer.data());
        }
        if (inputLayerNormBuffer != nullptr) {
            ApplyLayerNorm(inputGateBuffer.data(), inputLayerNormBuffer, inputBiasBuffer,
                           inputLayerNormScaleA, inputLayerNormScaleB, inputInvLargeValue,
                           batchSize, numUnits, inputGateBuffer.data());
        }
        ApplySigmoid(inputGateBuffer.data(), batchSize, numUnits, inputGateBuffer.data());
    }

    // Cell.
    CwiseMul(forgetGateBuffer.data(), prevCellStateBuffer, batchSize, numUnits,
             /*shift=*/15, forgetGateBuffer.data());
    CwiseMul(inputGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, 30 + cellShift,
             cellGateBuffer.data());
    CwiseAdd(forgetGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, cellStateBuffer);
    if (quantizedCellClip > 0) {
        CwiseClipping(cellStateBuffer, quantizedCellClip, batchSize, numUnits);
    }

    // Output gate.
    MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToOutputEffectiveBias.get(),
                                        inputToOutputWeightsBuffer, inputToOutputEffectiveScaleA,
                                        inputToOutputEffectiveScaleB, batchSize, inputSize,
                                        numUnits,
                                        /*outputZeroPoint=*/0, outputGateBuffer.data());
    MatrixBatchVectorMultiplyAccumulate(
            prevOutputBuffer, recurrentToOutputEffectiveBias.get(), recurrentToOutputWeightsBuffer,
            recurrentToOutputEffectiveScaleA, recurrentToOutputEffectiveScaleB, batchSize,
            outputSize, numUnits,
            /*outputZeroPoint=*/0, outputGateBuffer.data());
    if (cellToOutputBuffer != nullptr) {
        VectorBatchVectorCwiseProductAccumulate(
                cellToOutputBuffer, outputSize, cellStateBuffer, batchSize,
                cellToOutputEffectiveScaleA, cellToOutputEffectiveScaleB, outputGateBuffer.data());
    }
    if (outputLayerNormBuffer != nullptr) {
        ApplyLayerNorm(outputGateBuffer.data(), outputLayerNormBuffer, outputBiasBuffer,
                       outputLayerNormScaleA, outputLayerNormScaleB, outputInvLargeValue, batchSize,
                       numUnits, outputGateBuffer.data());
    }
    ApplySigmoid(outputGateBuffer.data(), batchSize, numUnits, outputGateBuffer.data());

    // Hidden.
    ApplyTanh(cellShift + 15, cellStateBuffer, batchSize, numUnits, inputGateBuffer.data());
    CwiseMul(outputGateBuffer.data(), inputGateBuffer.data(), hiddenStateEffectiveScaleA,
             hiddenStateEffectiveScaleB, batchSize, numUnits, hiddenStateZeroPoint, buffer8.data());

    // Projection.
    if (projectionWeightsBuffer != nullptr) {
        memset(outputBuffer, 0, batchSize * outputSize * sizeof(int8_t));
        MatrixBatchVectorMultiplyAccumulate(buffer8.data(), projectionEffectiveBias.get(),
                                            projectionWeightsBuffer, projectionEffectiveScaleA,
                                            projectionEffectiveScaleB, batchSize, numUnits,
                                            outputSize, prevOutputShape.offset, outputBuffer);
        if (quantizedProjectionClip > 0) {
            CwiseClipping(outputBuffer, quantizedProjectionClip, batchSize, outputSize);
        }
    } else {
        std::copy_n(buffer8.data(), batchSize * outputSize, outputBuffer);
    }

    // Copy output to output state out.
    for (unsigned int i = 0; i < batchSize * outputSize; ++i) {
        outputStateBuffer[i] = outputBuffer[i];
    }

    return true;
}
#endif  // NN_INCLUDE_CPU_IMPLEMENTATION

}  // namespace qlstm

NN_REGISTER_OPERATION_DEFAULT_VALIDATION(QUANTIZED_LSTM, qlstm::prepare, qlstm::execute,
                                         .allowOmittedOperand = true);

}  // namespace nn
}  // namespace android
