/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <vector>

#include "NeuralNetworksWrapper.h"
#include "RNN.h"

namespace android {
namespace nn {
namespace wrapper {

using ::testing::Each;
using ::testing::FloatNear;
using ::testing::Matcher;

namespace {

std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
                                           float max_abs_error = 1.e-5) {
    std::vector<Matcher<float>> matchers;
    matchers.reserve(values.size());
    for (const float& v : values) {
        matchers.emplace_back(FloatNear(v, max_abs_error));
    }
    return matchers;
}

static float rnn_input[] = {
        0.23689353,  0.285385,     0.037029743, -0.19858193,  -0.27569133,  0.43773448,
        0.60379338,  0.35562468,   -0.69424844, -0.93421471,  -0.87287879,  0.37144363,
        -0.62476718, 0.23791671,   0.40060222,  0.1356622,    -0.99774903,  -0.98858172,
        -0.38952237, -0.47685933,  0.31073618,  0.71511042,   -0.63767755,  -0.31729108,
        0.33468103,  0.75801885,   0.30660987,  -0.37354088,  0.77002847,   -0.62747043,
        -0.68572164, 0.0069220066, 0.65791464,  0.35130811,   0.80834007,   -0.61777675,
        -0.21095741, 0.41213346,   0.73784804,  0.094794154,  0.47791874,   0.86496925,
        -0.53376222, 0.85315156,   0.10288584,  0.86684,      -0.011186242, 0.10513687,
        0.87825835,  0.59929144,   0.62827742,  0.18899453,   0.31440187,   0.99059987,
        0.87170351,  -0.35091716,  0.74861872,  0.17831337,   0.2755419,    0.51864719,
        0.55084288,  0.58982027,   -0.47443086, 0.20875752,   -0.058871567, -0.66609079,
        0.59098077,  0.73017097,   0.74604273,  0.32882881,   -0.17503482,  0.22396147,
        0.19379807,  0.29120302,   0.077113032, -0.70331609,  0.15804303,   -0.93407321,
        0.40182066,  0.036301374,  0.66521823,  0.0300982,    -0.7747041,   -0.02038002,
        0.020698071, -0.90300065,  0.62870288,  -0.23068321,  0.27531278,   -0.095755219,
        -0.712036,   -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,  0.43519354,
        0.14744234,  0.62589407,   0.1653645,   -0.10651493,  -0.045277178, 0.99032974,
        -0.88255352, -0.85147917,  0.28153265,  0.19455957,   -0.55479527,  -0.56042433,
        0.26048636,  0.84702539,   0.47587705,  -0.074295521, -0.12287641,  0.70117295,
        0.90532446,  0.89782166,   0.79817224,  0.53402734,   -0.33286154,  0.073485017,
        -0.56172788, -0.044897556, 0.89964068,  -0.067662835, 0.76863563,   0.93455386,
        -0.6324693,  -0.083922029};

static float rnn_golden_output[] = {
        0.496726,   0,        0.965996,  0,         0.0584254, 0,          0,         0.12315,
        0,          0,        0.612266,  0.456601,  0,         0.52286,    1.16099,   0.0291232,

        0,          0,        0.524901,  0,         0,         0,          0,         1.02116,
        0,          1.35762,  0,         0.356909,  0.436415,  0.0355727,  0,         0,

        0,          0,        0,         0.262335,  0,         0,          0,         1.33992,
        0,          2.9739,   0,         0,         1.31914,   2.66147,    0,         0,

        0.942568,   0,        0,         0,         0.025507,  0,          0,         0,
        0.321429,   0.569141, 1.25274,   1.57719,   0.8158,    1.21805,    0.586239,  0.25427,

        1.04436,    0,        0.630725,  0,         0.133801,  0.210693,   0.363026,  0,
        0.533426,   0,        1.25926,   0.722707,  0,         1.22031,    1.30117,   0.495867,

        0.222187,   0,        0.72725,   0,         0.767003,  0,          0,         0.147835,
        0,          0,        0,         0.608758,  0.469394,  0.00720298, 0.927537,  0,

        0.856974,   0.424257, 0,         0,         0.937329,  0,          0,         0,
        0.476425,   0,        0.566017,  0.418462,  0.141911,  0.996214,   1.13063,   0,

        0.967899,   0,        0,         0,         0.0831304, 0,          0,         1.00378,
        0,          0,        0,         1.44818,   1.01768,   0.943891,   0.502745,  0,

        0.940135,   0,        0,         0,         0,         0,          0,         2.13243,
        0,          0.71208,  0.123918,  1.53907,   1.30225,   1.59644,    0.70222,   0,

        0.804329,   0,        0.430576,  0,         0.505872,  0.509603,   0.343448,  0,
        0.107756,   0.614544, 1.44549,   1.52311,   0.0454298, 0.300267,   0.562784,  0.395095,

        0.228154,   0,        0.675323,  0,         1.70536,   0.766217,   0,         0,
        0,          0.735363, 0.0759267, 1.91017,   0.941888,  0,          0,         0,

        0,          0,        1.5909,    0,         0,         0,          0,         0.5755,
        0,          0.184687, 0,         1.56296,   0.625285,  0,          0,         0,

        0,          0,        0.0857888, 0,         0,         0,          0,         0.488383,
        0.252786,   0,        0,         0,         1.02817,   1.85665,    0,         0,

        0.00981836, 0,        1.06371,   0,         0,         0,          0,         0,
        0,          0.290445, 0.316406,  0,         0.304161,  1.25079,    0.0707152, 0,

        0.986264,   0.309201, 0,         0,         0,         0,          0,         1.64896,
        0.346248,   0,        0.918175,  0.78884,   0.524981,  1.92076,    2.07013,   0.333244,

        0.415153,   0.210318, 0,         0,         0,         0,          0,         2.02616,
        0,          0.728256, 0.84183,   0.0907453, 0.628881,  3.58099,    1.49974,   0};

}  // anonymous namespace

#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
    ACTION(Input)                                \
    ACTION(Weights)                              \
    ACTION(RecurrentWeights)                     \
    ACTION(Bias)                                 \
    ACTION(HiddenStateIn)

// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
    ACTION(HiddenStateOut)             \
    ACTION(Output)

class BasicRNNOpModel {
   public:
    BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
        : batches_(batches), units_(units), input_size_(size), activation_(kActivationRelu) {
        std::vector<uint32_t> inputs;

        OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
        inputs.push_back(model_.addOperand(&InputTy));
        OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
        inputs.push_back(model_.addOperand(&WeightTy));
        OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
        inputs.push_back(model_.addOperand(&RecurrentWeightTy));
        OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
        inputs.push_back(model_.addOperand(&BiasTy));
        OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
        inputs.push_back(model_.addOperand(&HiddenStateTy));
        OperandType ActionParamTy(Type::INT32, {});
        inputs.push_back(model_.addOperand(&ActionParamTy));

        std::vector<uint32_t> outputs;

        outputs.push_back(model_.addOperand(&HiddenStateTy));
        OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
        outputs.push_back(model_.addOperand(&OutputTy));

        Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
        HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
        HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
        Output_.insert(Output_.end(), batches_ * units_, 0.f);

        model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
        model_.identifyInputsAndOutputs(inputs, outputs);

        model_.finish();
    }

#define DefineSetter(X) \
    void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }

    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);

#undef DefineSetter

    void SetInput(int offset, float* begin, float* end) {
        for (; begin != end; begin++, offset++) {
            Input_[offset] = *begin;
        }
    }

    void ResetHiddenState() {
        std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
        std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
    }

    const std::vector<float>& GetOutput() const { return Output_; }

    uint32_t input_size() const { return input_size_; }
    uint32_t num_units() const { return units_; }
    uint32_t num_batches() const { return batches_; }

    void Invoke() {
        ASSERT_TRUE(model_.isValid());

        HiddenStateIn_.swap(HiddenStateOut_);

        Compilation compilation(&model_);
        compilation.finish();
        Execution execution(&compilation);
#define SetInputOrWeight(X)                                                                    \
    ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
              Result::NO_ERROR);

        FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);

#undef SetInputOrWeight

#define SetOutput(X)                                                                            \
    ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
              Result::NO_ERROR);

        FOR_ALL_OUTPUT_TENSORS(SetOutput);

#undef SetOutput

        ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_, sizeof(activation_)),
                  Result::NO_ERROR);

        ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    }

   private:
    Model model_;

    const uint32_t batches_;
    const uint32_t units_;
    const uint32_t input_size_;

    const int activation_;

#define DefineTensor(X) std::vector<float> X##_;

    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
    FOR_ALL_OUTPUT_TENSORS(DefineTensor);

#undef DefineTensor
};

TEST(RNNOpTest, BlackBoxTest) {
    BasicRNNOpModel rnn(2, 16, 8);
    rnn.SetWeights(
            {0.461459,  0.153381,    0.529743,   -0.00371218, 0.676267,    -0.211346, 0.317493,
             0.969689,  -0.343251,   0.186423,   0.398151,    0.152399,    0.448504,  0.317662,
             0.523556,  -0.323514,   0.480877,   0.333113,    -0.757714,   -0.674487, -0.643585,
             0.217766,  -0.0251462,  0.79512,    -0.595574,   -0.422444,   0.371572,  -0.452178,
             -0.556069, -0.482188,   -0.685456,  -0.727851,   0.841829,    0.551535,  -0.232336,
             0.729158,  -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,  -0.423241,
             0.548547,  -0.0152023,  -0.757482,  -0.85491,    0.251331,    -0.989183, 0.306261,
             -0.340716, 0.886103,    -0.0726757, -0.723523,   -0.784303,   0.0354295, 0.566564,
             -0.485469, -0.620498,   0.832546,   0.697884,    -0.279115,   0.294415,  -0.584313,
             0.548772,  0.0648819,   0.968726,   0.723834,    -0.0080452,  -0.350386, -0.272803,
             0.115121,  -0.412644,   -0.824713,  -0.992843,   -0.592904,   -0.417893, 0.863791,
             -0.423461, -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,  -0.639158,
             0.816969,  -0.337228,   0.659878,   0.73107,     0.754768,    -0.337042, 0.0960841,
             0.368357,  0.244191,    -0.817703,  -0.211223,   0.442012,    0.37225,   -0.623598,
             -0.405423, 0.455101,    0.673656,   -0.145345,   -0.511346,   -0.901675, -0.81252,
             -0.127006, 0.809865,    -0.721884,  0.636255,    0.868989,    -0.347973, -0.10179,
             -0.777449, 0.917274,    0.819286,   0.206218,    -0.00785118, 0.167141,  0.45872,
             0.972934,  -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057, -0.469077,
             0.277308,  0.415818});

    rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
                 -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, 0.37197268,
                 0.61957061, 0.3956964, -0.37609905});

    rnn.SetRecurrentWeights(
            {0.1, 0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0.1, 0,
             0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0.1, 0,   0,   0,
             0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0.1, 0,   0,   0,   0,   0,
             0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0.1, 0,   0,   0,   0,   0,   0,   0,
             0,   0,   0, 0,   0, 0,   0, 0,   0,  0.1, 0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0, 0,   0, 0,   0, 0.1, 0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0, 0,   0, 0.1, 0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0, 0.1, 0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0.1, 0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0.1,
             0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0.1, 0,   0,
             0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0.1, 0,   0,   0,   0,
             0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0.1, 0,   0,   0,   0,   0,   0,
             0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0.1, 0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0, 0,   0, 0,   0, 0,   0.1});

    rnn.ResetHiddenState();
    const int input_sequence_size =
            sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches());

    for (int i = 0; i < input_sequence_size; i++) {
        float* batch_start = rnn_input + i * rnn.input_size();
        float* batch_end = batch_start + rnn.input_size();
        rnn.SetInput(0, batch_start, batch_end);
        rnn.SetInput(rnn.input_size(), batch_start, batch_end);

        rnn.Invoke();

        float* golden_start = rnn_golden_output + i * rnn.num_units();
        float* golden_end = golden_start + rnn.num_units();
        std::vector<float> expected;
        expected.insert(expected.end(), golden_start, golden_end);
        expected.insert(expected.end(), golden_start, golden_end);

        EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    }
}

}  // namespace wrapper
}  // namespace nn
}  // namespace android
