// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "discovery/mdns/mdns_trackers.h"

#include <memory>
#include <utility>

#include "discovery/common/config.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_sender.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "platform/test/fake_clock.h"
#include "platform/test/fake_task_runner.h"
#include "platform/test/fake_udp_socket.h"

namespace openscreen {
namespace discovery {
namespace {

constexpr Clock::duration kOneSecond =
    Clock::to_duration(std::chrono::seconds(1));
}

using testing::_;
using testing::Args;
using testing::DoAll;
using testing::Invoke;
using testing::Return;
using testing::StrictMock;
using testing::WithArgs;

ACTION_P2(VerifyMessageBytesWithoutId, expected_data, expected_size) {
  const uint8_t* actual_data = reinterpret_cast<const uint8_t*>(arg0);
  const size_t actual_size = arg1;
  ASSERT_EQ(actual_size, expected_size);
  // Start at bytes[2] to skip a generated message ID.
  for (size_t i = 2; i < actual_size; ++i) {
    EXPECT_EQ(actual_data[i], expected_data[i]);
  }
}

ACTION_P(VerifyTruncated, is_truncated) {
  EXPECT_EQ(arg0.is_truncated(), is_truncated);
}

ACTION_P(VerifyRecordCount, record_count) {
  EXPECT_EQ(arg0.answers().size(), static_cast<size_t>(record_count));
}

class MockMdnsSender : public MdnsSender {
 public:
  explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}

  MOCK_METHOD1(SendMulticast, Error(const MdnsMessage&));
  MOCK_METHOD2(SendMessage, Error(const MdnsMessage&, const IPEndpoint&));
};

class MockRecordChangedCallback : public MdnsRecordChangedCallback {
 public:
  MOCK_METHOD(std::vector<PendingQueryChange>,
              OnRecordChanged,
              (const MdnsRecord&, RecordChangedEvent event),
              (override));
};

class MdnsTrackerTest : public testing::Test {
 public:
  MdnsTrackerTest()
      : clock_(Clock::now()),
        task_runner_(&clock_),
        socket_(&task_runner_),
        sender_(&socket_),
        a_question_(DomainName{"testing", "local"},
                    DnsType::kANY,
                    DnsClass::kIN,
                    ResponseType::kMulticast),
        a_record_(DomainName{"testing", "local"},
                  DnsType::kA,
                  DnsClass::kIN,
                  RecordType::kShared,
                  std::chrono::seconds(120),
                  ARecordRdata(IPAddress{172, 0, 0, 1})),
        nsec_record_(
            DomainName{"testing", "local"},
            DnsType::kNSEC,
            DnsClass::kIN,
            RecordType::kShared,
            std::chrono::seconds(120),
            NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) {}

  template <class TrackerType>
  void TrackerNoQueryAfterDestruction(TrackerType tracker) {
    tracker.reset();
    // Advance fake clock by a long time interval to make sure if there's a
    // scheduled task, it will run.
    clock_.Advance(std::chrono::hours(1));
  }

  std::unique_ptr<MdnsRecordTracker> CreateRecordTracker(
      const MdnsRecord& record,
      DnsType type) {
    return std::make_unique<MdnsRecordTracker>(
        record, type, &sender_, &task_runner_, &FakeClock::now, &random_,
        [this](const MdnsRecordTracker* tracker, const MdnsRecord& record) {
          expiration_called_ = true;
        });
  }

  std::unique_ptr<MdnsRecordTracker> CreateRecordTracker(
      const MdnsRecord& record) {
    return CreateRecordTracker(record, record.dns_type());
  }

  std::unique_ptr<MdnsQuestionTracker> CreateQuestionTracker(
      const MdnsQuestion& question,
      MdnsQuestionTracker::QueryType query_type =
          MdnsQuestionTracker::QueryType::kContinuous) {
    return std::make_unique<MdnsQuestionTracker>(question, &sender_,
                                                 &task_runner_, &FakeClock::now,
                                                 &random_, config_, query_type);
  }

 protected:
  void AdvanceThroughAllTtlFractions(std::chrono::seconds ttl) {
    constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00};
    Clock::duration time_passed{0};
    for (double fraction : kTtlFractions) {
      Clock::duration time_till_refresh = Clock::to_duration(ttl * fraction);
      Clock::duration delta = time_till_refresh - time_passed;
      time_passed = time_till_refresh;
      clock_.Advance(delta);
    }
  }

  const MdnsRecord& GetRecord(MdnsRecordTracker* tracker) {
    return tracker->record_;
  }

  // clang-format off
  const std::vector<uint8_t> kQuestionQueryBytes = {
      0x00, 0x00,  // ID = 0
      0x00, 0x00,  // FLAGS = None
      0x00, 0x01,  // Question count
      0x00, 0x00,  // Answer count
      0x00, 0x00,  // Authority count
      0x00, 0x00,  // Additional count
      // Question
      0x07, 't', 'e', 's', 't', 'i', 'n', 'g',
      0x05, 'l', 'o', 'c', 'a', 'l',
      0x00,
      0x00, 0xFF,  // TYPE = ANY (255)
      0x00, 0x01,  // CLASS = IN (1)
  };

  const std::vector<uint8_t> kRecordQueryBytes = {
      0x00, 0x00,  // ID = 0
      0x00, 0x00,  // FLAGS = None
      0x00, 0x01,  // Question count
      0x00, 0x00,  // Answer count
      0x00, 0x00,  // Authority count
      0x00, 0x00,  // Additional count
      // Question
      0x07, 't', 'e', 's', 't', 'i', 'n', 'g',
      0x05, 'l', 'o', 'c', 'a', 'l',
      0x00,
      0x00, 0x01,  // TYPE = A (1)
      0x00, 0x01,  // CLASS = IN (1)
  };

  // clang-format on
  Config config_;
  FakeClock clock_;
  FakeTaskRunner task_runner_;
  FakeUdpSocket socket_;
  StrictMock<MockMdnsSender> sender_;
  MdnsRandom random_;

  MdnsQuestion a_question_;
  MdnsRecord a_record_;
  MdnsRecord nsec_record_;

  bool expiration_called_ = false;
};

// Records are re-queried at 80%, 85%, 90% and 95% TTL as per RFC 6762
// Section 5.2 There are no subsequent queries to refresh the record after that,
// the record is expired after TTL has passed since the start of tracking.
// Random variance required is from 0% to 2%, making these times at most 82%,
// 87%, 92% and 97% TTL. Fake clock is advanced to 83%, 88%, 93% and 98% to make
// sure that task gets executed.
// https://tools.ietf.org/html/rfc6762#section-5.2

TEST_F(MdnsTrackerTest, RecordTrackerRecordAccessor) {
  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
  EXPECT_EQ(GetRecord(tracker.get()), a_record_);
}

TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelayPerQuestionTracker) {
  std::unique_ptr<MdnsQuestionTracker> question = CreateQuestionTracker(
      a_question_, MdnsQuestionTracker::QueryType::kOneShot);
  std::unique_ptr<MdnsQuestionTracker> question2 = CreateQuestionTracker(
      a_question_, MdnsQuestionTracker::QueryType::kOneShot);
  EXPECT_CALL(sender_, SendMulticast(_)).Times(2);
  clock_.Advance(kOneSecond);
  clock_.Advance(kOneSecond);
  testing::Mock::VerifyAndClearExpectations(&sender_);

  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);

  // No queries without an associated tracker.
  AdvanceThroughAllTtlFractions(a_record_.ttl());
  testing::Mock::VerifyAndClearExpectations(&sender_);

  // 4 queries with one associated tracker.
  tracker = CreateRecordTracker(a_record_);
  tracker->AddAssociatedQuery(question.get());
  EXPECT_CALL(sender_, SendMulticast(_)).Times(4);
  AdvanceThroughAllTtlFractions(a_record_.ttl());
  testing::Mock::VerifyAndClearExpectations(&sender_);

  // 8 queries with two associated trackers.
  tracker = CreateRecordTracker(a_record_);
  tracker->AddAssociatedQuery(question.get());
  tracker->AddAssociatedQuery(question2.get());
  EXPECT_CALL(sender_, SendMulticast(_)).Times(8);
  AdvanceThroughAllTtlFractions(a_record_.ttl());
}

TEST_F(MdnsTrackerTest, RecordTrackerSendsMessage) {
  std::unique_ptr<MdnsQuestionTracker> question = CreateQuestionTracker(
      a_question_, MdnsQuestionTracker::QueryType::kOneShot);
  EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
  clock_.Advance(kOneSecond);
  clock_.Advance(kOneSecond);
  testing::Mock::VerifyAndClearExpectations(&sender_);

  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
  tracker->AddAssociatedQuery(question.get());

  EXPECT_CALL(sender_, SendMulticast(_))
      .Times(1)
      .WillRepeatedly([this](const MdnsMessage& message) -> Error {
        EXPECT_EQ(message.questions().size(), size_t{1});
        EXPECT_EQ(message.questions()[0], a_question_);
        return Error::None();
      });

  clock_.Advance(Clock::to_duration(a_record_.ttl() * 0.83));
}

TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterDestruction) {
  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
  TrackerNoQueryAfterDestruction(std::move(tracker));
}

TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterLateTask) {
  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
  // If task runner was too busy and callback happened too late, there should be
  // no query and instead the record will expire.
  // Check lower bound for task being late (TTL) and an arbitrarily long time
  // interval to ensure the query is not sent a later time.
  clock_.Advance(a_record_.ttl());
  clock_.Advance(std::chrono::hours(1));
}

TEST_F(MdnsTrackerTest, RecordTrackerUpdateResetsTtl) {
  expiration_called_ = false;
  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
  // Advance time by 60% of record's TTL
  Clock::duration advance_time = Clock::to_duration(a_record_.ttl() * 0.6);
  clock_.Advance(advance_time);
  // Now update the record, this must reset expiration time
  EXPECT_EQ(tracker->Update(a_record_).value(),
            MdnsRecordTracker::UpdateType::kTTLOnly);
  // Advance time by 60% of record's TTL again
  clock_.Advance(advance_time);
  // Check that expiration callback was not called
  EXPECT_FALSE(expiration_called_);
}

TEST_F(MdnsTrackerTest, RecordTrackerForceExpiration) {
  expiration_called_ = false;
  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
  tracker->ExpireSoon();
  // Expire schedules expiration after 1 second.
  clock_.Advance(std::chrono::seconds(1));
  EXPECT_TRUE(expiration_called_);
}

TEST_F(MdnsTrackerTest, NsecRecordTrackerForceExpiration) {
  expiration_called_ = false;
  std::unique_ptr<MdnsRecordTracker> tracker =
      CreateRecordTracker(nsec_record_, DnsType::kA);
  tracker->ExpireSoon();
  // Expire schedules expiration after 1 second.
  clock_.Advance(std::chrono::seconds(1));
  EXPECT_TRUE(expiration_called_);
}

TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallback) {
  expiration_called_ = false;
  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
  clock_.Advance(a_record_.ttl());
  EXPECT_TRUE(expiration_called_);
}

TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) {
  expiration_called_ = false;
  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
  MdnsRecord goodbye_record(a_record_.name(), a_record_.dns_type(),
                            a_record_.dns_class(), a_record_.record_type(),
                            std::chrono::seconds(0), a_record_.rdata());

  // After a goodbye record is received, expiration is schedule in a second.
  EXPECT_EQ(tracker->Update(goodbye_record).value(),
            MdnsRecordTracker::UpdateType::kGoodbye);

  // Advance clock to just before the expiration time of 1 second.
  clock_.Advance(std::chrono::microseconds(999999));
  EXPECT_FALSE(expiration_called_);
  // Advance clock to exactly the expiration time.
  clock_.Advance(std::chrono::microseconds(1));
  EXPECT_TRUE(expiration_called_);
}

TEST_F(MdnsTrackerTest, RecordTrackerInvalidPositiveRecordUpdate) {
  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);

  MdnsRecord invalid_name(DomainName{"invalid"}, a_record_.dns_type(),
                          a_record_.dns_class(), a_record_.record_type(),
                          a_record_.ttl(), a_record_.rdata());
  EXPECT_EQ(tracker->Update(invalid_name).error(),
            Error::Code::kParameterInvalid);

  MdnsRecord invalid_type(a_record_.name(), DnsType::kPTR,
                          a_record_.dns_class(), a_record_.record_type(),
                          a_record_.ttl(),
                          PtrRecordRdata{DomainName{"invalid"}});
  EXPECT_EQ(tracker->Update(invalid_type).error(),
            Error::Code::kParameterInvalid);

  MdnsRecord invalid_class(a_record_.name(), a_record_.dns_type(),
                           DnsClass::kANY, a_record_.record_type(),
                           a_record_.ttl(), a_record_.rdata());
  EXPECT_EQ(tracker->Update(invalid_class).error(),
            Error::Code::kParameterInvalid);

  // RDATA must match the old RDATA for goodbye records
  MdnsRecord invalid_rdata(a_record_.name(), a_record_.dns_type(),
                           a_record_.dns_class(), a_record_.record_type(),
                           std::chrono::seconds(0),
                           ARecordRdata(IPAddress{172, 0, 0, 2}));
  EXPECT_EQ(tracker->Update(invalid_rdata).error(),
            Error::Code::kParameterInvalid);
}

TEST_F(MdnsTrackerTest, RecordTrackerUpdatePositiveResponseWithNegative) {
  // Check valid update.
  std::unique_ptr<MdnsRecordTracker> tracker =
      CreateRecordTracker(a_record_, DnsType::kA);
  auto result = tracker->Update(nsec_record_);
  ASSERT_TRUE(result.is_value());
  EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata);
  EXPECT_EQ(GetRecord(tracker.get()), nsec_record_);

  // Check invalid update.
  MdnsRecord non_a_nsec_record(
      nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(),
      nsec_record_.record_type(), nsec_record_.ttl(),
      NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA));
  tracker = CreateRecordTracker(a_record_, DnsType::kA);
  auto response = tracker->Update(non_a_nsec_record);
  ASSERT_TRUE(response.is_error());
  EXPECT_EQ(GetRecord(tracker.get()), a_record_);
}

TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithNegative) {
  // Check valid update.
  std::unique_ptr<MdnsRecordTracker> tracker =
      CreateRecordTracker(nsec_record_, DnsType::kA);
  MdnsRecord multiple_nsec_record(
      nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(),
      nsec_record_.record_type(), nsec_record_.ttl(),
      NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA,
                      DnsType::kAAAA));
  auto result = tracker->Update(multiple_nsec_record);
  ASSERT_TRUE(result.is_value());
  EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata);
  EXPECT_EQ(GetRecord(tracker.get()), multiple_nsec_record);

  // Check invalid update.
  tracker = CreateRecordTracker(nsec_record_, DnsType::kA);
  MdnsRecord non_a_nsec_record(
      nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(),
      nsec_record_.record_type(), nsec_record_.ttl(),
      NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA));
  auto response = tracker->Update(non_a_nsec_record);
  EXPECT_TRUE(response.is_error());
  EXPECT_EQ(GetRecord(tracker.get()), nsec_record_);
}

TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithPositive) {
  // Check valid update.
  std::unique_ptr<MdnsRecordTracker> tracker =
      CreateRecordTracker(nsec_record_, DnsType::kA);
  auto result = tracker->Update(a_record_);
  ASSERT_TRUE(result.is_value());
  EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata);
  EXPECT_EQ(GetRecord(tracker.get()), a_record_);

  // Check invalid update.
  tracker = CreateRecordTracker(nsec_record_, DnsType::kA);
  MdnsRecord aaaa_record(a_record_.name(), DnsType::kAAAA,
                         a_record_.dns_class(), a_record_.record_type(),
                         std::chrono::seconds(0),
                         AAAARecordRdata(IPAddress{0, 0, 0, 0, 0, 0, 0, 1}));
  result = tracker->Update(aaaa_record);
  EXPECT_TRUE(result.is_error());
  EXPECT_EQ(GetRecord(tracker.get()), nsec_record_);
}

TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterDestruction) {
  expiration_called_ = false;
  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
  tracker.reset();
  clock_.Advance(a_record_.ttl());
  EXPECT_FALSE(expiration_called_);
}

// Initial query is delayed for up to 120 ms as per RFC 6762 Section 5.2
// Subsequent queries happen no sooner than a second after the initial query and
// the interval between the queries increases at least by a factor of 2 for each
// next query up until it's capped at 1 hour.
// https://tools.ietf.org/html/rfc6762#section-5.2

TEST_F(MdnsTrackerTest, QuestionTrackerQuestionAccessor) {
  std::unique_ptr<MdnsQuestionTracker> tracker =
      CreateQuestionTracker(a_question_);
  EXPECT_EQ(tracker->question(), a_question_);
}

TEST_F(MdnsTrackerTest, QuestionTrackerQueryAfterDelay) {
  std::unique_ptr<MdnsQuestionTracker> tracker =
      CreateQuestionTracker(a_question_);

  EXPECT_CALL(sender_, SendMulticast(_))
      .WillOnce(
          DoAll(WithArgs<0>(VerifyTruncated(false)), Return(Error::None())));
  clock_.Advance(std::chrono::milliseconds(120));

  std::chrono::seconds interval{1};
  while (interval < std::chrono::hours(1)) {
    EXPECT_CALL(sender_, SendMulticast(_))
        .WillOnce(
            DoAll(WithArgs<0>(VerifyTruncated(false)), Return(Error::None())));
    clock_.Advance(interval);
    interval *= 2;
  }
}

TEST_F(MdnsTrackerTest, QuestionTrackerSendsMessage) {
  std::unique_ptr<MdnsQuestionTracker> tracker =
      CreateQuestionTracker(a_question_);

  EXPECT_CALL(sender_, SendMulticast(_))
      .WillOnce(DoAll(
          WithArgs<0>(VerifyTruncated(false)),
          [this](const MdnsMessage& message) -> Error {
            EXPECT_EQ(message.questions().size(), size_t{1});
            EXPECT_EQ(message.questions()[0], a_question_);
            return Error::None();
          },
          Return(Error::None())));

  clock_.Advance(std::chrono::milliseconds(120));
}

TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterDestruction) {
  std::unique_ptr<MdnsQuestionTracker> tracker =
      CreateQuestionTracker(a_question_);
  TrackerNoQueryAfterDestruction(std::move(tracker));
}

TEST_F(MdnsTrackerTest, QuestionTrackerSendsMultipleMessages) {
  std::unique_ptr<MdnsQuestionTracker> tracker =
      CreateQuestionTracker(a_question_);

  std::vector<std::unique_ptr<MdnsRecordTracker>> answers;
  for (int i = 0; i < 100; i++) {
    auto record = CreateRecordTracker(a_record_);
    tracker->AddAssociatedRecord(record.get());
    answers.push_back(std::move(record));
  }

  EXPECT_CALL(sender_, SendMulticast(_))
      .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(true)),
                      WithArgs<0>(VerifyRecordCount(49)),
                      Return(Error::None())))
      .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(true)),
                      WithArgs<0>(VerifyRecordCount(50)),
                      Return(Error::None())))
      .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(false)),
                      WithArgs<0>(VerifyRecordCount(1)),
                      Return(Error::None())));

  clock_.Advance(std::chrono::milliseconds(120));
}

}  // namespace discovery
}  // namespace openscreen
