/**
 * Copyright (C) 2021 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define private public

#include "../includes/common.h"
#include "gatekeeper.h"

using namespace gatekeeper;

bool isVulnerable = false;
const uint8_t *authTokenKey = nullptr;

void *operator new(decltype(sizeof(0)) n) noexcept(false) { return malloc(n); }

void operator delete(void *ptr) throw() {
    if (ptr == authTokenKey) {
        isVulnerable = true;
    }
    if (!ptr) {
        free(ptr);
    }
}

class DerivedGateKeeper : public GateKeeper {
   protected:
    bool GetAuthTokenKey(const uint8_t **auth_token_key,
                         uint32_t *length __attribute__((unused))) const {
        *auth_token_key = (const uint8_t *)(new uint8_t());
        authTokenKey = *auth_token_key;
        return true;
    }
    void GetPasswordKey(const uint8_t **password_key __attribute__((unused)),
                        uint32_t *length __attribute__((unused))) {}
    void ComputePasswordSignature(uint8_t *signature __attribute__((unused)),
                                  uint32_t signature_length __attribute__((unused)),
                                  const uint8_t *key __attribute__((unused)),
                                  uint32_t key_length __attribute__((unused)),
                                  const uint8_t *password __attribute__((unused)),
                                  uint32_t password_length __attribute__((unused)),
                                  salt_t salt __attribute__((unused))) const {}
    void GetRandom(void *random __attribute__((unused)),
                   uint32_t requested_size __attribute__((unused))) const {}
    void ComputeSignature(uint8_t *signature __attribute__((unused)),
                          uint32_t signature_length __attribute__((unused)),
                          const uint8_t *key __attribute__((unused)),
                          uint32_t key_length __attribute__((unused)),
                          const uint8_t *message __attribute__((unused)),
                          const uint32_t length __attribute__((unused))) const {}
    uint64_t GetMillisecondsSinceBoot() const { return EXIT_SUCCESS; }
    bool GetFailureRecord(uint32_t uid __attribute__((unused)),
                          secure_id_t user_id __attribute__((unused)),
                          failure_record_t *record __attribute__((unused)),
                          bool secure __attribute__((unused))) {
        return false;
    }
    bool ClearFailureRecord(uint32_t uid __attribute__((unused)),
                            secure_id_t user_id __attribute__((unused)),
                            bool secure __attribute__((unused))) {
        return false;
    }
    bool WriteFailureRecord(uint32_t uid __attribute__((unused)),
                            failure_record_t *record __attribute__((unused)),
                            bool secure __attribute__((unused))) {
        return false;
    }
    uint32_t ComputeRetryTimeout(const failure_record_t *record __attribute__((unused))) {
        return EXIT_SUCCESS;
    }
    virtual bool IsHardwareBacked() const { return false; }
    bool DoVerify(const password_handle_t *expected_handle __attribute__((unused)),
                  const SizedBuffer &password __attribute__((unused))) {
        return false;
    }
};

int main() {
    uint8_t *auth_token = new uint8_t();
    uint32_t length = sizeof(uint32_t);
    SizedBuffer *sb = new SizedBuffer(auth_token, length);
    uint64_t timestamp = 1;
    secure_id_t user_id = 1;
    secure_id_t authenticator_id = 1;
    uint64_t challenge = 0;

    DerivedGateKeeper *object = new DerivedGateKeeper();
    object->MintAuthToken(sb, timestamp, user_id, authenticator_id, challenge);

    delete auth_token;
    delete object;
    delete sb;
    return (isVulnerable) ? EXIT_VULNERABLE : EXIT_SUCCESS;
}
