// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <array>
#include <cstdint>
#include <random>
#include <span>
#include <utility>
#include <vector>

#include "fuzztest/domain_core.h"
#include "fuzztest/fuzztest.h"
#include "gtest/gtest.h"
#include "nearby_protocol.h"
#include "np_cpp_ffi_types.h"
#include "shared_test_util.h"

// redefine test data as std::vector types since fuzztest does not support
// template class use in its input domains (ie: std::array<T, N>), and we want
// to use our test data to seed the fuzzer
static std::vector<uint8_t> V0AdvEmptyVec(V0AdvEmptyBytes.begin(),
                                          V0AdvEmptyBytes.end());
static std::vector<uint8_t> V1AdvEmptyVec(V1AdvEmptyBytes.begin(),
                                          V1AdvEmptyBytes.end());
static std::vector<uint8_t> V0AdvPlaintextVec(V0AdvPlaintextBytes.begin(),
                                              V0AdvPlaintextBytes.end());
static std::vector<uint8_t> V0AdvPlaintextMultiDeVec(
    V0AdvPlaintextMultiDeBytes.begin(), V0AdvPlaintextMultiDeBytes.end());
static std::vector<uint8_t> V1AdvPlaintextVec(V1AdvPlaintextBytes.begin(),
                                              V1AdvPlaintextBytes.end());
static std::vector<uint8_t> V0AdvEncryptedVec(V0AdvEncryptedBytes.begin(),
                                              V0AdvEncryptedBytes.end());
static std::vector<uint8_t> V1AdvEncryptedVec(V1AdvEncryptedBytes.begin(),
                                              V1AdvEncryptedBytes.end());

void HandleAdvertisementResult(nearby_protocol::DeserializeAdvertisementResult);

void PlaintextDeserializer(std::span<const uint8_t> adv_bytes) {
  nearby_protocol::CredentialSlab slab;
  nearby_protocol::CredentialBook book(slab);
  auto buffer = nearby_protocol::ByteBuffer<255>::TryFromSpan(adv_bytes);
  EXPECT_TRUE(buffer.ok());

  nearby_protocol::RawAdvertisementPayload payload(
      (nearby_protocol::ByteBuffer<255>(*buffer)));
  auto deserialize_result =
      nearby_protocol::Deserializer::DeserializeAdvertisement(payload, book);

  // Since we are seeding with valid data, we can add extra calls into the
  // result processing APIs to ensure none of the internal asserts are
  // triggered.
  HandleAdvertisementResult(std::move(deserialize_result));
}

FUZZ_TEST(NpCppDeserializationFuzzers, PlaintextDeserializer)
    .WithDomains(fuzztest::Arbitrary<std::vector<uint8_t>>()
                     .WithMinSize(0)
                     .WithMaxSize(255))
    .WithSeeds({V0AdvEmptyVec, V1AdvEmptyVec, V0AdvPlaintextVec,
                V0AdvPlaintextMultiDeVec, V1AdvPlaintextVec, V0AdvEncryptedVec,
                V1AdvEncryptedVec});

// The data which is automatically generated by the fuzzer
struct IdentityData {
  uint32_t credential_id;
  std::array<uint8_t, 32> key_seed;
  std::array<uint8_t, 32> legacy_metadata_key_hmac;
  std::array<uint8_t, 32> expected_unsigned_identity_token_hmac;
  std::array<uint8_t, 32> expected_signed_identity_token_hmac;
  std::array<uint8_t, 32> pub_key;
  std::vector<uint8_t> encrypted_metadata_bytes;
};

static struct IdentityData V0TestCaseIdentityData {
  .credential_id = static_cast<uint32_t>(rand()), .key_seed = V0AdvKeySeed,
  .legacy_metadata_key_hmac = V0AdvLegacyIdentityTokenHmac,
  .encrypted_metadata_bytes = V0AdvEncryptedMetadata
};

static struct IdentityData V1TestCaseIdentityData {
  .credential_id = static_cast<uint32_t>(rand()), .key_seed = V1AdvKeySeed,
  .expected_unsigned_identity_token_hmac =
      V1AdvExpectedMicExtendedSaltIdentityTokenHmac,
  .expected_signed_identity_token_hmac =
      V1AdvExpectedSignatureIdentityTokenHmac,
  .pub_key = V1AdvPublicKey, .encrypted_metadata_bytes = V1AdvEncryptedMetadata,
};

// Now lets try feeding the fuzzer some credential data that can successfully
// decrypt advertisements to improve its efficiency, and provide extra coverage
// on on credential/iteration code paths.
// TODO: Add more interesting credential seed data once we have C++
// serialization APIs, ie multiple sections with multiple valid credentials
// and combinations of matching and undecryptable sections etc.
void DeserializeWithCredentials(std::span<const IdentityData> identities,
                                std::span<const uint8_t> adv_bytes) {
  nearby_protocol::CredentialSlab slab;
  // populate book with fuzzer generated credential data
  for (auto data : identities) {
    nearby_protocol::MatchedCredentialData match_data(
        123, data.encrypted_metadata_bytes);
    nearby_protocol::V0MatchableCredential v0_cred(
        data.key_seed, data.legacy_metadata_key_hmac, match_data);
    // adding v0 credentials is infallible
    slab.AddV0Credential(v0_cred);

    nearby_protocol::V1MatchableCredential v1_cred(
        data.key_seed, data.expected_unsigned_identity_token_hmac,
        data.expected_signed_identity_token_hmac, data.pub_key, match_data);
    [[maybe_unused]] auto result = slab.AddV1Credential(v1_cred);
  }

  nearby_protocol::CredentialBook book(slab);
  auto buffer = nearby_protocol::ByteBuffer<255>::TryFromSpan(adv_bytes);
  EXPECT_TRUE(buffer.ok());

  nearby_protocol::RawAdvertisementPayload payload(
      (nearby_protocol::ByteBuffer<255>(*buffer)));
  auto deserialize_result =
      nearby_protocol::Deserializer::DeserializeAdvertisement(payload, book);

  // Since we are seeding with valid data, we can add extra calls into the
  // result processing APIs to ensure none of the internal asserts are
  // triggered.
  HandleAdvertisementResult(std::move(deserialize_result));
}

std::vector<std::tuple<std::vector<IdentityData>, std::vector<uint8_t>>>
DeserializeWithCredentialSeedData() {
  return {
      {{{V0TestCaseIdentityData, V1TestCaseIdentityData}, V0AdvEncryptedVec},
       {{V0TestCaseIdentityData, V1TestCaseIdentityData}, V1AdvEncryptedVec}},
  };
}

TEST(NpCppDeserializationFuzzers, InvalidPublicKeyInCredential) {
  std::vector<IdentityData> identities;
  identities.push_back(
      {1804289383,
       {17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
        17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17},
       {136, 51,  222, 213, 77,  0,   146, 232, 128, 112, 213,
        31,  24,  236, 34,  69,  117, 124, 36,  223, 227, 140,
        178, 222, 119, 182, 120, 133, 252, 165, 103, 77},
       {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
       {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
       {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
       fuzztest::ToByteArray("")});
  identities.push_back({846930886,
                        {49, 67, 99, 30,  202, 232, 151, 75,  150, 80,  204,
                         28, 72, 37, 14,  129, 88,  6,   129, 81,  249, 235,
                         37, 35, 3,  212, 151, 109, 149, 25,  145, 57},
                        {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
                        {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
                        {28,  188, 235, 220, 23,  181, 145, 229, 7,   157, 112,
                         193, 232, 75,  204, 219, 75,  15,  118, 131, 89,  98,
                         10,  45,  85,  11,  59,  54,  164, 146, 139, 19},
                        {109, 13,  182, 9,   16,  177, 83,  196, 126, 16,  22,
                         20,  156, 159, 242, 20,  15,  236, 83,  118, 227, 7,
                         217, 211, 158, 174, 231, 69,  44,  3,   236, 109},
                        fuzztest::ToByteArray("")});

  DeserializeWithCredentials(
      identities,
      fuzztest::ToByteArray(
          " g\221\020\010\255iF\004]"
          "\256m\267\367\\\323\270\254\360\277u\220\001\276s3\244v\204J\t"
          "\017+"
          "\231G\337\213F\312\026\316\023\265nS\256("
          "VD\016\246\215\353\241\021\257N\033\340\216\365\272\220O."
          "\224\374\336\246\177]"
          "\3107\265\357\312\254\213\237\033\324\306\021\205\323g92\321\202"
          "\312"
          "N\271F\003\203hV\013\314\375z*\276\007"));
}

TEST(NpCppDeserializationFuzzers, PlaintextDeserializerRegression) {
  PlaintextDeserializer(fuzztest::ToByteArray(
      " 4\221\020\005\255iF\004]"
      "\256\275m\267\367\\\270\254\360\277u\220\001G\337\213F\312\026\316\023"
      "\265nS\256(VD\016\243\215\353\241\021\257N\020\340\216\365\272\220O."
      "\224\374\336\246\177]"
      "\3107\267\357\312\254\213\237\033\324\306\021\205\323g92\321\202\312N"
      "\271D\003\203hV\013\314\375z*\276\007"));
}

FUZZ_TEST(NpCppDeserializationFuzzers, DeserializeWithCredentials)
    .WithDomains(fuzztest::Arbitrary<std::vector<IdentityData>>()
                     .WithMinSize(0)
                     .WithMaxSize(1000),
                 fuzztest::Arbitrary<std::vector<uint8_t>>()
                     .WithMinSize(0)
                     .WithMaxSize(255))
    .WithSeeds(DeserializeWithCredentialSeedData);

// Lets encourage the fuzzer to try with a lot of credentials and make sure
// nothing falls apart. By default the vec ranges are somewhere between 1-20
// even with setting a max of 1000, so setting a high minimum here to ensure the
// higher end of credential iteration is hit.
FUZZ_TEST(NpCppDeserializationFuzzersLotsOfCredentials,
          DeserializeWithCredentials)
    .WithDomains(
        fuzztest::Arbitrary<std::vector<IdentityData>>().WithMinSize(10000),
        fuzztest::Arbitrary<std::vector<uint8_t>>().WithMinSize(0).WithMaxSize(
            255));

// Helpers to trigger result processing code paths.
void HandleV0Adv(nearby_protocol::DeserializedV0Advertisement);
void HandleLegibleV0Adv(nearby_protocol::LegibleDeserializedV0Advertisement);
void HandleV0IdentityKind(nearby_protocol::DeserializedV0IdentityKind);
void HandleDataElement(nearby_protocol::V0DataElement);

void HandleV1Adv(nearby_protocol::DeserializedV1Advertisement);
void HandleV1Section(nearby_protocol::DeserializedV1Section);
void HandleV1DataElement(nearby_protocol::V1DataElement);

void HandleAdvertisementResult(
    nearby_protocol::DeserializeAdvertisementResult result) {
  switch (result.GetKind()) {
    case nearby_protocol::DeserializeAdvertisementResultKind::Error:
      break;
    case nearby_protocol::DeserializeAdvertisementResultKind::V0:
      HandleV0Adv(result.IntoV0());
      break;
    case nearby_protocol::DeserializeAdvertisementResultKind::V1:
      HandleV1Adv(result.IntoV1());
      break;
  }
}

void HandleV0Adv(nearby_protocol::DeserializedV0Advertisement result) {
  switch (result.GetKind()) {
    case nearby_protocol::DeserializedV0AdvertisementKind::Legible:
      HandleLegibleV0Adv(result.IntoLegible());
      break;
    case nearby_protocol::DeserializedV0AdvertisementKind::
        NoMatchingCredentials:
      break;
  }
}

void HandleLegibleV0Adv(
    nearby_protocol::LegibleDeserializedV0Advertisement legible_adv) {
  auto num_des = legible_adv.GetNumberOfDataElements();
  auto payload = legible_adv.IntoPayload();
  for (int i = 0; i < num_des; i++) {
    auto de_result = payload.TryGetDataElement(i);
    if (!de_result.ok()) {
      return;
    }
    HandleDataElement(de_result.value());
  }
}

void HandleDataElement(nearby_protocol::V0DataElement de) {
  switch (de.GetKind()) {
    case nearby_protocol::V0DataElementKind::TxPower: {
      [[maybe_unused]] auto tx_power = de.AsTxPower();
      break;
    }
    case nearby_protocol::V0DataElementKind::Actions: {
      [[maybe_unused]] auto actions = de.AsActions();
      break;
    }
  }
}

void HandleV1Adv(nearby_protocol::DeserializedV1Advertisement adv) {
  auto legible_sections = adv.GetNumLegibleSections();
  [[maybe_unused]] auto encrypted_sections = adv.GetNumUndecryptableSections();
  for (auto i = 0; i < legible_sections; i++) {
    auto section_result = adv.TryGetSection(i);
    if (!section_result.ok()) {
      return;
    }
    HandleV1Section(section_result.value());
  }
}

void HandleV1Section(nearby_protocol::DeserializedV1Section section) {
  auto num_des = section.NumberOfDataElements();
  for (auto i = 0; i < num_des; i++) {
    auto de_result = section.TryGetDataElement(i);
    if (!de_result.ok()) {
      return;
    }
  }
}
