/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <functional>
#include <vector>

#include "EmbeddingLookup.h"
#include "NeuralNetworksWrapper.h"

using ::testing::FloatNear;
using ::testing::Matcher;

namespace android {
namespace nn {
namespace wrapper {

namespace {

std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
                                           float max_abs_error = 1.e-6) {
    std::vector<Matcher<float>> matchers;
    matchers.reserve(values.size());
    for (const float& v : values) {
        matchers.emplace_back(FloatNear(v, max_abs_error));
    }
    return matchers;
}

}  // namespace

using ::testing::ElementsAreArray;

#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
    ACTION(Value, float)                         \
    ACTION(Lookup, int)

// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) ACTION(Output, float)

class EmbeddingLookupOpModel {
   public:
    EmbeddingLookupOpModel(std::initializer_list<uint32_t> index_shape,
                           std::initializer_list<uint32_t> weight_shape) {
        auto it = weight_shape.begin();
        rows_ = *it++;
        columns_ = *it++;
        features_ = *it;

        std::vector<uint32_t> inputs;

        OperandType LookupTy(Type::TENSOR_INT32, index_shape);
        inputs.push_back(model_.addOperand(&LookupTy));

        OperandType ValueTy(Type::TENSOR_FLOAT32, weight_shape);
        inputs.push_back(model_.addOperand(&ValueTy));

        std::vector<uint32_t> outputs;

        OperandType OutputOpndTy(Type::TENSOR_FLOAT32, weight_shape);
        outputs.push_back(model_.addOperand(&OutputOpndTy));

        auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
            uint32_t sz = 1;
            for (uint32_t d : dims) {
                sz *= d;
            }
            return sz;
        };

        Value_.insert(Value_.end(), multiAll(weight_shape), 0.f);
        Output_.insert(Output_.end(), multiAll(weight_shape), 0.f);

        model_.addOperation(ANEURALNETWORKS_EMBEDDING_LOOKUP, inputs, outputs);
        model_.identifyInputsAndOutputs(inputs, outputs);

        model_.finish();
    }

    void Invoke() {
        ASSERT_TRUE(model_.isValid());

        Compilation compilation(&model_);
        compilation.finish();
        Execution execution(&compilation);

#define SetInputOrWeight(X, T)                                               \
    ASSERT_EQ(execution.setInput(EmbeddingLookup::k##X##Tensor, X##_.data(), \
                                 sizeof(T) * X##_.size()),                   \
              Result::NO_ERROR);

        FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);

#undef SetInputOrWeight

#define SetOutput(X, T)                                                       \
    ASSERT_EQ(execution.setOutput(EmbeddingLookup::k##X##Tensor, X##_.data(), \
                                  sizeof(T) * X##_.size()),                   \
              Result::NO_ERROR);

        FOR_ALL_OUTPUT_TENSORS(SetOutput);

#undef SetOutput

        ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    }

#define DefineSetter(X, T) \
    void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }

    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);

#undef DefineSetter

    void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
        for (uint32_t i = 0; i < rows_; i++) {
            for (uint32_t j = 0; j < columns_; j++) {
                for (uint32_t k = 0; k < features_; k++) {
                    Value_[(i * columns_ + j) * features_ + k] = function(i, j, k);
                }
            }
        }
    }

    const std::vector<float>& GetOutput() const { return Output_; }

   private:
    Model model_;
    uint32_t rows_;
    uint32_t columns_;
    uint32_t features_;

#define DefineTensor(X, T) std::vector<T> X##_;

    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
    FOR_ALL_OUTPUT_TENSORS(DefineTensor);

#undef DefineTensor
};

// TODO: write more tests that exercise the details of the op, such as
// lookup errors and variable input shapes.
TEST(EmbeddingLookupOpTest, SimpleTest) {
    EmbeddingLookupOpModel m({3}, {3, 2, 4});
    m.SetLookup({1, 0, 2});
    m.Set3DWeightMatrix([](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });

    m.Invoke();

    EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
                                       1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13,  // Row 1
                                       0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13,  // Row 0
                                       2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13,  // Row 2
                               })));
}

}  // namespace wrapper
}  // namespace nn
}  // namespace android
