// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "discovery/mdns/mdns_records.h"

#include <algorithm>
#include <cctype>
#include <limits>
#include <sstream>
#include <vector>

#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "discovery/mdns/mdns_writer.h"

namespace openscreen {
namespace discovery {

namespace {

constexpr size_t kMaxRawRecordSize = std::numeric_limits<uint16_t>::max();

constexpr size_t kMaxMessageFieldEntryCount =
    std::numeric_limits<uint16_t>::max();

inline int CompareIgnoreCase(const std::string& x, const std::string& y) {
  size_t i = 0;
  for (; i < x.size(); i++) {
    if (i == y.size()) {
      return 1;
    }
    const char& x_char = std::tolower(x[i]);
    const char& y_char = std::tolower(y[i]);
    if (x_char < y_char) {
      return -1;
    } else if (y_char < x_char) {
      return 1;
    }
  }
  return i == y.size() ? 0 : -1;
}

template <typename RDataType>
bool IsGreaterThan(const Rdata& lhs, const Rdata& rhs) {
  const RDataType& lhs_cast = absl::get<RDataType>(lhs);
  const RDataType& rhs_cast = absl::get<RDataType>(rhs);

  // The Extra 2 in length is from the record size that Write() prepends to the
  // result.
  const size_t lhs_size = lhs_cast.MaxWireSize() + 2;
  const size_t rhs_size = rhs_cast.MaxWireSize() + 2;

  uint8_t lhs_bytes[lhs_size];
  uint8_t rhs_bytes[rhs_size];
  MdnsWriter lhs_writer(lhs_bytes, lhs_size);
  MdnsWriter rhs_writer(rhs_bytes, rhs_size);

  const bool lhs_write = lhs_writer.Write(lhs_cast);
  const bool rhs_write = rhs_writer.Write(rhs_cast);
  OSP_DCHECK(lhs_write);
  OSP_DCHECK(rhs_write);

  // Skip the size bits.
  const size_t min_size = std::min(lhs_writer.offset(), rhs_writer.offset());
  for (size_t i = 2; i < min_size; i++) {
    if (lhs_bytes[i] != rhs_bytes[i]) {
      return lhs_bytes[i] > rhs_bytes[i];
    }
  }

  return lhs_size > rhs_size;
}

bool IsGreaterThan(DnsType type, const Rdata& lhs, const Rdata& rhs) {
  switch (type) {
    case DnsType::kA:
      return IsGreaterThan<ARecordRdata>(lhs, rhs);
    case DnsType::kPTR:
      return IsGreaterThan<PtrRecordRdata>(lhs, rhs);
    case DnsType::kTXT:
      return IsGreaterThan<TxtRecordRdata>(lhs, rhs);
    case DnsType::kAAAA:
      return IsGreaterThan<AAAARecordRdata>(lhs, rhs);
    case DnsType::kSRV:
      return IsGreaterThan<SrvRecordRdata>(lhs, rhs);
    case DnsType::kNSEC:
      return IsGreaterThan<NsecRecordRdata>(lhs, rhs);
    default:
      return IsGreaterThan<RawRecordRdata>(lhs, rhs);
  }
}

}  // namespace

bool IsValidDomainLabel(absl::string_view label) {
  const size_t label_size = label.size();
  return label_size > 0 && label_size <= kMaxLabelLength;
}

DomainName::DomainName() = default;

DomainName::DomainName(std::vector<std::string> labels)
    : DomainName(labels.begin(), labels.end()) {}

DomainName::DomainName(const std::vector<absl::string_view>& labels)
    : DomainName(labels.begin(), labels.end()) {}

DomainName::DomainName(std::initializer_list<absl::string_view> labels)
    : DomainName(labels.begin(), labels.end()) {}

DomainName::DomainName(std::vector<std::string> labels, size_t max_wire_size)
    : max_wire_size_(max_wire_size), labels_(std::move(labels)) {}

DomainName::DomainName(const DomainName& other) = default;

DomainName::DomainName(DomainName&& other) noexcept = default;

DomainName& DomainName::operator=(const DomainName& rhs) = default;

DomainName& DomainName::operator=(DomainName&& rhs) = default;

std::string DomainName::ToString() const {
  return absl::StrJoin(labels_, ".");
}

bool DomainName::operator<(const DomainName& rhs) const {
  size_t i = 0;
  for (; i < labels_.size(); i++) {
    if (i == rhs.labels_.size()) {
      return false;
    } else {
      int result = CompareIgnoreCase(labels_[i], rhs.labels_[i]);
      if (result < 0) {
        return true;
      } else if (result > 0) {
        return false;
      }
    }
  }
  return i < rhs.labels_.size();
}

bool DomainName::operator<=(const DomainName& rhs) const {
  return (*this < rhs) || (*this == rhs);
}

bool DomainName::operator>(const DomainName& rhs) const {
  return !(*this < rhs) && !(*this == rhs);
}

bool DomainName::operator>=(const DomainName& rhs) const {
  return !(*this < rhs);
}

bool DomainName::operator==(const DomainName& rhs) const {
  if (labels_.size() != rhs.labels_.size()) {
    return false;
  }
  for (size_t i = 0; i < labels_.size(); i++) {
    if (CompareIgnoreCase(labels_[i], rhs.labels_[i]) != 0) {
      return false;
    }
  }
  return true;
}

bool DomainName::operator!=(const DomainName& rhs) const {
  return !(*this == rhs);
}

size_t DomainName::MaxWireSize() const {
  return max_wire_size_;
}

// static
ErrorOr<RawRecordRdata> RawRecordRdata::TryCreate(std::vector<uint8_t> rdata) {
  if (rdata.size() > kMaxRawRecordSize) {
    return Error::Code::kIndexOutOfBounds;
  } else {
    return RawRecordRdata(std::move(rdata));
  }
}

RawRecordRdata::RawRecordRdata() = default;

RawRecordRdata::RawRecordRdata(std::vector<uint8_t> rdata)
    : rdata_(std::move(rdata)) {
  // Ensure RDATA length does not exceed the maximum allowed.
  OSP_DCHECK(rdata_.size() <= kMaxRawRecordSize);
}

RawRecordRdata::RawRecordRdata(const uint8_t* begin, size_t size)
    : RawRecordRdata(std::vector<uint8_t>(begin, begin + size)) {}

RawRecordRdata::RawRecordRdata(const RawRecordRdata& other) = default;

RawRecordRdata::RawRecordRdata(RawRecordRdata&& other) noexcept = default;

RawRecordRdata& RawRecordRdata::operator=(const RawRecordRdata& rhs) = default;

RawRecordRdata& RawRecordRdata::operator=(RawRecordRdata&& rhs) = default;

bool RawRecordRdata::operator==(const RawRecordRdata& rhs) const {
  return rdata_ == rhs.rdata_;
}

bool RawRecordRdata::operator!=(const RawRecordRdata& rhs) const {
  return !(*this == rhs);
}

size_t RawRecordRdata::MaxWireSize() const {
  // max_wire_size includes uint16_t record length field.
  return sizeof(uint16_t) + rdata_.size();
}

SrvRecordRdata::SrvRecordRdata() = default;

SrvRecordRdata::SrvRecordRdata(uint16_t priority,
                               uint16_t weight,
                               uint16_t port,
                               DomainName target)
    : priority_(priority),
      weight_(weight),
      port_(port),
      target_(std::move(target)) {}

SrvRecordRdata::SrvRecordRdata(const SrvRecordRdata& other) = default;

SrvRecordRdata::SrvRecordRdata(SrvRecordRdata&& other) noexcept = default;

SrvRecordRdata& SrvRecordRdata::operator=(const SrvRecordRdata& rhs) = default;

SrvRecordRdata& SrvRecordRdata::operator=(SrvRecordRdata&& rhs) = default;

bool SrvRecordRdata::operator==(const SrvRecordRdata& rhs) const {
  return priority_ == rhs.priority_ && weight_ == rhs.weight_ &&
         port_ == rhs.port_ && target_ == rhs.target_;
}

bool SrvRecordRdata::operator!=(const SrvRecordRdata& rhs) const {
  return !(*this == rhs);
}

size_t SrvRecordRdata::MaxWireSize() const {
  // max_wire_size includes uint16_t record length field.
  return sizeof(uint16_t) + sizeof(priority_) + sizeof(weight_) +
         sizeof(port_) + target_.MaxWireSize();
}

ARecordRdata::ARecordRdata() = default;

ARecordRdata::ARecordRdata(IPAddress ipv4_address,
                           NetworkInterfaceIndex interface_index)
    : ipv4_address_(std::move(ipv4_address)),
      interface_index_(interface_index) {
  OSP_CHECK(ipv4_address_.IsV4());
}

ARecordRdata::ARecordRdata(const ARecordRdata& other) = default;

ARecordRdata::ARecordRdata(ARecordRdata&& other) noexcept = default;

ARecordRdata& ARecordRdata::operator=(const ARecordRdata& rhs) = default;

ARecordRdata& ARecordRdata::operator=(ARecordRdata&& rhs) = default;

bool ARecordRdata::operator==(const ARecordRdata& rhs) const {
  return ipv4_address_ == rhs.ipv4_address_ &&
         interface_index_ == rhs.interface_index_;
}

bool ARecordRdata::operator!=(const ARecordRdata& rhs) const {
  return !(*this == rhs);
}

size_t ARecordRdata::MaxWireSize() const {
  // max_wire_size includes uint16_t record length field.
  return sizeof(uint16_t) + IPAddress::kV4Size;
}

AAAARecordRdata::AAAARecordRdata() = default;

AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address,
                                 NetworkInterfaceIndex interface_index)
    : ipv6_address_(std::move(ipv6_address)),
      interface_index_(interface_index) {
  OSP_CHECK(ipv6_address_.IsV6());
}

AAAARecordRdata::AAAARecordRdata(const AAAARecordRdata& other) = default;

AAAARecordRdata::AAAARecordRdata(AAAARecordRdata&& other) noexcept = default;

AAAARecordRdata& AAAARecordRdata::operator=(const AAAARecordRdata& rhs) =
    default;

AAAARecordRdata& AAAARecordRdata::operator=(AAAARecordRdata&& rhs) = default;

bool AAAARecordRdata::operator==(const AAAARecordRdata& rhs) const {
  return ipv6_address_ == rhs.ipv6_address_ &&
         interface_index_ == rhs.interface_index_;
}

bool AAAARecordRdata::operator!=(const AAAARecordRdata& rhs) const {
  return !(*this == rhs);
}

size_t AAAARecordRdata::MaxWireSize() const {
  // max_wire_size includes uint16_t record length field.
  return sizeof(uint16_t) + IPAddress::kV6Size;
}

PtrRecordRdata::PtrRecordRdata() = default;

PtrRecordRdata::PtrRecordRdata(DomainName ptr_domain)
    : ptr_domain_(ptr_domain) {}

PtrRecordRdata::PtrRecordRdata(const PtrRecordRdata& other) = default;

PtrRecordRdata::PtrRecordRdata(PtrRecordRdata&& other) noexcept = default;

PtrRecordRdata& PtrRecordRdata::operator=(const PtrRecordRdata& rhs) = default;

PtrRecordRdata& PtrRecordRdata::operator=(PtrRecordRdata&& rhs) = default;

bool PtrRecordRdata::operator==(const PtrRecordRdata& rhs) const {
  return ptr_domain_ == rhs.ptr_domain_;
}

bool PtrRecordRdata::operator!=(const PtrRecordRdata& rhs) const {
  return !(*this == rhs);
}

size_t PtrRecordRdata::MaxWireSize() const {
  // max_wire_size includes uint16_t record length field.
  return sizeof(uint16_t) + ptr_domain_.MaxWireSize();
}

// static
ErrorOr<TxtRecordRdata> TxtRecordRdata::TryCreate(std::vector<Entry> texts) {
  std::vector<std::string> str_texts;
  size_t max_wire_size = 3;
  if (texts.size() > 0) {
    str_texts.reserve(texts.size());
    // max_wire_size includes uint16_t record length field.
    max_wire_size = sizeof(uint16_t);
    for (const auto& text : texts) {
      if (text.empty()) {
        return Error::Code::kParameterInvalid;
      }
      str_texts.push_back(
          std::string(reinterpret_cast<const char*>(text.data()), text.size()));
      // Include the length byte in the size calculation.
      max_wire_size += text.size() + 1;
    }
  }
  return TxtRecordRdata(std::move(str_texts), max_wire_size);
}

TxtRecordRdata::TxtRecordRdata() = default;

TxtRecordRdata::TxtRecordRdata(std::vector<Entry> texts) {
  ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
  *this = std::move(rdata.value());
}

TxtRecordRdata::TxtRecordRdata(std::vector<std::string> texts,
                               size_t max_wire_size)
    : max_wire_size_(max_wire_size), texts_(std::move(texts)) {}

TxtRecordRdata::TxtRecordRdata(const TxtRecordRdata& other) = default;

TxtRecordRdata::TxtRecordRdata(TxtRecordRdata&& other) noexcept = default;

TxtRecordRdata& TxtRecordRdata::operator=(const TxtRecordRdata& rhs) = default;

TxtRecordRdata& TxtRecordRdata::operator=(TxtRecordRdata&& rhs) = default;

bool TxtRecordRdata::operator==(const TxtRecordRdata& rhs) const {
  return texts_ == rhs.texts_;
}

bool TxtRecordRdata::operator!=(const TxtRecordRdata& rhs) const {
  return !(*this == rhs);
}

size_t TxtRecordRdata::MaxWireSize() const {
  return max_wire_size_;
}

NsecRecordRdata::NsecRecordRdata() = default;

NsecRecordRdata::NsecRecordRdata(DomainName next_domain_name,
                                 std::vector<DnsType> types)
    : types_(std::move(types)), next_domain_name_(std::move(next_domain_name)) {
  // Sort the types_ array for easier comparison later.
  std::sort(types_.begin(), types_.end());

  // Calculate the bitmaps as described in RFC 4034 Section 4.1.2.
  std::vector<uint8_t> block_contents;
  uint8_t current_block = 0;
  for (auto type : types_) {
    const uint16_t type_int = static_cast<uint16_t>(type);
    const uint8_t block = static_cast<uint8_t>(type_int >> 8);
    const uint8_t block_position = static_cast<uint8_t>(type_int & 0xFF);
    const uint8_t byte_bit_is_at = block_position >> 3;         // First 5 bits.
    const uint8_t byte_mask = 0x80 >> (block_position & 0x07);  // Last 3 bits.

    // If the block has changed, write the previous block's info and all of its
    // contents to the |encoded_types_| vector.
    if (block > current_block) {
      if (!block_contents.empty()) {
        encoded_types_.push_back(current_block);
        encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
        encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
                              block_contents.end());
      }
      block_contents = std::vector<uint8_t>();
      current_block = block;
    }

    // Make sure |block_contents| is large enough to hold the bit representing
    // the new type , then set it.
    if (block_contents.size() <= byte_bit_is_at) {
      block_contents.insert(block_contents.end(),
                            byte_bit_is_at - block_contents.size() + 1, 0x00);
    }

    block_contents[byte_bit_is_at] |= byte_mask;
  }

  if (!block_contents.empty()) {
    encoded_types_.push_back(current_block);
    encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
    encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
                          block_contents.end());
  }
}

NsecRecordRdata::NsecRecordRdata(const NsecRecordRdata& other) = default;

NsecRecordRdata::NsecRecordRdata(NsecRecordRdata&& other) noexcept = default;

NsecRecordRdata& NsecRecordRdata::operator=(const NsecRecordRdata& rhs) =
    default;

NsecRecordRdata& NsecRecordRdata::operator=(NsecRecordRdata&& rhs) = default;

bool NsecRecordRdata::operator==(const NsecRecordRdata& rhs) const {
  return types_ == rhs.types_ && next_domain_name_ == rhs.next_domain_name_;
}

bool NsecRecordRdata::operator!=(const NsecRecordRdata& rhs) const {
  return !(*this == rhs);
}

size_t NsecRecordRdata::MaxWireSize() const {
  return next_domain_name_.MaxWireSize() + encoded_types_.size();
}

size_t OptRecordRdata::Option::MaxWireSize() const {
  // One uint16_t for each of OPTION-LENGTH and OPTION-CODE as defined in RFC
  // 6891 section 6.1.2.
  constexpr size_t kOptionLengthAndCodeSize = 2 * sizeof(uint16_t);
  return data.size() + kOptionLengthAndCodeSize;
}

bool OptRecordRdata::Option::operator>(
    const OptRecordRdata::Option& rhs) const {
  if (code != rhs.code) {
    return code > rhs.code;
  } else if (length != rhs.length) {
    return length > rhs.length;
  } else if (data.size() != rhs.data.size()) {
    return data.size() > rhs.data.size();
  }

  for (int i = 0; i < static_cast<int>(data.size()); i++) {
    if (data[i] != rhs.data[i]) {
      return data[i] > rhs.data[i];
    }
  }

  return false;
}

bool OptRecordRdata::Option::operator<(
    const OptRecordRdata::Option& rhs) const {
  return rhs > *this;
}

bool OptRecordRdata::Option::operator>=(
    const OptRecordRdata::Option& rhs) const {
  return !(*this < rhs);
}

bool OptRecordRdata::Option::operator<=(
    const OptRecordRdata::Option& rhs) const {
  return !(*this > rhs);
}

bool OptRecordRdata::Option::operator==(
    const OptRecordRdata::Option& rhs) const {
  return *this >= rhs && *this <= rhs;
}

bool OptRecordRdata::Option::operator!=(
    const OptRecordRdata::Option& rhs) const {
  return !(*this == rhs);
}

OptRecordRdata::OptRecordRdata() = default;

OptRecordRdata::OptRecordRdata(std::vector<Option> options)
    : options_(std::move(options)) {
  for (const auto& option : options_) {
    max_wire_size_ += option.MaxWireSize();
  }
  std::sort(options_.begin(), options_.end());
}

OptRecordRdata::OptRecordRdata(const OptRecordRdata& other) = default;

OptRecordRdata::OptRecordRdata(OptRecordRdata&& other) noexcept = default;

OptRecordRdata& OptRecordRdata::operator=(const OptRecordRdata& rhs) = default;

OptRecordRdata& OptRecordRdata::operator=(OptRecordRdata&& rhs) = default;

bool OptRecordRdata::operator==(const OptRecordRdata& rhs) const {
  return options_ == rhs.options_;
}

bool OptRecordRdata::operator!=(const OptRecordRdata& rhs) const {
  return !(*this == rhs);
}

// static
ErrorOr<MdnsRecord> MdnsRecord::TryCreate(DomainName name,
                                          DnsType dns_type,
                                          DnsClass dns_class,
                                          RecordType record_type,
                                          std::chrono::seconds ttl,
                                          Rdata rdata) {
  if (!IsValidConfig(name, dns_type, ttl, rdata)) {
    return Error::Code::kParameterInvalid;
  } else {
    return MdnsRecord(std::move(name), dns_type, dns_class, record_type, ttl,
                      std::move(rdata));
  }
}

MdnsRecord::MdnsRecord() = default;

MdnsRecord::MdnsRecord(DomainName name,
                       DnsType dns_type,
                       DnsClass dns_class,
                       RecordType record_type,
                       std::chrono::seconds ttl,
                       Rdata rdata)
    : name_(std::move(name)),
      dns_type_(dns_type),
      dns_class_(dns_class),
      record_type_(record_type),
      ttl_(ttl),
      rdata_(std::move(rdata)) {
  OSP_DCHECK(IsValidConfig(name_, dns_type, ttl_, rdata_));
}

MdnsRecord::MdnsRecord(const MdnsRecord& other) = default;

MdnsRecord::MdnsRecord(MdnsRecord&& other) noexcept = default;

MdnsRecord& MdnsRecord::operator=(const MdnsRecord& rhs) = default;

MdnsRecord& MdnsRecord::operator=(MdnsRecord&& rhs) = default;

// static
bool MdnsRecord::IsValidConfig(const DomainName& name,
                               DnsType dns_type,
                               std::chrono::seconds ttl,
                               const Rdata& rdata) {
  // NOTE: Although the name_ field was initially expected to be non-empty, this
  // validation is no longer accurate for some record types (such as OPT
  // records). To ensure that future record types correctly parse into
  // RawRecordData types and do not invalidate the received message, this check
  // has been removed.
  return ttl.count() <= std::numeric_limits<uint32_t>::max() &&
         ((dns_type == DnsType::kSRV &&
           absl::holds_alternative<SrvRecordRdata>(rdata)) ||
          (dns_type == DnsType::kA &&
           absl::holds_alternative<ARecordRdata>(rdata)) ||
          (dns_type == DnsType::kAAAA &&
           absl::holds_alternative<AAAARecordRdata>(rdata)) ||
          (dns_type == DnsType::kPTR &&
           absl::holds_alternative<PtrRecordRdata>(rdata)) ||
          (dns_type == DnsType::kTXT &&
           absl::holds_alternative<TxtRecordRdata>(rdata)) ||
          (dns_type == DnsType::kNSEC &&
           absl::holds_alternative<NsecRecordRdata>(rdata)) ||
          (dns_type == DnsType::kOPT &&
           absl::holds_alternative<OptRecordRdata>(rdata)) ||
          absl::holds_alternative<RawRecordRdata>(rdata));
}

bool MdnsRecord::operator==(const MdnsRecord& rhs) const {
  return IsReannouncementOf(rhs) && ttl_ == rhs.ttl_;
}

bool MdnsRecord::operator!=(const MdnsRecord& rhs) const {
  return !(*this == rhs);
}

bool MdnsRecord::operator>(const MdnsRecord& rhs) const {
  // Returns the record which is lexicographically later. The determination of
  // "lexicographically later" is performed by first comparing the record class,
  // then the record type, then raw comparison of the binary content of the
  // rdata without regard for meaning or structure.
  // NOTE: Per RFC, the TTL is not included in this comparison.
  if (name() != rhs.name()) {
    return name() > rhs.name();
  }

  if (record_type() != rhs.record_type()) {
    return record_type() == RecordType::kUnique;
  }

  if (dns_class() != rhs.dns_class()) {
    return dns_class() > rhs.dns_class();
  }

  uint16_t this_type = static_cast<uint16_t>(dns_type()) & kClassMask;
  uint16_t other_type = static_cast<uint16_t>(rhs.dns_type()) & kClassMask;
  if (this_type != other_type) {
    return this_type > other_type;
  }

  return IsGreaterThan(dns_type(), rdata(), rhs.rdata());
}

bool MdnsRecord::operator<(const MdnsRecord& rhs) const {
  return rhs > *this;
}

bool MdnsRecord::operator<=(const MdnsRecord& rhs) const {
  return !(*this > rhs);
}

bool MdnsRecord::operator>=(const MdnsRecord& rhs) const {
  return !(*this < rhs);
}

bool MdnsRecord::IsReannouncementOf(const MdnsRecord& rhs) const {
  return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
         record_type_ == rhs.record_type_ && name_ == rhs.name_ &&
         rdata_ == rhs.rdata_;
}

size_t MdnsRecord::MaxWireSize() const {
  auto wire_size_visitor = [](auto&& arg) { return arg.MaxWireSize(); };
  // NAME size, 2-byte TYPE, 2-byte CLASS, 4-byte TTL, RDATA size
  return name_.MaxWireSize() + absl::visit(wire_size_visitor, rdata_) + 8;
}

std::string MdnsRecord::ToString() const {
  std::stringstream ss;
  ss << "name: '" << name_.ToString() << "'";
  ss << ", type: " << dns_type_;

  if (dns_type_ == DnsType::kPTR) {
    const DomainName& target = absl::get<PtrRecordRdata>(rdata_).ptr_domain();
    ss << ", target: '" << target.ToString() << "'";
  } else if (dns_type_ == DnsType::kSRV) {
    const DomainName& target = absl::get<SrvRecordRdata>(rdata_).target();
    ss << ", target: '" << target.ToString() << "'";
  } else if (dns_type_ == DnsType::kNSEC) {
    const auto& nsec_rdata = absl::get<NsecRecordRdata>(rdata_);
    std::vector<DnsType> types = nsec_rdata.types();
    ss << ", representing [";
    if (!types.empty()) {
      auto it = types.begin();
      ss << *it++;
      while (it != types.end()) {
        ss << ", " << *it++;
      }
      ss << "]";
    }
  }

  return ss.str();
}

MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address) {
  Rdata rdata;
  DnsType type;
  std::chrono::seconds ttl;
  if (address.IsV4()) {
    type = DnsType::kA;
    rdata = ARecordRdata(address);
    ttl = kARecordTtl;
  } else {
    type = DnsType::kAAAA;
    rdata = AAAARecordRdata(address);
    ttl = kAAAARecordTtl;
  }

  return MdnsRecord(std::move(name), type, DnsClass::kIN, RecordType::kUnique,
                    ttl, std::move(rdata));
}

// static
ErrorOr<MdnsQuestion> MdnsQuestion::TryCreate(DomainName name,
                                              DnsType dns_type,
                                              DnsClass dns_class,
                                              ResponseType response_type) {
  if (name.empty()) {
    return Error::Code::kParameterInvalid;
  }

  return MdnsQuestion(std::move(name), dns_type, dns_class, response_type);
}

MdnsQuestion::MdnsQuestion(DomainName name,
                           DnsType dns_type,
                           DnsClass dns_class,
                           ResponseType response_type)
    : name_(std::move(name)),
      dns_type_(dns_type),
      dns_class_(dns_class),
      response_type_(response_type) {
  OSP_CHECK(!name_.empty());
}

bool MdnsQuestion::operator==(const MdnsQuestion& rhs) const {
  return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
         response_type_ == rhs.response_type_ && name_ == rhs.name_;
}

bool MdnsQuestion::operator!=(const MdnsQuestion& rhs) const {
  return !(*this == rhs);
}

size_t MdnsQuestion::MaxWireSize() const {
  // NAME size, 2-byte TYPE, 2-byte CLASS
  return name_.MaxWireSize() + 4;
}

// static
ErrorOr<MdnsMessage> MdnsMessage::TryCreate(
    uint16_t id,
    MessageType type,
    std::vector<MdnsQuestion> questions,
    std::vector<MdnsRecord> answers,
    std::vector<MdnsRecord> authority_records,
    std::vector<MdnsRecord> additional_records) {
  if (questions.size() >= kMaxMessageFieldEntryCount ||
      answers.size() >= kMaxMessageFieldEntryCount ||
      authority_records.size() >= kMaxMessageFieldEntryCount ||
      additional_records.size() >= kMaxMessageFieldEntryCount) {
    return Error::Code::kParameterInvalid;
  }

  return MdnsMessage(id, type, std::move(questions), std::move(answers),
                     std::move(authority_records),
                     std::move(additional_records));
}

MdnsMessage::MdnsMessage(uint16_t id, MessageType type)
    : id_(id), type_(type) {}

MdnsMessage::MdnsMessage(uint16_t id,
                         MessageType type,
                         std::vector<MdnsQuestion> questions,
                         std::vector<MdnsRecord> answers,
                         std::vector<MdnsRecord> authority_records,
                         std::vector<MdnsRecord> additional_records)
    : id_(id),
      type_(type),
      questions_(std::move(questions)),
      answers_(std::move(answers)),
      authority_records_(std::move(authority_records)),
      additional_records_(std::move(additional_records)) {
  OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
  OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
  OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
  OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);

  for (const MdnsQuestion& question : questions_) {
    max_wire_size_ += question.MaxWireSize();
  }
  for (const MdnsRecord& record : answers_) {
    max_wire_size_ += record.MaxWireSize();
  }
  for (const MdnsRecord& record : authority_records_) {
    max_wire_size_ += record.MaxWireSize();
  }
  for (const MdnsRecord& record : additional_records_) {
    max_wire_size_ += record.MaxWireSize();
  }
}

bool MdnsMessage::operator==(const MdnsMessage& rhs) const {
  return id_ == rhs.id_ && type_ == rhs.type_ && questions_ == rhs.questions_ &&
         answers_ == rhs.answers_ &&
         authority_records_ == rhs.authority_records_ &&
         additional_records_ == rhs.additional_records_;
}

bool MdnsMessage::operator!=(const MdnsMessage& rhs) const {
  return !(*this == rhs);
}

bool MdnsMessage::IsProbeQuery() const {
  // A message is a probe query if it contains records in the authority section
  // which answer the question being asked.
  if (questions().empty() || authority_records().empty()) {
    return false;
  }

  for (const MdnsQuestion& question : questions_) {
    for (const MdnsRecord& record : authority_records_) {
      if (question.name() == record.name() &&
          ((question.dns_type() == record.dns_type()) ||
           (question.dns_type() == DnsType::kANY)) &&
          ((question.dns_class() == record.dns_class()) ||
           (question.dns_class() == DnsClass::kANY))) {
        return true;
      }
    }
  }

  return false;
}

size_t MdnsMessage::MaxWireSize() const {
  return max_wire_size_;
}

void MdnsMessage::AddQuestion(MdnsQuestion question) {
  OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
  max_wire_size_ += question.MaxWireSize();
  questions_.emplace_back(std::move(question));
}

void MdnsMessage::AddAnswer(MdnsRecord record) {
  OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
  max_wire_size_ += record.MaxWireSize();
  answers_.emplace_back(std::move(record));
}

void MdnsMessage::AddAuthorityRecord(MdnsRecord record) {
  OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
  max_wire_size_ += record.MaxWireSize();
  authority_records_.emplace_back(std::move(record));
}

void MdnsMessage::AddAdditionalRecord(MdnsRecord record) {
  OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
  max_wire_size_ += record.MaxWireSize();
  additional_records_.emplace_back(std::move(record));
}

bool MdnsMessage::CanAddRecord(const MdnsRecord& record) {
  return (max_wire_size_ + record.MaxWireSize()) < kMaxMulticastMessageSize;
}

uint16_t CreateMessageId() {
  static uint16_t id(0);
  return id++;
}

bool CanBePublished(DnsType type) {
  // NOTE: A 'default' switch statement has intentionally been avoided below to
  // enforce that new DnsTypes added must be added below through a compile-time
  // check.
  switch (type) {
    case DnsType::kA:
    case DnsType::kAAAA:
    case DnsType::kPTR:
    case DnsType::kTXT:
    case DnsType::kSRV:
      return true;
    case DnsType::kOPT:
    case DnsType::kNSEC:
    case DnsType::kANY:
      break;
  }

  return false;
}

bool CanBePublished(const MdnsRecord& record) {
  return CanBePublished(record.dns_type());
}

bool CanBeQueried(DnsType type) {
  // NOTE: A 'default' switch statement has intentionally been avoided below to
  // enforce that new DnsTypes added must be added below through a compile-time
  // check.
  switch (type) {
    case DnsType::kA:
    case DnsType::kAAAA:
    case DnsType::kPTR:
    case DnsType::kTXT:
    case DnsType::kSRV:
    case DnsType::kANY:
      return true;
    case DnsType::kOPT:
    case DnsType::kNSEC:
      break;
  }

  return false;
}

bool CanBeProcessed(DnsType type) {
  // NOTE: A 'default' switch statement has intentionally been avoided below to
  // enforce that new DnsTypes added must be added below through a compile-time
  // check.
  switch (type) {
    case DnsType::kA:
    case DnsType::kAAAA:
    case DnsType::kPTR:
    case DnsType::kTXT:
    case DnsType::kSRV:
    case DnsType::kNSEC:
      return true;
    case DnsType::kOPT:
    case DnsType::kANY:
      break;
  }

  return false;
}

}  // namespace discovery
}  // namespace openscreen
