// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "discovery/mdns/mdns_querier.h"

#include <algorithm>
#include <array>
#include <bitset>
#include <memory>
#include <unordered_set>
#include <utility>
#include <vector>

#include "discovery/common/config.h"
#include "discovery/common/reporting_client.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_sender.h"
#include "discovery/mdns/public/mdns_constants.h"

namespace openscreen {
namespace discovery {
namespace {

constexpr std::array<DnsType, 5> kTranslatedNsecAnyQueryTypes = {
    DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};

bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
  if (record.dns_type() != DnsType::kNSEC) {
    return false;
  }

  const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());

  // RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
  // record to indicate this is an mDNS NSEC record rather than a traditional
  // DNS NSEC record.
  if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
      nsec.types().end()) {
    return false;
  }

  return std::find_if(nsec.types().begin(), nsec.types().end(),
                      [type](DnsType stored_type) {
                        return stored_type == type ||
                               stored_type == DnsType::kANY;
                      }) != nsec.types().end();
}

struct HashDnsType {
  inline size_t operator()(DnsType type) const {
    return static_cast<size_t>(type);
  }
};

// Helper used for sorting MDNS records. This function guarantees the following:
// - All MdnsRecords with the same name appear adjacent to each-other.
// - An NSEC record with a given name appears before all other records with the
//   same name.
bool CompareRecordByNameAndType(const MdnsRecord& first,
                                const MdnsRecord& second) {
  if (first.name() != second.name()) {
    return first.name() < second.name();
  }

  if ((first.dns_type() == DnsType::kNSEC) !=
      (second.dns_type() == DnsType::kNSEC)) {
    return first.dns_type() == DnsType::kNSEC;
  }

  return first < second;
}

class DnsTypeBitSet {
 public:
  // Returns whether any types are currently stored in this data structure.
  bool IsEmpty() { return !elements_.any(); }

  // Attempts to insert the given type into this data structure. Returns
  // true iff the type was not already present.
  bool Insert(DnsType type) {
    uint16_t bit = (type == DnsType::kANY) ? 0 : static_cast<uint16_t>(type);
    bool was_set = elements_.test(bit);
    elements_.set(bit);
    return !was_set;
  }

  // Iterates over all members of the provided container, inserting each
  // DnsType contained within to this instance. Returns true iff any element
  // inserted was not already present in this instance.
  template <typename Container>
  bool Insert(const Container& container) {
    bool has_element_been_inserted = false;
    for (DnsType type : container) {
      has_element_been_inserted |= Insert(type);
    }
    return has_element_been_inserted;
  }

  // Attempts to remove the given type from this data structure. Returns true
  // iff the type was present prior to this call.
  bool Remove(DnsType type) {
    if (IsEmpty()) {
      return false;
    } else if (type == DnsType::kANY) {
      elements_.reset();
      return true;
    }

    uint16_t bit = static_cast<uint16_t>(type);
    bool was_set = elements_.test(bit);
    elements_.reset(bit);
    return was_set;
  }

  // Returns the DnsTypes currently stored in this data structure.
  std::vector<DnsType> GetTypes() const {
    if (elements_.test(0)) {
      return {DnsType::kANY};
    }

    std::vector<DnsType> types;
    for (DnsType type : kSupportedDnsTypes) {
      if (type == DnsType::kANY) {
        continue;
      }

      uint16_t cast_int = static_cast<uint16_t>(type);
      if (elements_.test(cast_int)) {
        types.push_back(type);
      }
    }
    return types;
  }

 private:
  std::bitset<64> elements_;
};

// Modifies |records| such that no NSEC record signifies the nonexistance of a
// record which is also present in the same message. Order of the input vector
// is NOT preserved.
// NOTE: |records| is not of type MdnsRecord::ConstRef because the members must
// be modified.
// TODO(b/170353378): Break this logic into a separate processing module between
// the MdnsReader and the MdnsQuerier.
void RemoveInvalidNsecFlags(std::vector<MdnsRecord>* records) {
  // Sort the records so NSEC records are first so that only one iteration
  // through all records is needed.
  std::sort(records->begin(), records->end(), CompareRecordByNameAndType);

  // The set of NSEC records that need to be removed from |records|. This can't
  // be done as part of the below loop because it would invalidate the iterator
  // that's still being used.
  std::vector<std::vector<MdnsRecord>::iterator> nsecs_to_delete;

  // Process all elements.
  for (auto it = records->begin(); it != records->end();) {
    if (it->dns_type() != DnsType::kNSEC) {
      it++;
      continue;
    }

    // Track whether the current NSEC record in the input vector has been
    // modified by some step of this algorithm, be that merging with another
    // record, removing a DnsType, or any other modification.
    bool has_changed = false;

    // The types for the new record to create, if |has_changed|.
    const NsecRecordRdata& nsec_rdata = absl::get<NsecRecordRdata>(it->rdata());
    DnsTypeBitSet types;
    for (DnsType type : nsec_rdata.types()) {
      types.Insert(type);
    }
    auto nsec = it;
    it++;

    // Combine multiple NSECs to simplify the following code. This probably
    // won't happen, but the RFC doesn't exclude the possibility, so account for
    // it. Define the TTL of this new NSEC record created by this merge process
    // to be the minimum of all merged NSEC records.
    std::chrono::seconds new_ttl = nsec->ttl();
    while (it != records->end() && it->name() == nsec->name() &&
           it->dns_type() == DnsType::kNSEC) {
      has_changed |=
          types.Insert(absl::get<NsecRecordRdata>(it->rdata()).types());
      new_ttl = std::min(new_ttl, it->ttl());
      it = records->erase(it);
    }

    // Remove any types associated with a known record type.
    for (; it != records->end() && it->name() == nsec->name(); it++) {
      OSP_DCHECK(it->dns_type() != DnsType::kNSEC);
      has_changed |= types.Remove(it->dns_type());
    }

    // Modify the stored NSEC record, if needed.
    if (has_changed && types.IsEmpty()) {
      nsecs_to_delete.push_back(nsec);
    } else if (has_changed) {
      NsecRecordRdata new_rdata(nsec_rdata.next_domain_name(),
                                types.GetTypes());
      *nsec = MdnsRecord(nsec->name(), nsec->dns_type(), nsec->dns_class(),
                         nsec->record_type(), new_ttl, std::move(new_rdata));
    }
  }

  // Erase invalid NSEC records. Go backwards to avoid invalidating the
  // remaining iterators.
  for (auto erase_it = nsecs_to_delete.rbegin();
       erase_it != nsecs_to_delete.rend(); erase_it++) {
    records->erase(*erase_it);
  }
}

}  // namespace

MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
    MdnsQuerier* querier,
    MdnsSender* sender,
    MdnsRandom* random_delay,
    TaskRunner* task_runner,
    ClockNowFunctionPtr now_function,
    ReportingClient* reporting_client,
    const Config& config)
    : querier_(querier),
      sender_(sender),
      random_delay_(random_delay),
      task_runner_(task_runner),
      now_function_(now_function),
      reporting_client_(reporting_client),
      config_(config) {
  OSP_DCHECK(sender_);
  OSP_DCHECK(random_delay_);
  OSP_DCHECK(task_runner_);
  OSP_DCHECK(reporting_client_);
  OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
}

std::vector<std::reference_wrapper<const MdnsRecordTracker>>
MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
  return Find(name, DnsType::kANY, DnsClass::kANY);
}

std::vector<std::reference_wrapper<const MdnsRecordTracker>>
MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
                                         DnsType dns_type,
                                         DnsClass dns_class) {
  std::vector<RecordTrackerConstRef> results;
  auto pair = records_.equal_range(name);
  for (auto it = pair.first; it != pair.second; it++) {
    const MdnsRecordTracker& tracker = *it->second;
    if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
        (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
      results.push_back(std::cref(tracker));
    }
  }

  return results;
}

int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
                                              TrackerApplicableCheck check) {
  auto pair = records_.equal_range(domain);
  int count = 0;
  for (RecordMap::iterator it = pair.first; it != pair.second;) {
    if (check(*it->second)) {
      lru_order_.erase(it->second);
      it = records_.erase(it);
      count++;
    } else {
      it++;
    }
  }

  return count;
}

int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
    const DomainName& domain,
    TrackerApplicableCheck check) {
  auto pair = records_.equal_range(domain);
  int count = 0;
  for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
    if (check(*it->second)) {
      MoveToEnd(it);
      it->second->ExpireSoon();
      count++;
    }
  }

  return count;
}

int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
                                               TrackerApplicableCheck check) {
  return Update(record, check, [](const MdnsRecordTracker& t) {});
}

int MdnsQuerier::RecordTrackerLruCache::Update(
    const MdnsRecord& record,
    TrackerApplicableCheck check,
    TrackerChangeCallback on_rdata_update) {
  auto pair = records_.equal_range(record.name());
  int count = 0;
  for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
    if (check(*it->second)) {
      auto result = it->second->Update(record);

      if (result.is_error()) {
        reporting_client_->OnRecoverableError(
            Error(Error::Code::kUpdateReceivedRecordFailure,
                  result.error().ToString()));
        continue;
      }

      count++;
      if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
        it->second->ExpireSoon();
        MoveToEnd(it);
      } else {
        MoveToBeginning(it);
        if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
          on_rdata_update(*it->second);
        }
      }
    }
  }

  return count;
}

const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
    MdnsRecord record,
    DnsType dns_type) {
  auto expiration_callback = [this](const MdnsRecordTracker* tracker,
                                    const MdnsRecord& record) {
    querier_->OnRecordExpired(tracker, record);
  };

  while (lru_order_.size() >=
         static_cast<size_t>(config_.querier_max_records_cached)) {
    // This call erases one of the tracked records.
    OSP_DVLOG << "Maximum cacheable record count exceeded ("
              << config_.querier_max_records_cached << ")";
    lru_order_.back().ExpireNow();
  }

  auto name = record.name();
  lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
                           now_function_, random_delay_,
                           std::move(expiration_callback));
  records_.emplace(std::move(name), lru_order_.begin());

  return lru_order_.front();
}

void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
    MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
  lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
  it->second = lru_order_.begin();
}

void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
    MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
  lru_order_.splice(lru_order_.end(), lru_order_, it->second);
  it->second = --lru_order_.end();
}

MdnsQuerier::MdnsQuerier(MdnsSender* sender,
                         MdnsReceiver* receiver,
                         TaskRunner* task_runner,
                         ClockNowFunctionPtr now_function,
                         MdnsRandom* random_delay,
                         ReportingClient* reporting_client,
                         Config config)
    : sender_(sender),
      receiver_(receiver),
      task_runner_(task_runner),
      now_function_(now_function),
      random_delay_(random_delay),
      reporting_client_(reporting_client),
      config_(std::move(config)),
      records_(this,
               sender_,
               random_delay_,
               task_runner_,
               now_function_,
               reporting_client_,
               config_) {
  OSP_DCHECK(sender_);
  OSP_DCHECK(receiver_);
  OSP_DCHECK(task_runner_);
  OSP_DCHECK(now_function_);
  OSP_DCHECK(random_delay_);
  OSP_DCHECK(reporting_client_);

  receiver_->AddResponseCallback(this);
}

MdnsQuerier::~MdnsQuerier() {
  receiver_->RemoveResponseCallback(this);
}

// NOTE: The code below is range loops instead of std:find_if, for better
// readability, brevity and homogeneity.  Using std::find_if results in a few
// more lines of code, readability suffers from extra lambdas.

void MdnsQuerier::StartQuery(const DomainName& name,
                             DnsType dns_type,
                             DnsClass dns_class,
                             MdnsRecordChangedCallback* callback) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  OSP_DCHECK(callback);
  OSP_DCHECK(CanBeQueried(dns_type));

  // Add a new callback if haven't seen it before
  auto callbacks_it = callbacks_.equal_range(name);
  for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
    const CallbackInfo& callback_info = entry->second;
    if (dns_type == callback_info.dns_type &&
        dns_class == callback_info.dns_class &&
        callback == callback_info.callback) {
      // Already have this callback
      return;
    }
  }
  callbacks_.emplace(name, CallbackInfo{callback, dns_type, dns_class});

  // Notify the new callback with previously cached records.
  // NOTE: In the future, could allow callers to fetch cached records after
  // adding a callback, for example to prime the UI.
  std::vector<PendingQueryChange> pending_changes;
  const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
      records_.Find(name, dns_type, dns_class);
  for (const MdnsRecordTracker& tracker : trackers) {
    if (!tracker.is_negative_response()) {
      MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
                               tracker.record_type(), tracker.ttl(),
                               tracker.rdata());
      std::vector<PendingQueryChange> new_changes = callback->OnRecordChanged(
          std::move(stored_record), RecordChangedEvent::kCreated);
      pending_changes.insert(pending_changes.end(), new_changes.begin(),
                             new_changes.end());
    }
  }

  // Add a new question if haven't seen it before
  auto questions_it = questions_.equal_range(name);
  const bool is_question_already_tracked =
      std::find_if(questions_it.first, questions_it.second,
                   [dns_type, dns_class](const auto& entry) {
                     const MdnsQuestion& tracked_question =
                         entry.second->question();
                     return dns_type == tracked_question.dns_type() &&
                            dns_class == tracked_question.dns_class();
                   }) != questions_it.second;
  if (!is_question_already_tracked) {
    AddQuestion(
        MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
  }

  // Apply any pending changes from the OnRecordChanged() callbacks.
  ApplyPendingChanges(std::move(pending_changes));
}

void MdnsQuerier::StopQuery(const DomainName& name,
                            DnsType dns_type,
                            DnsClass dns_class,
                            MdnsRecordChangedCallback* callback) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  OSP_DCHECK(callback);

  if (!CanBeQueried(dns_type)) {
    return;
  }

  // Find and remove the callback.
  int callbacks_for_key = 0;
  auto callbacks_it = callbacks_.equal_range(name);
  for (auto entry = callbacks_it.first; entry != callbacks_it.second;) {
    const CallbackInfo& callback_info = entry->second;
    if (dns_type == callback_info.dns_type &&
        dns_class == callback_info.dns_class) {
      if (callback == callback_info.callback) {
        entry = callbacks_.erase(entry);
      } else {
        ++callbacks_for_key;
        ++entry;
      }
    }
  }

  // Exit if there are still callbacks registered for DomainName + DnsType +
  // DnsClass
  if (callbacks_for_key > 0) {
    return;
  }

  // Find and delete a question that does not have any associated callbacks
  auto questions_it = questions_.equal_range(name);
  for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
    const MdnsQuestion& tracked_question = entry->second->question();
    if (dns_type == tracked_question.dns_type() &&
        dns_class == tracked_question.dns_class()) {
      questions_.erase(entry);
      return;
    }
  }
}

void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  // Get the ongoing queries and their callbacks.
  std::vector<CallbackInfo> callbacks;
  auto its = callbacks_.equal_range(name);
  for (auto it = its.first; it != its.second; it++) {
    callbacks.push_back(std::move(it->second));
  }
  callbacks_.erase(name);

  // Remove all known questions and answers.
  questions_.erase(name);
  records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });

  // Restart the queries.
  for (const auto& cb : callbacks) {
    StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
  }
}

void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  OSP_DCHECK(message.type() == MessageType::Response);

  OSP_DVLOG << "Received mDNS Response message with "
            << message.answers().size() << " answers and "
            << message.additional_records().size()
            << " additional records. Processing...";

  std::vector<MdnsRecord> records_to_process;

  // Add any records that are relevant for this querier.
  bool found_relevant_records = false;
  for (const MdnsRecord& record : message.answers()) {
    if (ShouldAnswerRecordBeProcessed(record)) {
      records_to_process.push_back(record);
      found_relevant_records = true;
    }
  }

  // If any of the message's answers are relevant, add all additional records.
  // Else, since the message has already been received and parsed, use any
  // individual records relevant to this querier to update the cache.
  for (const MdnsRecord& record : message.additional_records()) {
    if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
      records_to_process.push_back(record);
    }
  }

  // Drop NSEC records associated with a non-NSEC record of the same type.
  RemoveInvalidNsecFlags(&records_to_process);

  // Process all remaining records.
  for (const MdnsRecord& record_to_process : records_to_process) {
    ProcessRecord(record_to_process);
  }

  OSP_DVLOG << "\tmDNS Response processed (" << records_to_process.size()
            << " records accepted)!";

  // TODO(crbug.com/openscreen/83): Check authority records.
}

bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
  // First, accept the record if it's associated with an ongoing question.
  const auto questions_range = questions_.equal_range(answer.name());
  const auto it = std::find_if(
      questions_range.first, questions_range.second,
      [&answer](const auto& pair) {
        return (pair.second->question().dns_type() == DnsType::kANY ||
                IsNegativeResponseFor(answer,
                                      pair.second->question().dns_type()) ||
                pair.second->question().dns_type() == answer.dns_type()) &&
               (pair.second->question().dns_class() == DnsClass::kANY ||
                pair.second->question().dns_class() == answer.dns_class());
      });
  if (it != questions_range.second) {
    return true;
  }

  // If not, check if it corresponds to an already existing record. This is
  // required because records which are already stored may either have been
  // received in an additional records section, or are associated with a query
  // which is no longer active.
  std::vector<DnsType> types{answer.dns_type()};
  if (answer.dns_type() == DnsType::kNSEC) {
    const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
    types = nsec_rdata.types();
  }

  for (DnsType type : types) {
    std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
        records_.Find(answer.name(), type, answer.dns_class());
    if (!trackers.empty()) {
      return true;
    }
  }

  // In all other cases, the record isn't relevant. Drop it.
  return false;
}

void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
                                  const MdnsRecord& record) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  if (!tracker->is_negative_response()) {
    ProcessCallbacks(record, RecordChangedEvent::kExpired);
  }

  records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
    return tracker == &it_tracker;
  });
}

void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  // Skip all records that can't be processed.
  if (!CanBeProcessed(record.dns_type())) {
    return;
  }

  // Ignore NSEC records if the embedder has configured us to do so.
  if (config_.ignore_nsec_responses && record.dns_type() == DnsType::kNSEC) {
    return;
  }

  // Get the types which the received record is associated with. In most cases
  // this will only be the type of the provided record, but in the case of
  // NSEC records this will be all records which the record dictates the
  // nonexistence of.
  std::vector<DnsType> types;
  int types_count = 0;
  const DnsType* types_ptr = nullptr;
  if (record.dns_type() == DnsType::kNSEC) {
    const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
    if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
                  DnsType::kANY) != nsec_rdata.types().end()) {
      types_ptr = kTranslatedNsecAnyQueryTypes.data();
      types_count = kTranslatedNsecAnyQueryTypes.size();
    } else {
      types_ptr = nsec_rdata.types().data();
      types_count = nsec_rdata.types().size();
    }
  } else {
    types.push_back(record.dns_type());
    types_ptr = types.data();
    types_count = types.size();
  }

  // Apply the update for each type that the record is associated with.
  for (int i = 0; i < types_count; ++i) {
    DnsType dns_type = types_ptr[i];
    switch (record.record_type()) {
      case RecordType::kShared: {
        ProcessSharedRecord(record, dns_type);
        break;
      }
      case RecordType::kUnique: {
        ProcessUniqueRecord(record, dns_type);
        break;
      }
    }
  }
}

void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
                                      DnsType dns_type) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  OSP_DCHECK(record.record_type() == RecordType::kShared);

  // By design, NSEC records are never shared records.
  if (record.dns_type() == DnsType::kNSEC) {
    return;
  }

  // For any records updated, this host already has this shared record. Since
  // the RDATA matches, this is only a TTL update.
  auto check = [&record](const MdnsRecordTracker& tracker) {
    return record.dns_type() == tracker.dns_type() &&
           record.dns_class() == tracker.dns_class() &&
           record.rdata() == tracker.rdata();
  };
  auto updated_count = records_.Update(record, std::move(check));

  if (!updated_count) {
    // Have never before seen this shared record, insert a new one.
    AddRecord(record, dns_type);
    ProcessCallbacks(record, RecordChangedEvent::kCreated);
  }
}

void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
                                      DnsType dns_type) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  OSP_DCHECK(record.record_type() == RecordType::kUnique);

  std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
      records_.Find(record.name(), dns_type, record.dns_class());
  size_t num_records_for_key = trackers.size();

  // Have not seen any records with this key before. This case is expected the
  // first time a record is received.
  if (num_records_for_key == size_t{0}) {
    const bool will_exist = record.dns_type() != DnsType::kNSEC;
    AddRecord(record, dns_type);
    if (will_exist) {
      ProcessCallbacks(record, RecordChangedEvent::kCreated);
    }
  } else if (num_records_for_key == size_t{1}) {
    // There is exactly one tracker associated with this key. This is the
    // expected case when a record matching this one has already been seen.
    ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
  } else {
    // Multiple records with the same key.
    ProcessMultiTrackedUniqueRecord(record, dns_type);
  }
}

void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
    const MdnsRecord& record,
    const MdnsRecordTracker& tracker) {
  const bool existed_previously = !tracker.is_negative_response();
  const bool will_exist = record.dns_type() != DnsType::kNSEC;

  // Calculate the callback to call on record update success while the old
  // record still exists.
  MdnsRecord record_for_callback = record;
  if (existed_previously && !will_exist) {
    record_for_callback =
        MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
                   tracker.record_type(), tracker.ttl(), tracker.rdata());
  }

  auto on_rdata_change = [this, r = std::move(record_for_callback),
                          existed_previously,
                          will_exist](const MdnsRecordTracker& tracker) {
    // If RDATA on the record is different, notify that the record has
    // been updated.
    if (existed_previously && will_exist) {
      ProcessCallbacks(r, RecordChangedEvent::kUpdated);
    } else if (existed_previously) {
      // Do not expire the tracker, because it still holds an NSEC record.
      ProcessCallbacks(r, RecordChangedEvent::kExpired);
    } else if (will_exist) {
      ProcessCallbacks(r, RecordChangedEvent::kCreated);
    }
  };

  int updated_count = records_.Update(
      record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
      std::move(on_rdata_change));
  OSP_DCHECK_EQ(updated_count, 1);
}

void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
                                                  DnsType dns_type) {
  auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
    return tracker.dns_type() == dns_type &&
           tracker.dns_class() == record.dns_class() &&
           tracker.rdata() == record.rdata();
  };
  int update_count = records_.Update(
      record, std::move(update_check),
      [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
  OSP_DCHECK_LE(update_count, 1);

  auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
    return tracker.dns_type() == dns_type &&
           tracker.dns_class() == record.dns_class() &&
           tracker.rdata() != record.rdata();
  };
  int expire_count =
      records_.ExpireSoon(record.name(), std::move(expire_check));
  OSP_DCHECK_GE(expire_count, 1);

  // Did not find an existing record to update.
  if (!update_count && !expire_count) {
    AddRecord(record, dns_type);
    if (record.dns_type() != DnsType::kNSEC) {
      ProcessCallbacks(record, RecordChangedEvent::kCreated);
    }
  }
}

void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
                                   RecordChangedEvent event) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  std::vector<PendingQueryChange> pending_changes;
  auto callbacks_it = callbacks_.equal_range(record.name());
  for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
    const CallbackInfo& callback_info = entry->second;
    if ((callback_info.dns_type == DnsType::kANY ||
         record.dns_type() == callback_info.dns_type) &&
        (callback_info.dns_class == DnsClass::kANY ||
         record.dns_class() == callback_info.dns_class)) {
      std::vector<PendingQueryChange> new_changes =
          callback_info.callback->OnRecordChanged(record, event);
      pending_changes.insert(pending_changes.end(), new_changes.begin(),
                             new_changes.end());
    }
  }

  ApplyPendingChanges(std::move(pending_changes));
}

void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
  auto tracker = std::make_unique<MdnsQuestionTracker>(
      question, sender_, task_runner_, now_function_, random_delay_, config_);
  MdnsQuestionTracker* ptr = tracker.get();
  questions_.emplace(question.name(), std::move(tracker));

  // Let all records associated with this question know that there is a new
  // query that can be used for their refresh.
  std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
      records_.Find(question.name(), question.dns_type(), question.dns_class());
  for (const MdnsRecordTracker& tracker : trackers) {
    // NOTE: When the pointed to object is deleted, its dtor removes itself
    // from all associated records.
    ptr->AddAssociatedRecord(&tracker);
  }
}

void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
  // Add the new record.
  const auto& tracker = records_.StartTracking(record, type);

  // Let all questions associated with this record know that there is a new
  // record that answers them (for known answer suppression).
  auto query_it = questions_.equal_range(record.name());
  for (auto entry = query_it.first; entry != query_it.second; ++entry) {
    const MdnsQuestion& query = entry->second->question();
    const bool is_relevant_type =
        type == DnsType::kANY || type == query.dns_type();
    const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
                                   record.dns_class() == query.dns_class();
    if (is_relevant_type && is_relevant_class) {
      // NOTE: When the pointed to object is deleted, its dtor removes itself
      // from all associated queries.
      entry->second->AddAssociatedRecord(&tracker);
    }
  }
}

void MdnsQuerier::ApplyPendingChanges(
    std::vector<PendingQueryChange> pending_changes) {
  for (auto& pending_change : pending_changes) {
    switch (pending_change.change_type) {
      case PendingQueryChange::kStartQuery:
        StartQuery(std::move(pending_change.name), pending_change.dns_type,
                   pending_change.dns_class, pending_change.callback);
        break;
      case PendingQueryChange::kStopQuery:
        StopQuery(std::move(pending_change.name), pending_change.dns_type,
                  pending_change.dns_class, pending_change.callback);
        break;
    }
  }
}

}  // namespace discovery
}  // namespace openscreen
