/**
 * Copyright 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "run_tflite.h"

#include <jni.h>
#include <string>
#include <iomanip>
#include <sstream>
#include <fcntl.h>

#include <android/asset_manager_jni.h>
#include <android/log.h>
#include <android/sharedmem.h>
#include <sys/mman.h>

extern "C" JNIEXPORT jboolean JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_hasNnApiDevice(
    JNIEnv *env, jobject /* this */, jstring _nnApiDeviceName) {
  bool foundDevice = false;
  const char *nnApiDeviceName =
      _nnApiDeviceName == NULL ? NULL
                               : env->GetStringUTFChars(_nnApiDeviceName, NULL);
  if (nnApiDeviceName != NULL) {
    std::string device_name(nnApiDeviceName);
    uint32_t numDevices = 0;
    NnApiImplementation()->ANeuralNetworks_getDeviceCount(&numDevices);

    for (uint32_t i = 0; i < numDevices; i++) {
      ANeuralNetworksDevice *device = nullptr;
      const char *buffer = nullptr;
      NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
      NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer);
      if (device_name == buffer) {
        foundDevice = true;
        break;
      }
    }
    env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
  }

  return foundDevice;
}

extern "C"
JNIEXPORT jlong
JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_initModel(
        JNIEnv *env,
        jobject /* this */,
        jstring _modelFileName,
        jint _tfliteBackend,
        jboolean _enableIntermediateTensorsDump,
        jstring _nnApiDeviceName,
        jboolean _mmapModel,
        jstring _nnApiCacheDir,
        jlong _nnApiSlHandle) {
    const char *modelFileName = env->GetStringUTFChars(_modelFileName, NULL);
    const char *nnApiDeviceName =
        _nnApiDeviceName == NULL
            ? NULL
            : env->GetStringUTFChars(_nnApiDeviceName, NULL);
    const char *nnApiCacheDir =
        _nnApiCacheDir == NULL
            ? NULL
            : env->GetStringUTFChars(_nnApiCacheDir, NULL);
     const tflite::nnapi::NnApiSupportLibrary *nnApiSlHandle =
          (const tflite::nnapi::NnApiSupportLibrary *)_nnApiSlHandle;
    int nnapiErrno = 0;
    void *handle = BenchmarkModel::create(
        modelFileName, _tfliteBackend, _enableIntermediateTensorsDump, &nnapiErrno,
        nnApiDeviceName, _mmapModel, nnApiCacheDir, nnApiSlHandle);
    env->ReleaseStringUTFChars(_modelFileName, modelFileName);
    if (_nnApiDeviceName != NULL) {
        env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
    }

    if (_tfliteBackend == TFLITE_NNAPI && nnapiErrno != 0) {
      jclass nnapiFailureClass = env->FindClass(
          "com/android/nn/benchmark/core/NnApiDelegationFailure");
      jmethodID constructor =
          env->GetMethodID(nnapiFailureClass, "<init>", "(I)V");
      jobject exception =
          env->NewObject(nnapiFailureClass, constructor, nnapiErrno);
      env->Throw(static_cast<jthrowable>(exception));
    }

    return (jlong)(uintptr_t)handle;
}



extern "C"
JNIEXPORT void
JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(
        JNIEnv *env,
        jobject /* this */,
        jlong _modelHandle) {
    BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
    delete(model);
}

extern "C"
JNIEXPORT jboolean
JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(
        JNIEnv *env,
        jobject /* this */,
        jlong _modelHandle,
        jintArray _inputShape) {
    BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
    jint* shapePtr = env->GetIntArrayElements(_inputShape, nullptr);
    jsize shapeLen = env->GetArrayLength(_inputShape);

    std::vector<int> shape(shapePtr, shapePtr + shapeLen);
    return model->resizeInputTensors(std::move(shape));
}

/** RAII container for a list of InferenceInOutSequence to handle JNI data release in destructor. */
class InferenceInOutSequenceList {
public:
    InferenceInOutSequenceList(JNIEnv *env,
                               const jobject& inOutDataList,
                               bool expectGoldenOutputs);
    ~InferenceInOutSequenceList();

    bool isValid() const { return mValid; }

    const std::vector<InferenceInOutSequence>& data() const { return mData; }

private:
    JNIEnv *mEnv;  // not owned.

    std::vector<InferenceInOutSequence> mData;
    std::vector<jbyteArray> mInputArrays;
    std::vector<jobjectArray> mOutputArrays;
    bool mValid;
};

InferenceInOutSequenceList::InferenceInOutSequenceList(JNIEnv *env,
                                                       const jobject& inOutDataList,
                                                       bool expectGoldenOutputs)
    : mEnv(env), mValid(false) {

    jclass list_class = env->FindClass("java/util/List");
    if (list_class == nullptr) { return; }
    jmethodID list_size = env->GetMethodID(list_class, "size", "()I");
    if (list_size == nullptr) { return; }
    jmethodID list_get = env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;");
    if (list_get == nullptr) { return; }
    jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
    if (list_add == nullptr) { return; }

    jclass inOutSeq_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOutSequence");
    if (inOutSeq_class == nullptr) { return; }
    jmethodID inOutSeq_size = env->GetMethodID(inOutSeq_class, "size", "()I");
    if (inOutSeq_size == nullptr) { return; }
    jmethodID inOutSeq_get = env->GetMethodID(inOutSeq_class, "get",
                                              "(I)Lcom/android/nn/benchmark/core/InferenceInOut;");
    if (inOutSeq_get == nullptr) { return; }

    jclass inout_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut");
    if (inout_class == nullptr) { return; }
    jfieldID inout_input = env->GetFieldID(inout_class, "mInput", "[B");
    if (inout_input == nullptr) { return; }
    jfieldID inout_expectedOutputs = env->GetFieldID(inout_class, "mExpectedOutputs", "[[B");
    if (inout_expectedOutputs == nullptr) { return; }
    jfieldID inout_inputCreator = env->GetFieldID(inout_class, "mInputCreator",
            "Lcom/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface;");
    if (inout_inputCreator == nullptr) { return; }



    // Fetch input/output arrays
    size_t data_count = mEnv->CallIntMethod(inOutDataList, list_size);
    if (env->ExceptionCheck()) { return; }
    mData.reserve(data_count);

    jclass inputCreator_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface");
    if (inputCreator_class == nullptr) { return; }
    jmethodID createInput_method = env->GetMethodID(inputCreator_class, "createInput", "(Ljava/nio/ByteBuffer;)V");
    if (createInput_method == nullptr) { return; }

    for (int seq_index = 0; seq_index < data_count; ++seq_index) {
        jobject inOutSeq = mEnv->CallObjectMethod(inOutDataList, list_get, seq_index);
        if (mEnv->ExceptionCheck()) { return; }

        size_t seqLen = mEnv->CallIntMethod(inOutSeq, inOutSeq_size);
        if (mEnv->ExceptionCheck()) { return; }

        mData.push_back(InferenceInOutSequence{});
        auto& seq = mData.back();
        seq.reserve(seqLen);
        for (int i = 0; i < seqLen; ++i) {
            jobject inout = mEnv->CallObjectMethod(inOutSeq, inOutSeq_get, i);
            if (mEnv->ExceptionCheck()) { return; }

            uint8_t* input_data = nullptr;
            size_t input_len = 0;
            std::function<bool(uint8_t*, size_t)> inputCreator;
            jbyteArray input = static_cast<jbyteArray>(
                    mEnv->GetObjectField(inout, inout_input));
            mInputArrays.push_back(input);
            if (input != nullptr) {
                input_data = reinterpret_cast<uint8_t*>(
                        mEnv->GetByteArrayElements(input, NULL));
                input_len = mEnv->GetArrayLength(input);
            } else {
                inputCreator = [env, inout, inout_inputCreator, createInput_method](
                        uint8_t* buffer, size_t length) {
                    jobject byteBuffer = env->NewDirectByteBuffer(buffer, length);
                    if (byteBuffer == nullptr) { return false; }
                    jobject creator = env->GetObjectField(inout, inout_inputCreator);
                    if (creator == nullptr) { return false; }
                    env->CallVoidMethod(creator, createInput_method, byteBuffer);
                    if (env->ExceptionCheck()) { return false; }
                    return true;
                };
            }

            jobjectArray expectedOutputs = static_cast<jobjectArray>(
                    mEnv->GetObjectField(inout, inout_expectedOutputs));
            mOutputArrays.push_back(expectedOutputs);
            seq.push_back({input_data, input_len, {}, inputCreator});

            // Add expected output to sequence added above
            if (expectedOutputs != nullptr) {
                jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
                auto& outputs = seq.back().outputs;
                outputs.reserve(expectedOutputsLength);

                for (jsize j = 0;j < expectedOutputsLength; ++j) {
                    jbyteArray expectedOutput =
                            static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
                    if (env->ExceptionCheck()) {
                        return;
                    }
                    if (expectedOutput == nullptr) {
                        jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
                        mEnv->ThrowNew(iaeClass, "Null expected output array");
                        return;
                    }

                    uint8_t *expectedOutput_data = reinterpret_cast<uint8_t*>(
                                        mEnv->GetByteArrayElements(expectedOutput, NULL));
                    size_t expectedOutput_len = mEnv->GetArrayLength(expectedOutput);
                    outputs.push_back({ expectedOutput_data, expectedOutput_len});
                }
            } else {
                if (expectGoldenOutputs) {
                    jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
                    mEnv->ThrowNew(iaeClass, "Expected golden output for every input");
                    return;
                }
            }
        }
    }
    mValid = true;
}

InferenceInOutSequenceList::~InferenceInOutSequenceList() {
    // Note that we may land here with a pending JNI exception so cannot call
    // java objects.
    int arrayIndex = 0;
    for (int seq_index = 0; seq_index < mData.size(); ++seq_index) {
        for (int i = 0; i < mData[seq_index].size(); ++i) {
            jbyteArray input = mInputArrays[arrayIndex];
            if (input != nullptr) {
                mEnv->ReleaseByteArrayElements(
                        input, reinterpret_cast<jbyte*>(mData[seq_index][i].input), JNI_ABORT);
            }
            jobjectArray expectedOutputs = mOutputArrays[arrayIndex];
            if (expectedOutputs != nullptr) {
                jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
                if (expectedOutputsLength != mData[seq_index][i].outputs.size()) {
                    // Should not happen? :)
                    jclass iaeClass = mEnv->FindClass("java/lang/IllegalStateException");
                    mEnv->ThrowNew(iaeClass, "Mismatch of the size of expected outputs jni array "
                                   "and internal array of its bufers");
                    return;
                }

                for (jsize j = 0;j < expectedOutputsLength; ++j) {
                    jbyteArray expectedOutput = static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
                    mEnv->ReleaseByteArrayElements(
                        expectedOutput, reinterpret_cast<jbyte*>(mData[seq_index][i].outputs[j].ptr),
                        JNI_ABORT);
                }
            }
            arrayIndex++;
        }
    }
}

extern "C"
JNIEXPORT jboolean
JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(
        JNIEnv *env,
        jobject /* this */,
        jlong _modelHandle,
        jobject inOutDataList,
        jobject resultList,
        jint inferencesSeqMaxCount,
        jfloat timeoutSec,
        jint flags) {

    BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);

    jclass list_class = env->FindClass("java/util/List");
    if (list_class == nullptr) { return false; }
    jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
    if (list_add == nullptr) { return false; }

    jclass result_class = env->FindClass("com/android/nn/benchmark/core/InferenceResult");
    if (result_class == nullptr) { return false; }
    jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "(F[F[F[[BII)V");
    if (result_ctor == nullptr) { return false; }

    std::vector<InferenceResult> result;

    const bool expectGoldenOutputs = (flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0;
    InferenceInOutSequenceList data(env, inOutDataList, expectGoldenOutputs);
    if (!data.isValid()) {
        return false;
    }

    // TODO: Remove success boolean from this method and throw an exception in case of problems
    bool success = model->benchmark(data.data(), inferencesSeqMaxCount, timeoutSec, flags, &result);

    // Generate results
    if (success) {
        for (const InferenceResult &rentry : result) {
            jobjectArray inferenceOutputs = nullptr;
            jfloatArray meanSquareErrorArray = nullptr;
            jfloatArray maxSingleErrorArray = nullptr;

            if ((flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0) {
                meanSquareErrorArray = env->NewFloatArray(rentry.meanSquareErrors.size());
                if (env->ExceptionCheck()) { return false; }
                maxSingleErrorArray = env->NewFloatArray(rentry.maxSingleErrors.size());
                if (env->ExceptionCheck()) { return false; }
                {
                    jfloat *bytes = env->GetFloatArrayElements(meanSquareErrorArray, nullptr);
                    memcpy(bytes,
                           &rentry.meanSquareErrors[0],
                           rentry.meanSquareErrors.size() * sizeof(float));
                    env->ReleaseFloatArrayElements(meanSquareErrorArray, bytes, 0);
                }
                {
                    jfloat *bytes = env->GetFloatArrayElements(maxSingleErrorArray, nullptr);
                    memcpy(bytes,
                           &rentry.maxSingleErrors[0],
                           rentry.maxSingleErrors.size() * sizeof(float));
                    env->ReleaseFloatArrayElements(maxSingleErrorArray, bytes, 0);
                }
            }

            if ((flags & FLAG_DISCARD_INFERENCE_OUTPUT) == 0) {
                jclass byteArrayClass = env->FindClass("[B");

                inferenceOutputs = env->NewObjectArray(
                    rentry.inferenceOutputs.size(),
                    byteArrayClass, nullptr);

                for (int i = 0;i < rentry.inferenceOutputs.size();++i) {
                    jbyteArray inferenceOutput = nullptr;
                    inferenceOutput = env->NewByteArray(rentry.inferenceOutputs[i].size());
                    if (env->ExceptionCheck()) { return false; }
                    jbyte *bytes = env->GetByteArrayElements(inferenceOutput, nullptr);
                    memcpy(bytes, &rentry.inferenceOutputs[i][0], rentry.inferenceOutputs[i].size());
                    env->ReleaseByteArrayElements(inferenceOutput, bytes, 0);
                    env->SetObjectArrayElement(inferenceOutputs, i, inferenceOutput);
                }
            }

            jobject object = env->NewObject(
                result_class, result_ctor, rentry.computeTimeSec,
                meanSquareErrorArray, maxSingleErrorArray, inferenceOutputs,
                rentry.inputOutputSequenceIndex, rentry.inputOutputIndex);
            if (env->ExceptionCheck() || object == NULL) { return false; }

            env->CallBooleanMethod(resultList, list_add, object);
            if (env->ExceptionCheck()) { return false; }

            // Releasing local references to objects to avoid local reference table overflow
            // if tests is set to run for long time.
            if (meanSquareErrorArray) {
                env->DeleteLocalRef(meanSquareErrorArray);
            }
            if (maxSingleErrorArray) {
                env->DeleteLocalRef(maxSingleErrorArray);
            }
            env->DeleteLocalRef(object);
        }
    }

    return success;
}

extern "C"
JNIEXPORT void
JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(
        JNIEnv *env,
        jobject /* this */,
        jlong _modelHandle,
        jstring dumpPath,
        jobject inOutDataList) {

    BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);

    InferenceInOutSequenceList data(env, inOutDataList, /*expectGoldenOutputs=*/false);
    if (!data.isValid()) {
        return;
    }

    const char *dumpPathStr = env->GetStringUTFChars(dumpPath, JNI_FALSE);
    model->dumpAllLayers(dumpPathStr, data.data());
    env->ReleaseStringUTFChars(dumpPath, dumpPathStr);
}

extern "C"
JNIEXPORT jboolean
JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_hasAccelerator() {
  uint32_t device_count = 0;
  NnApiImplementation()->ANeuralNetworks_getDeviceCount(&device_count);
  // We only consider a real device, not 'nnapi-reference'.
  return device_count > 1;
}

extern "C"
JNIEXPORT jboolean
JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_getAcceleratorNames(
    JNIEnv *env,
    jclass, /* clazz */
    jobject resultList
    ) {
  uint32_t device_count = 0;
  auto nnapi_result = NnApiImplementation()->ANeuralNetworks_getDeviceCount(&device_count);
  if (nnapi_result != 0) {
    return false;
  }

  jclass list_class = env->FindClass("java/util/List");
  if (list_class == nullptr) { return false; }
  jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
  if (list_add == nullptr) { return false; }

  for (int i = 0; i < device_count; i++) {
      ANeuralNetworksDevice* device = nullptr;
      nnapi_result = NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
      if (nnapi_result != 0) {
          return false;
       }
      const char* buffer = nullptr;
      nnapi_result = NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer);
      if (nnapi_result != 0) {
        return false;
      }

      auto device_name = env->NewStringUTF(buffer);

      env->CallBooleanMethod(resultList, list_add, device_name);
      if (env->ExceptionCheck()) { return false; }
  }
  return true;
}

static jfloatArray convertToJfloatArray(JNIEnv* env, const std::vector<float>& from) {
  jfloatArray to = env->NewFloatArray(from.size());
  if (env->ExceptionCheck()) {
    return nullptr;
  }
  jfloat* bytes = env->GetFloatArrayElements(to, nullptr);
  memcpy(bytes, from.data(), from.size() * sizeof(float));
  env->ReleaseFloatArrayElements(to, bytes, 0);
  return to;
}

extern "C" JNIEXPORT jobject JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runCompilationBenchmark(
    JNIEnv* env,
    jobject /* this */,
    jlong _modelHandle,
    jint maxNumIterations,
    jfloat warmupTimeoutSec,
    jfloat runTimeoutSec,
    jboolean useNnapiSl) {
  BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);

  jclass result_class = env->FindClass("com/android/nn/benchmark/core/CompilationBenchmarkResult");
  if (result_class == nullptr) return nullptr;
  jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "([F[F[FI)V");
  if (result_ctor == nullptr) return nullptr;

  CompilationBenchmarkResult result;
  bool success =
          model->benchmarkCompilation(maxNumIterations, warmupTimeoutSec,
                                      runTimeoutSec, useNnapiSl, &result);
  if (!success) return nullptr;

  // Convert cpp CompilationBenchmarkResult struct to java.
  jfloatArray compileWithoutCacheArray =
          convertToJfloatArray(env, result.compileWithoutCacheTimeSec);
  if (compileWithoutCacheArray == nullptr) return nullptr;

  // saveToCache and prepareFromCache results may not exist.
  jfloatArray saveToCacheArray = nullptr;
  if (result.saveToCacheTimeSec) {
    saveToCacheArray = convertToJfloatArray(env, result.saveToCacheTimeSec.value());
    if (saveToCacheArray == nullptr) return nullptr;
  }
  jfloatArray prepareFromCacheArray = nullptr;
  if (result.prepareFromCacheTimeSec) {
    prepareFromCacheArray = convertToJfloatArray(env, result.prepareFromCacheTimeSec.value());
    if (prepareFromCacheArray == nullptr) return nullptr;
  }

  jobject object = env->NewObject(result_class, result_ctor, compileWithoutCacheArray,
                                  saveToCacheArray, prepareFromCacheArray, result.cacheSizeBytes);
  if (env->ExceptionCheck()) return nullptr;
  return object;
}
