/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <C2AllocatorIon.h>
#include <C2Buffer.h>
#include <C2BufferPriv.h>
#include <C2Config.h>
#include <C2Debug.h>
#include <codec2/hidl/client.h>
#include <gui/BufferQueue.h>
#include <gui/IConsumerListener.h>
#include <gui/IProducerListener.h>

#include <chrono>
#include <condition_variable>
#include <fstream>
#include <iostream>

#include "../includes/common.h"

using android::C2AllocatorIon;
using std::chrono_literals::operator""ms;

#define MAXIMUM_NUMBER_OF_RETRIES 20
#define QUEUE_TIMEOUT 400ms
#define MAXIMUM_NUMBER_OF_INPUT_BUFFERS 8
#define ALIGN(_sz, _align) ((_sz + (_align - 1)) & ~(_align - 1))

void workDone(const std::shared_ptr<android::Codec2Client::Component> &component,
              std::unique_ptr<C2Work> &work, std::list<uint64_t> &flushedIndices,
              std::mutex &queueLock, std::condition_variable &queueCondition,
              std::list<std::unique_ptr<C2Work>> &workQueue, bool &eos, bool &csd,
              uint32_t &framesReceived);

struct FrameInfo {
    int bytesCount;
    uint32_t flags;
    int64_t timestamp;
};

class LinearBuffer : public C2Buffer {
public:
    explicit LinearBuffer(const std::shared_ptr<C2LinearBlock> &block)
          : C2Buffer({block->share(block->offset(), block->size(), ::C2Fence())}) {}

    explicit LinearBuffer(const std::shared_ptr<C2LinearBlock> &block, size_t size)
          : C2Buffer({block->share(block->offset(), size, ::C2Fence())}) {}
};

/*
 * Handle Callback functions onWorkDone(), onTripped(),
 * onError(), onDeath(), onFramesRendered()
 */
struct CodecListener : public android::Codec2Client::Listener {
public:
    CodecListener(
            const std::function<void(std::list<std::unique_ptr<C2Work>> &workItems)> fn = nullptr)
          : callBack(fn) {}
    virtual void onWorkDone(const std::weak_ptr<android::Codec2Client::Component> &comp,
                            std::list<std::unique_ptr<C2Work>> &workItems) override {
        (void)comp;
        if (callBack) {
            callBack(workItems);
        }
    }

    virtual void onTripped(
            const std::weak_ptr<android::Codec2Client::Component> &comp,
            const std::vector<std::shared_ptr<C2SettingResult>> &settingResults) override {
        (void)comp;
        (void)settingResults;
    }

    virtual void onError(const std::weak_ptr<android::Codec2Client::Component> &comp,
                         uint32_t errorCode) override {
        (void)comp;
        (void)errorCode;
    }

    virtual void onDeath(const std::weak_ptr<android::Codec2Client::Component> &comp) override {
        (void)comp;
    }

    virtual void onInputBufferDone(uint64_t frameIndex, size_t arrayIndex) override {
        (void)frameIndex;
        (void)arrayIndex;
    }

    virtual void onFrameRendered(uint64_t bufferQueueId, int32_t slotId,
                                 int64_t timestampNs) override {
        (void)bufferQueueId;
        (void)slotId;
        (void)timestampNs;
    }
    std::function<void(std::list<std::unique_ptr<C2Work>> &workItems)> callBack;
};

class Codec2VideoDecHidlTestBase {
public:
    bool SetUp() {
        mClient = getClient();
        if (!mClient) {
            return false;
        }
        mListener.reset(new CodecListener([this](std::list<std::unique_ptr<C2Work>> &workItems) {
            handleWorkDone(workItems);
        }));
        if (!mListener) {
            return false;
        }
        if (android::Codec2Client::CreateComponentByName(mComponentName.c_str(), mListener,
                                                         &mComponent, &mClient) != C2_OK) {
            return false;
        }
        for (int i = 0; i < MAXIMUM_NUMBER_OF_INPUT_BUFFERS; ++i) {
            mWorkQueue.emplace_back(new C2Work);
        }
        std::shared_ptr<C2AllocatorStore> store = android::GetCodec2PlatformAllocatorStore();
        if (store->fetchAllocator(C2AllocatorStore::DEFAULT_LINEAR, &mLinearAllocator) != C2_OK) {
            return false;
        }
        mLinearPool = std::make_shared<C2PooledBlockPool>(mLinearAllocator, ++mBlockPoolId);
        if (!mLinearPool) {
            return false;
        }
        mEos = false;
        mHasVulnerability = false;
        mTimestampUs = 0u;
        mWorkResult = C2_OK;
        mFramesReceived = 0;
        return true;
    }

    ~Codec2VideoDecHidlTestBase() {
        if (mComponent != nullptr) {
            mComponent->release();
            mComponent = nullptr;
        }
    }

    std::shared_ptr<android::Codec2Client> getClient() {
        auto instances = android::Codec2Client::GetServiceNames();
        for (std::string instance : instances) {
            std::shared_ptr<android::Codec2Client> client =
                    android::Codec2Client::CreateFromService(instance.c_str());
            std::vector<C2Component::Traits> components = client->listComponents();
            for (C2Component::Traits traits : components) {
                if (instance.compare(traits.owner)) {
                    continue;
                }
                if (traits.domain == DOMAIN_VIDEO && traits.kind == KIND_DECODER &&
                    mComponentName.compare(traits.name)) {
                    return android::Codec2Client::
                            CreateFromService(instance.c_str(),
                                              !bool(android::Codec2Client::
                                                            CreateFromService("default", true)));
                }
            }
        }
        return nullptr;
    }

    void checkBufferOK(std::unique_ptr<C2Work> &work) {
        const C2GraphicView output = work->worklets.front()
                                             ->output.buffers[0]
                                             ->data()
                                             .graphicBlocks()
                                             .front()
                                             .map()
                                             .get();
        uint8_t *uPlane = const_cast<uint8_t *>(output.data()[C2PlanarLayout::PLANE_U]);
        uint8_t *vPlane = const_cast<uint8_t *>(output.data()[C2PlanarLayout::PLANE_V]);
        const uint8_t ul[] = {109, 109, 109, 109, 109, 109, 109};
        const uint8_t vl[] = {121, 121, 121, 121, 121, 121, 121};
        const uint8_t ur[] = {114, 114, 120, 120, 122, 127, 127};
        const uint8_t vr[] = {126, 121, 123, 121, 123, 126, 126};
        if (!memcmp(uPlane - 7, ul, 7) && !memcmp(vPlane - 7, vl, 7) &&
            !memcmp(uPlane + 1, ur, 7) && !memcmp(vPlane + 1, vr, 7)) {
            mHasVulnerability |= true;
        }
    }

    // Config output pixel format
    bool configPixelFormat(uint32_t format) {
        std::vector<std::unique_ptr<C2SettingResult>> failures;
        C2StreamPixelFormatInfo::output pixelformat(0u, format);

        std::vector<C2Param *> configParam{&pixelformat};
        c2_status_t status = mComponent->config(configParam, C2_DONT_BLOCK, &failures);
        if (status == C2_OK && failures.size() == 0u) {
            return true;
        }
        return false;
    }

    // callback function to process onWorkDone received by Listener
    void handleWorkDone(std::list<std::unique_ptr<C2Work>> &workItems) {
        for (std::unique_ptr<C2Work> &work : workItems) {
            if (!work->worklets.empty()) {
                // For decoder components current timestamp always exceeds
                // previous timestamp if output is in display order
                mWorkResult |= work->result;
                bool codecConfig = ((work->worklets.front()->output.flags &
                                     C2FrameData::FLAG_CODEC_CONFIG) != 0);
                if (!codecConfig && !work->worklets.front()->output.buffers.empty()) {
                    checkBufferOK(work);
                }
                bool mCsd = false;
                workDone(mComponent, work, mFlushedIndices, mQueueLock, mQueueCondition, mWorkQueue,
                         mEos, mCsd, mFramesReceived);
                (void)mCsd;
            }
        }
    }

    const std::string mComponentName = "c2.android.hevc.decoder";
    bool mEos;
    bool mHasVulnerability;
    uint64_t mTimestampUs;
    int32_t mWorkResult;
    uint32_t mFramesReceived;
    std::list<uint64_t> mFlushedIndices;

    C2BlockPool::local_id_t mBlockPoolId;
    std::shared_ptr<C2BlockPool> mLinearPool;
    std::shared_ptr<C2Allocator> mLinearAllocator;

    std::mutex mQueueLock;
    std::condition_variable mQueueCondition;
    std::list<std::unique_ptr<C2Work>> mWorkQueue;

    std::shared_ptr<android::Codec2Client> mClient;
    std::shared_ptr<android::Codec2Client::Listener> mListener;
    std::shared_ptr<android::Codec2Client::Component> mComponent;
};

// process onWorkDone received by Listener
void workDone(const std::shared_ptr<android::Codec2Client::Component> &component,
              std::unique_ptr<C2Work> &work, std::list<uint64_t> &flushedIndices,
              std::mutex &queueLock, std::condition_variable &queueCondition,
              std::list<std::unique_ptr<C2Work>> &workQueue, bool &eos, bool &csd,
              uint32_t &framesReceived) {
    // handle configuration changes in work done
    if (work->worklets.front()->output.configUpdate.size() != 0) {
        std::vector<std::unique_ptr<C2Param>> updates =
                std::move(work->worklets.front()->output.configUpdate);
        std::vector<C2Param *> configParam;
        std::vector<std::unique_ptr<C2SettingResult>> failures;
        for (size_t i = 0; i < updates.size(); ++i) {
            C2Param *param = updates[i].get();
            if (param->index() == C2StreamInitDataInfo::output::PARAM_TYPE) {
                C2StreamInitDataInfo::output *csdBuffer = (C2StreamInitDataInfo::output *)(param);
                size_t csdSize = csdBuffer->flexCount();
                if (csdSize > 0) {
                    csd = true;
                }
            } else if ((param->index() == C2StreamSampleRateInfo::output::PARAM_TYPE) ||
                       (param->index() == C2StreamChannelCountInfo::output::PARAM_TYPE) ||
                       (param->index() == C2StreamPictureSizeInfo::output::PARAM_TYPE)) {
                configParam.push_back(param);
            }
        }
        component->config(configParam, C2_DONT_BLOCK, &failures);
        assert(failures.size() == 0u);
    }
    if (work->worklets.front()->output.flags != C2FrameData::FLAG_INCOMPLETE) {
        ++framesReceived;
        eos = (work->worklets.front()->output.flags & C2FrameData::FLAG_END_OF_STREAM) != 0;
        auto frameIndexIt = std::find(flushedIndices.begin(), flushedIndices.end(),
                                      work->input.ordinal.frameIndex.peeku());
        work->input.buffers.clear();
        work->worklets.clear();
        {
            typedef std::unique_lock<std::mutex> ULock;
            ULock l(queueLock);
            workQueue.push_back(std::move(work));
            if (!flushedIndices.empty() && (frameIndexIt != flushedIndices.end())) {
                flushedIndices.erase(frameIndexIt);
            }
            queueCondition.notify_all();
        }
    }
}

bool decodeNFrames(const std::shared_ptr<android::Codec2Client::Component> &component,
                   std::mutex &queueLock, std::condition_variable &queueCondition,
                   std::list<std::unique_ptr<C2Work>> &workQueue,
                   std::list<uint64_t> &flushedIndices, std::shared_ptr<C2BlockPool> &linearPool,
                   std::ifstream &ifStream, android::Vector<FrameInfo> *Info) {
    typedef std::unique_lock<std::mutex> ULock;
    int frameID = 0;
    int retryCount = 0;
    while (1) {
        if (frameID == (int)Info->size()) {
            break;
        }
        uint32_t flags = 0;
        std::unique_ptr<C2Work> work;
        // Prepare C2Work
        while (!work && (retryCount < MAXIMUM_NUMBER_OF_RETRIES)) {
            ULock l(queueLock);
            if (!workQueue.empty()) {
                work.swap(workQueue.front());
                workQueue.pop_front();
            } else {
                queueCondition.wait_for(l, QUEUE_TIMEOUT);
                ++retryCount;
            }
        }
        if (!work && (retryCount >= MAXIMUM_NUMBER_OF_RETRIES)) {
            return false; // "Wait for generating C2Work exceeded timeout"
        }
        int64_t timestamp = (*Info)[frameID].timestamp;
        if ((*Info)[frameID].flags) {
            flags = (1 << ((*Info)[frameID].flags - 1));
        }
        if (frameID == (int)Info->size() - 1) {
            flags |= C2FrameData::FLAG_END_OF_STREAM;
        }

        work->input.flags = (C2FrameData::flags_t)flags;
        work->input.ordinal.timestamp = timestamp;
        work->input.ordinal.frameIndex = frameID;
        {
            ULock l(queueLock);
            flushedIndices.emplace_back(frameID);
        }

        int size = (*Info)[frameID].bytesCount;
        char *data = (char *)malloc(size);
        if (!data) {
            return false;
        }

        ifStream.read(data, size);
        if (ifStream.gcount() != size) {
            return false;
        }

        work->input.buffers.clear();
        auto alignedSize = ALIGN(size, getpagesize());
        if (size) {
            std::shared_ptr<C2LinearBlock> block;
            if (linearPool->fetchLinearBlock(alignedSize,
                                             {C2MemoryUsage::CPU_READ, C2MemoryUsage::CPU_WRITE},
                                             &block) != C2_OK) {
                return false;
            }
            if (!block) {
                return false;
            }

            // Write View
            C2WriteView view = block->map().get();
            if (view.error() != C2_OK) {
                return false;
            }
            if ((size_t)alignedSize != view.capacity()) {
                return false;
            }
            if (0u != view.offset()) {
                return false;
            }
            if ((size_t)alignedSize != view.size()) {
                return false;
            }

            memcpy(view.base(), data, size);

            work->input.buffers.emplace_back(new LinearBuffer(block, size));
            free(data);
        }
        work->worklets.clear();
        work->worklets.emplace_back(new C2Worklet);
        std::list<std::unique_ptr<C2Work>> items;
        items.push_back(std::move(work));

        // DO THE DECODING
        if (component->queue(&items) != C2_OK) {
            return false;
        }
        ++frameID;
        retryCount = 0;
    }
    return true;
}

// Wait for all the inputs to be consumed by the plugin.
void waitOnInputConsumption(std::mutex &queueLock, std::condition_variable &queueCondition,
                            std::list<std::unique_ptr<C2Work>> &workQueue,
                            size_t bufferCount = MAXIMUM_NUMBER_OF_INPUT_BUFFERS) {
    typedef std::unique_lock<std::mutex> ULock;
    uint32_t queueSize;
    uint32_t retryCount = 0;
    {
        ULock l(queueLock);
        queueSize = workQueue.size();
    }
    while ((retryCount < MAXIMUM_NUMBER_OF_RETRIES) && (queueSize < bufferCount)) {
        ULock l(queueLock);
        if (queueSize != workQueue.size()) {
            queueSize = workQueue.size();
            retryCount = 0;
        } else {
            queueCondition.wait_for(l, QUEUE_TIMEOUT);
            ++retryCount;
        }
    }
}

// Populate Info vector and return number of CSDs
int32_t populateInfoVector(std::string info, android::Vector<FrameInfo> *frameInfo) {
    std::ifstream eleInfo;
    eleInfo.open(info);
    if (!eleInfo.is_open()) {
        return -1;
    }
    int32_t numCsds = 0;
    int32_t bytesCount = 0;
    uint32_t flags = 0;
    uint32_t timestamp = 0;
    while (1) {
        if (!(eleInfo >> bytesCount)) {
            break;
        }
        eleInfo >> flags;
        eleInfo >> timestamp;
        bool codecConfig = flags ? ((1 << (flags - 1)) & C2FrameData::FLAG_CODEC_CONFIG) != 0 : 0;
        if (codecConfig) {
            ++numCsds;
        }
        frameInfo->push_back({bytesCount, flags, timestamp});
    }
    eleInfo.close();
    return numCsds;
}

#define RETURN_FAILURE(condition) \
    if ((condition)) {            \
        return EXIT_FAILURE;      \
    }

int main(int argc, char **argv) {
    RETURN_FAILURE(argc != 3);

    Codec2VideoDecHidlTestBase handle;
    RETURN_FAILURE(!handle.SetUp());
    RETURN_FAILURE(!handle.configPixelFormat(HAL_PIXEL_FORMAT_YCBCR_420_888));

    std::string eleStreamInfo{argv[2]};
    android::Vector<FrameInfo> Info;
    RETURN_FAILURE(populateInfoVector(eleStreamInfo, &Info) < 0);
    RETURN_FAILURE(handle.mComponent->start() != C2_OK);

    std::string eleStream{argv[1]};
    std::ifstream ifStream;
    ifStream.open(eleStream, std::ifstream::binary);
    RETURN_FAILURE(!ifStream.is_open());
    RETURN_FAILURE(!decodeNFrames(handle.mComponent, handle.mQueueLock, handle.mQueueCondition,
                                  handle.mWorkQueue, handle.mFlushedIndices, handle.mLinearPool,
                                  ifStream, &Info));
    // blocking call to ensures application to Wait till all the inputs are
    // consumed
    if (!handle.mEos) {
        waitOnInputConsumption(handle.mQueueLock, handle.mQueueCondition, handle.mWorkQueue);
    }
    ifStream.close();
    RETURN_FAILURE(handle.mFramesReceived != Info.size());
    RETURN_FAILURE(handle.mComponent->stop() != C2_OK);
    RETURN_FAILURE(handle.mWorkResult != C2_OK);
    if (handle.mHasVulnerability) {
        return EXIT_VULNERABLE;
    }
    return EXIT_SUCCESS;
}
