// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/base/address_tracker_linux_test_util.h"

#include <linux/if.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <stdint.h>
#include <string.h>

#include <vector>

#include "base/check_op.h"
#include "base/logging.h"
#include "net/base/ip_address.h"

bool operator==(const struct ifaddrmsg& lhs, const struct ifaddrmsg& rhs) {
  return memcmp(&lhs, &rhs, sizeof(struct ifaddrmsg)) == 0;
}

namespace net::test {

NetlinkMessage::NetlinkMessage(uint16_t type) : buffer_(NLMSG_HDRLEN) {
  header()->nlmsg_type = type;
  Align();
}

NetlinkMessage::~NetlinkMessage() = default;

void NetlinkMessage::AddPayload(const void* data, size_t length) {
  CHECK_EQ(static_cast<size_t>(NLMSG_HDRLEN), buffer_.size())
      << "Payload must be added first";
  Append(data, length);
  Align();
}

void NetlinkMessage::AddAttribute(uint16_t type,
                                  const void* data,
                                  size_t length) {
  struct nlattr attr;
  attr.nla_len = NLA_HDRLEN + length;
  attr.nla_type = type;
  Append(&attr, sizeof(attr));
  Align();
  Append(data, length);
  Align();
}

void NetlinkMessage::AppendTo(NetlinkBuffer* output) const {
  CHECK_EQ(NLMSG_ALIGN(output->size()), output->size());
  output->insert(output->end(), buffer_.begin(), buffer_.end());
}

void NetlinkMessage::Append(const void* data, size_t length) {
  const char* chardata = reinterpret_cast<const char*>(data);
  buffer_.insert(buffer_.end(), chardata, chardata + length);
}

void NetlinkMessage::Align() {
  header()->nlmsg_len = buffer_.size();
  buffer_.resize(NLMSG_ALIGN(buffer_.size()));
  CHECK(NLMSG_OK(header(), buffer_.size()));
}

#define INFINITY_LIFE_TIME 0xFFFFFFFF

void MakeAddrMessageWithCacheInfo(uint16_t type,
                                  uint8_t flags,
                                  uint8_t family,
                                  int index,
                                  const IPAddress& address,
                                  const IPAddress& local,
                                  uint32_t preferred_lifetime,
                                  NetlinkBuffer* output) {
  NetlinkMessage nlmsg(type);
  struct ifaddrmsg msg = {};
  msg.ifa_family = family;
  msg.ifa_flags = flags;
  msg.ifa_index = index;
  nlmsg.AddPayload(msg);
  if (address.size()) {
    nlmsg.AddAttribute(IFA_ADDRESS, address.bytes().data(), address.size());
  }
  if (local.size()) {
    nlmsg.AddAttribute(IFA_LOCAL, local.bytes().data(), local.size());
  }
  struct ifa_cacheinfo cache_info = {};
  cache_info.ifa_prefered = preferred_lifetime;
  cache_info.ifa_valid = INFINITY_LIFE_TIME;
  nlmsg.AddAttribute(IFA_CACHEINFO, &cache_info, sizeof(cache_info));
  nlmsg.AppendTo(output);
}

void MakeAddrMessage(uint16_t type,
                     uint8_t flags,
                     uint8_t family,
                     int index,
                     const IPAddress& address,
                     const IPAddress& local,
                     NetlinkBuffer* output) {
  MakeAddrMessageWithCacheInfo(type, flags, family, index, address, local,
                               INFINITY_LIFE_TIME, output);
}

void MakeLinkMessage(uint16_t type,
                     uint32_t flags,
                     uint32_t index,
                     NetlinkBuffer* output,
                     bool clear_output) {
  NetlinkMessage nlmsg(type);
  struct ifinfomsg msg = {};
  msg.ifi_index = index;
  msg.ifi_flags = flags;
  msg.ifi_change = 0xFFFFFFFF;
  nlmsg.AddPayload(msg);
  if (clear_output) {
    output->clear();
  }
  nlmsg.AppendTo(output);
}

// Creates a netlink message generated by wireless_send_event. These events
// should be ignored.
void MakeWirelessLinkMessage(uint16_t type,
                             uint32_t flags,
                             uint32_t index,
                             NetlinkBuffer* output,
                             bool clear_output) {
  NetlinkMessage nlmsg(type);
  struct ifinfomsg msg = {};
  msg.ifi_index = index;
  msg.ifi_flags = flags;
  msg.ifi_change = 0;
  nlmsg.AddPayload(msg);
  char data[8] = {0};
  nlmsg.AddAttribute(IFLA_WIRELESS, data, sizeof(data));
  if (clear_output) {
    output->clear();
  }
  nlmsg.AppendTo(output);
}

}  // namespace net::test
