/**
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <fstream>
#include <iostream>
#include <limits>
#include "../includes/common.h"
#include "aacdecoder.h"
#include "aacdecoder_lib.h"
#include "sbr_ram.h"

constexpr uint8_t kNumberOfLayers = 1;
constexpr uint8_t kMaxChannelCount = 8;
bool kIsVulnerable = false;

class Codec {
   public:
    Codec() = default;
    ~Codec() { deInitDecoder(); }
    bool initDecoder();
    void decodeFrames(UCHAR *data, UINT size);
    void deInitDecoder();
    int validateQmfDomainBounds();

   private:
    HANDLE_AACDECODER mAacDecoderHandle = nullptr;
    AAC_DECODER_ERROR mErrorCode = AAC_DEC_OK;
};

bool Codec::initDecoder() {
    mAacDecoderHandle = aacDecoder_Open(TT_MP4_ADIF, kNumberOfLayers);
    if (!mAacDecoderHandle) {
        return false;
    }
    return true;
}

void Codec::deInitDecoder() {
    aacDecoder_Close(mAacDecoderHandle);
    mAacDecoderHandle = nullptr;
}

void Codec::decodeFrames(UCHAR *data, UINT size) {
    while (size > 0) {
        UINT inputSize = size;
        UINT valid = size;
        mErrorCode = aacDecoder_Fill(mAacDecoderHandle, &data, &inputSize, &valid);
        if (mErrorCode != AAC_DEC_OK) {
            ++data;
            --size;
        } else {
            INT_PCM outputBuf[2048 * kMaxChannelCount];
            aacDecoder_DecodeFrame(mAacDecoderHandle, outputBuf, 2048 * kMaxChannelCount, 0);
            if (valid >= inputSize) {
                return;
            }
            UINT offset = inputSize - valid;
            data += offset;
            size = valid;
        }
    }
}

int Codec::validateQmfDomainBounds() {
    // Check OOB for qmfDomain
    void *qmfDomainBound = &(mAacDecoderHandle->qmfModeCurr);

    HANDLE_SBRDECODER hSbrDecoder = mAacDecoderHandle->hSbrDecoder;
    // Referring sbrDecoder_AssignQmfChannels2SbrChannels()
    {
        void *qmfDomainInChPtr = nullptr;
        void *qmfDomainOutChPtr = nullptr;
        int channel = 0;
        int element = 0;
        int absChOffset = 0;
        for (element = 0; element < hSbrDecoder->numSbrElements; ++element) {
            if (hSbrDecoder->pSbrElement[element] != NULL) {
                for (channel = 0; channel < hSbrDecoder->pSbrElement[element]->nChannels;
                     ++channel) {
                    qmfDomainInChPtr = &hSbrDecoder->pQmfDomain->QmfDomainIn[absChOffset + channel];
                    qmfDomainOutChPtr =
                        &hSbrDecoder->pQmfDomain->QmfDomainOut[absChOffset + channel];
                    if (qmfDomainBound <= qmfDomainInChPtr || qmfDomainBound <= qmfDomainOutChPtr) {
                        kIsVulnerable = true;
                    }
                }
                absChOffset += hSbrDecoder->pSbrElement[element]->nChannels;
            }
        }
    }
    return kIsVulnerable ? EXIT_VULNERABLE : EXIT_SUCCESS;
}

int main(int argc, char *argv[]) {
    if (argc != 2) {
        return EXIT_FAILURE;
    }

    std::ifstream file;
    file.open(argv[1], std::ios::in | std::ios::binary);
    if (!file.is_open()) {
        return EXIT_FAILURE;
    }
    file.ignore(std::numeric_limits<std::streamsize>::max());
    std::streamsize length = file.gcount();
    file.clear();
    file.seekg(0, std::ios_base::beg);
    UCHAR *data = new UCHAR[length];
    file.read(reinterpret_cast<char *>(data), length);
    file.close();

    Codec codec = Codec();
    if (codec.initDecoder()) {
        codec.decodeFrames(data, length);
    }
    return codec.validateQmfDomainBounds();
}
