/*
 * Copyright 2018 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "fcp/secagg/shared/aes_ctr_prng.h"

#include <algorithm>
#include <cstdint>
#include <string>

#include "fcp/base/monitoring.h"
#include "fcp/secagg/shared/aes_key.h"
#include "fcp/secagg/shared/prng.h"
#include "openssl/cipher.h"
#include "openssl/evp.h"

namespace fcp {
namespace secagg {

AesCtrPrng::AesCtrPrng(const AesKey& seed) {
  uint8_t iv[kIvSize];
  memset(iv, 0, kIvSize);
  FCP_CHECK(ctx_ = EVP_CIPHER_CTX_new());

  FCP_CHECK(1 == EVP_EncryptInit_ex(ctx_, EVP_aes_256_ctr(), nullptr,
                                    seed.data(), iv));

  // Initializing these to one past the end, in order to force a call to
  // GenerateBytes on the first attempt to use each cache.
  next_byte_pos_ = kCacheSize;
  blocks_generated_ = 0;
}

AesCtrPrng::~AesCtrPrng() { EVP_CIPHER_CTX_free(ctx_); }

void AesCtrPrng::GenerateBytes(uint8_t* cache, int cache_size) {
  FCP_CHECK((cache_size % kBlockSize) == 0)
      << "Number of bytes generated by AesCtrPrng must be a multiple of "
      << kBlockSize;
  FCP_CHECK(cache_size <= kCacheSize)
      << "Requested number of bytes " << cache_size
      << " exceeds maximum cache size " << kCacheSize;
  FCP_CHECK(blocks_generated_ <= kMaxBlocks)
      << "AesCtrPrng generated " << kMaxBlocks
      << " blocks and needs a new seed.";
  int bytes_written;
  FCP_CHECK(
      EVP_EncryptUpdate(ctx_, cache, &bytes_written, kAllZeroes, cache_size));
  FCP_CHECK(bytes_written == cache_size);
  blocks_generated_ += static_cast<size_t>(cache_size) / kBlockSize;
}

uint8_t AesCtrPrng::Rand8() {
  if (next_byte_pos_ >= kCacheSize) {
    GenerateBytes(cache_, kCacheSize);
    next_byte_pos_ = 0;
  }
  // Return the next byte and then increment the position.
  return cache_[next_byte_pos_++];
}

uint64_t AesCtrPrng::Rand64() {
  uint64_t output = 0;
  for (size_t i = 0; i < sizeof(uint64_t); ++i) {
    output |= static_cast<uint64_t>(Rand8()) << 8 * i;
  }
  return output;
}

int AesCtrPrng::RandBuffer(uint8_t* buffer, int buffer_size) {
  buffer_size = std::min(buffer_size, kCacheSize);
  GenerateBytes(buffer, buffer_size);
  return buffer_size;
}

}  // namespace secagg
}  // namespace fcp
