/*
 *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */
#include <cstdint>
#include <deque>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "api/array_view.h"
#include "api/task_queue/pending_task_safety_flag.h"
#include "api/task_queue/task_queue_base.h"
#include "api/test/create_network_emulation_manager.h"
#include "api/test/network_emulation_manager.h"
#include "api/units/time_delta.h"
#include "call/simulated_network.h"
#include "net/dcsctp/public/dcsctp_options.h"
#include "net/dcsctp/public/dcsctp_socket.h"
#include "net/dcsctp/public/types.h"
#include "net/dcsctp/socket/dcsctp_socket.h"
#include "net/dcsctp/testing/testing_macros.h"
#include "net/dcsctp/timer/task_queue_timeout.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/gunit.h"
#include "rtc_base/logging.h"
#include "rtc_base/socket_address.h"
#include "rtc_base/strings/string_format.h"
#include "rtc_base/time_utils.h"
#include "test/gmock.h"

#if !defined(WEBRTC_ANDROID) && defined(NDEBUG) && \
    !defined(THREAD_SANITIZER) && !defined(MEMORY_SANITIZER)
#define DCSCTP_NDEBUG_TEST(t) t
#else
// In debug mode, and when MSAN or TSAN sanitizers are enabled, these tests are
// too expensive to run due to extensive consistency checks that iterate on all
// outstanding chunks. Same with low-end Android devices, which have
// difficulties with these tests.
#define DCSCTP_NDEBUG_TEST(t) DISABLED_##t
#endif

namespace dcsctp {
namespace {
using ::testing::AllOf;
using ::testing::Ge;
using ::testing::Le;
using ::testing::SizeIs;

constexpr StreamID kStreamId(1);
constexpr PPID kPpid(53);
constexpr size_t kSmallPayloadSize = 10;
constexpr size_t kLargePayloadSize = 10000;
constexpr size_t kHugePayloadSize = 262144;
constexpr size_t kBufferedAmountLowThreshold = kLargePayloadSize * 2;
constexpr webrtc::TimeDelta kPrintBandwidthDuration =
    webrtc::TimeDelta::Seconds(1);
constexpr webrtc::TimeDelta kBenchmarkRuntime(webrtc::TimeDelta::Seconds(10));
constexpr webrtc::TimeDelta kAWhile(webrtc::TimeDelta::Seconds(1));

inline int GetUniqueSeed() {
  static int seed = 0;
  return ++seed;
}

DcSctpOptions MakeOptionsForTest() {
  DcSctpOptions options;

  // Throughput numbers are affected by the MTU. Ensure it's constant.
  options.mtu = 1200;

  // By disabling the heartbeat interval, there will no timers at all running
  // when the socket is idle, which makes it easy to just continue the test
  // until there are no more scheduled tasks. Note that it _will_ run for longer
  // than necessary as timers aren't cancelled when they are stopped (as that's
  // not supported), but it's still simulated time and passes quickly.
  options.heartbeat_interval = DurationMs(0);
  return options;
}

// When doing throughput tests, knowing what each actor should do.
enum class ActorMode {
  kAtRest,
  kThroughputSender,
  kThroughputReceiver,
  kLimitedRetransmissionSender,
};

// An abstraction around EmulatedEndpoint, representing a bound socket that
// will send its packet to a given destination.
class BoundSocket : public webrtc::EmulatedNetworkReceiverInterface {
 public:
  void Bind(webrtc::EmulatedEndpoint* endpoint) {
    endpoint_ = endpoint;
    uint16_t port = endpoint->BindReceiver(0, this).value();
    source_address_ =
        rtc::SocketAddress(endpoint_->GetPeerLocalAddress(), port);
  }

  void SetDestination(const BoundSocket& socket) {
    dest_address_ = socket.source_address_;
  }

  void SetReceiver(std::function<void(rtc::CopyOnWriteBuffer)> receiver) {
    receiver_ = std::move(receiver);
  }

  void SendPacket(rtc::ArrayView<const uint8_t> data) {
    endpoint_->SendPacket(source_address_, dest_address_,
                          rtc::CopyOnWriteBuffer(data.data(), data.size()));
  }

 private:
  // Implementation of `webrtc::EmulatedNetworkReceiverInterface`.
  void OnPacketReceived(webrtc::EmulatedIpPacket packet) override {
    receiver_(std::move(packet.data));
  }

  std::function<void(rtc::CopyOnWriteBuffer)> receiver_;
  webrtc::EmulatedEndpoint* endpoint_ = nullptr;
  rtc::SocketAddress source_address_;
  rtc::SocketAddress dest_address_;
};

// Sends at a constant rate but with random packet sizes.
class SctpActor : public DcSctpSocketCallbacks {
 public:
  SctpActor(absl::string_view name,
            BoundSocket& emulated_socket,
            const DcSctpOptions& sctp_options)
      : log_prefix_(std::string(name) + ": "),
        thread_(rtc::Thread::Current()),
        emulated_socket_(emulated_socket),
        timeout_factory_(
            *thread_,
            [this]() { return TimeMillis(); },
            [this](dcsctp::TimeoutID timeout_id) {
              sctp_socket_.HandleTimeout(timeout_id);
            }),
        random_(GetUniqueSeed()),
        sctp_socket_(name, *this, nullptr, sctp_options),
        last_bandwidth_printout_(TimeMs(TimeMillis())) {
    emulated_socket.SetReceiver([this](rtc::CopyOnWriteBuffer buf) {
      // The receiver will be executed on the NetworkEmulation task queue, but
      // the dcSCTP socket is owned by `thread_` and is not thread-safe.
      thread_->PostTask([this, buf] { this->sctp_socket_.ReceivePacket(buf); });
    });
  }

  void PrintBandwidth() {
    TimeMs now = TimeMillis();
    DurationMs duration = now - last_bandwidth_printout_;

    double bitrate_mbps =
        static_cast<double>(received_bytes_ * 8) / *duration / 1000;
    RTC_LOG(LS_INFO) << log_prefix()
                     << rtc::StringFormat("Received %0.2f Mbps", bitrate_mbps);

    received_bitrate_mbps_.push_back(bitrate_mbps);
    received_bytes_ = 0;
    last_bandwidth_printout_ = now;
    // Print again in a second.
    if (mode_ == ActorMode::kThroughputReceiver) {
      thread_->PostDelayedTask(
          SafeTask(safety_.flag(), [this] { PrintBandwidth(); }),
          kPrintBandwidthDuration);
    }
  }

  void SendPacket(rtc::ArrayView<const uint8_t> data) override {
    emulated_socket_.SendPacket(data);
  }

  std::unique_ptr<Timeout> CreateTimeout(
      webrtc::TaskQueueBase::DelayPrecision precision) override {
    return timeout_factory_.CreateTimeout(precision);
  }

  TimeMs TimeMillis() override { return TimeMs(rtc::TimeMillis()); }

  uint32_t GetRandomInt(uint32_t low, uint32_t high) override {
    return random_.Rand(low, high);
  }

  void OnMessageReceived(DcSctpMessage message) override {
    received_bytes_ += message.payload().size();
    last_received_message_ = std::move(message);
  }

  void OnError(ErrorKind error, absl::string_view message) override {
    RTC_LOG(LS_WARNING) << log_prefix() << "Socket error: " << ToString(error)
                        << "; " << message;
  }

  void OnAborted(ErrorKind error, absl::string_view message) override {
    RTC_LOG(LS_ERROR) << log_prefix() << "Socket abort: " << ToString(error)
                      << "; " << message;
  }

  void OnConnected() override {}

  void OnClosed() override {}

  void OnConnectionRestarted() override {}

  void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,
                            absl::string_view reason) override {}

  void OnStreamsResetPerformed(
      rtc::ArrayView<const StreamID> outgoing_streams) override {}

  void OnIncomingStreamsReset(
      rtc::ArrayView<const StreamID> incoming_streams) override {}

  void NotifyOutgoingMessageBufferEmpty() override {}

  void OnBufferedAmountLow(StreamID stream_id) override {
    if (mode_ == ActorMode::kThroughputSender) {
      std::vector<uint8_t> payload(kHugePayloadSize);
      sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)),
                        SendOptions());

    } else if (mode_ == ActorMode::kLimitedRetransmissionSender) {
      while (sctp_socket_.buffered_amount(kStreamId) <
             kBufferedAmountLowThreshold * 2) {
        SendOptions send_options;
        send_options.max_retransmissions = 0;
        sctp_socket_.Send(
            DcSctpMessage(kStreamId, kPpid,
                          std::vector<uint8_t>(kLargePayloadSize)),
            send_options);

        send_options.max_retransmissions = absl::nullopt;
        sctp_socket_.Send(
            DcSctpMessage(kStreamId, kPpid,
                          std::vector<uint8_t>(kSmallPayloadSize)),
            send_options);
      }
    }
  }

  absl::optional<DcSctpMessage> ConsumeReceivedMessage() {
    if (!last_received_message_.has_value()) {
      return absl::nullopt;
    }
    DcSctpMessage ret = *std::move(last_received_message_);
    last_received_message_ = absl::nullopt;
    return ret;
  }

  DcSctpSocket& sctp_socket() { return sctp_socket_; }

  void SetActorMode(ActorMode mode) {
    mode_ = mode;
    if (mode_ == ActorMode::kThroughputSender) {
      sctp_socket_.SetBufferedAmountLowThreshold(kStreamId,
                                                 kBufferedAmountLowThreshold);
      std::vector<uint8_t> payload(kHugePayloadSize);
      sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)),
                        SendOptions());

    } else if (mode_ == ActorMode::kLimitedRetransmissionSender) {
      sctp_socket_.SetBufferedAmountLowThreshold(kStreamId,
                                                 kBufferedAmountLowThreshold);
      std::vector<uint8_t> payload(kHugePayloadSize);
      sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)),
                        SendOptions());

    } else if (mode == ActorMode::kThroughputReceiver) {
      thread_->PostDelayedTask(
          SafeTask(safety_.flag(), [this] { PrintBandwidth(); }),
          kPrintBandwidthDuration);
    }
  }

  // Returns the average bitrate, stripping the first `remove_first_n` that
  // represent the time it took to ramp up the congestion control algorithm.
  double avg_received_bitrate_mbps(size_t remove_first_n = 3) const {
    std::vector<double> bitrates = received_bitrate_mbps_;
    bitrates.erase(bitrates.begin(), bitrates.begin() + remove_first_n);

    double sum = 0;
    for (double bitrate : bitrates) {
      sum += bitrate;
    }

    return sum / bitrates.size();
  }

 private:
  std::string log_prefix() const {
    rtc::StringBuilder sb;
    sb << log_prefix_;
    sb << rtc::TimeMillis();
    sb << ": ";
    return sb.Release();
  }

  ActorMode mode_ = ActorMode::kAtRest;
  const std::string log_prefix_;
  rtc::Thread* thread_;
  BoundSocket& emulated_socket_;
  TaskQueueTimeoutFactory timeout_factory_;
  webrtc::Random random_;
  DcSctpSocket sctp_socket_;
  size_t received_bytes_ = 0;
  absl::optional<DcSctpMessage> last_received_message_;
  TimeMs last_bandwidth_printout_;
  // Per-second received bitrates, in Mbps
  std::vector<double> received_bitrate_mbps_;
  webrtc::ScopedTaskSafety safety_;
};

class DcSctpSocketNetworkTest : public testing::Test {
 protected:
  DcSctpSocketNetworkTest()
      : options_(MakeOptionsForTest()),
        emulation_(webrtc::CreateNetworkEmulationManager(
            webrtc::TimeMode::kSimulated)) {}

  void MakeNetwork(const webrtc::BuiltInNetworkBehaviorConfig& config) {
    webrtc::EmulatedEndpoint* endpoint_a =
        emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig());
    webrtc::EmulatedEndpoint* endpoint_z =
        emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig());

    webrtc::EmulatedNetworkNode* node1 = emulation_->CreateEmulatedNode(config);
    webrtc::EmulatedNetworkNode* node2 = emulation_->CreateEmulatedNode(config);

    emulation_->CreateRoute(endpoint_a, {node1}, endpoint_z);
    emulation_->CreateRoute(endpoint_z, {node2}, endpoint_a);

    emulated_socket_a_.Bind(endpoint_a);
    emulated_socket_z_.Bind(endpoint_z);

    emulated_socket_a_.SetDestination(emulated_socket_z_);
    emulated_socket_z_.SetDestination(emulated_socket_a_);
  }

  void Sleep(webrtc::TimeDelta duration) {
    // Sleep in one-millisecond increments, to let timers expire when expected.
    for (int i = 0; i < duration.ms(); ++i) {
      emulation_->time_controller()->AdvanceTime(webrtc::TimeDelta::Millis(1));
    }
  }

  DcSctpOptions options_;
  std::unique_ptr<webrtc::NetworkEmulationManager> emulation_;
  BoundSocket emulated_socket_a_;
  BoundSocket emulated_socket_z_;
};

TEST_F(DcSctpSocketNetworkTest, CanConnectAndShutdown) {
  webrtc::BuiltInNetworkBehaviorConfig pipe_config;
  MakeNetwork(pipe_config);

  SctpActor sender("A", emulated_socket_a_, options_);
  SctpActor receiver("Z", emulated_socket_z_, options_);
  EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed);

  sender.sctp_socket().Connect();
  Sleep(kAWhile);
  EXPECT_THAT(sender.sctp_socket().state(), SocketState::kConnected);

  sender.sctp_socket().Shutdown();
  Sleep(kAWhile);
  EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed);
}

TEST_F(DcSctpSocketNetworkTest, CanSendLargeMessage) {
  webrtc::BuiltInNetworkBehaviorConfig pipe_config;
  pipe_config.queue_delay_ms = 30;
  MakeNetwork(pipe_config);

  SctpActor sender("A", emulated_socket_a_, options_);
  SctpActor receiver("Z", emulated_socket_z_, options_);
  sender.sctp_socket().Connect();

  constexpr size_t kPayloadSize = 100 * 1024;

  std::vector<uint8_t> payload(kPayloadSize);
  sender.sctp_socket().Send(DcSctpMessage(kStreamId, kPpid, payload),
                            SendOptions());

  Sleep(kAWhile);

  ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage message,
                              receiver.ConsumeReceivedMessage());

  EXPECT_THAT(message.payload(), SizeIs(kPayloadSize));

  sender.sctp_socket().Shutdown();
  Sleep(kAWhile);
}

TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithLowBandwidth) {
  webrtc::BuiltInNetworkBehaviorConfig pipe_config;
  pipe_config.queue_delay_ms = 30;
  pipe_config.link_capacity_kbps = 1000;
  MakeNetwork(pipe_config);

  SctpActor sender("A", emulated_socket_a_, options_);
  SctpActor receiver("Z", emulated_socket_z_, options_);
  sender.sctp_socket().Connect();

  sender.SetActorMode(ActorMode::kThroughputSender);
  receiver.SetActorMode(ActorMode::kThroughputReceiver);

  Sleep(kBenchmarkRuntime);
  sender.SetActorMode(ActorMode::kAtRest);
  receiver.SetActorMode(ActorMode::kAtRest);

  Sleep(kAWhile);

  sender.sctp_socket().Shutdown();

  Sleep(kAWhile);

  // Verify that the bitrates are in the range of 0.5-1.0 Mbps.
  double bitrate = receiver.avg_received_bitrate_mbps();
  EXPECT_THAT(bitrate, AllOf(Ge(0.5), Le(1.0)));
}

TEST_F(DcSctpSocketNetworkTest,
       DCSCTP_NDEBUG_TEST(CanSendMessagesReliablyWithMediumBandwidth)) {
  webrtc::BuiltInNetworkBehaviorConfig pipe_config;
  pipe_config.queue_delay_ms = 30;
  pipe_config.link_capacity_kbps = 18000;
  MakeNetwork(pipe_config);

  SctpActor sender("A", emulated_socket_a_, options_);
  SctpActor receiver("Z", emulated_socket_z_, options_);
  sender.sctp_socket().Connect();

  sender.SetActorMode(ActorMode::kThroughputSender);
  receiver.SetActorMode(ActorMode::kThroughputReceiver);

  Sleep(kBenchmarkRuntime);
  sender.SetActorMode(ActorMode::kAtRest);
  receiver.SetActorMode(ActorMode::kAtRest);

  Sleep(kAWhile);

  sender.sctp_socket().Shutdown();

  Sleep(kAWhile);

  // Verify that the bitrates are in the range of 16-18 Mbps.
  double bitrate = receiver.avg_received_bitrate_mbps();
  EXPECT_THAT(bitrate, AllOf(Ge(16), Le(18)));
}

TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithMuchPacketLoss) {
  webrtc::BuiltInNetworkBehaviorConfig config;
  config.queue_delay_ms = 30;
  config.loss_percent = 1;
  MakeNetwork(config);

  SctpActor sender("A", emulated_socket_a_, options_);
  SctpActor receiver("Z", emulated_socket_z_, options_);
  sender.sctp_socket().Connect();

  sender.SetActorMode(ActorMode::kThroughputSender);
  receiver.SetActorMode(ActorMode::kThroughputReceiver);

  Sleep(kBenchmarkRuntime);
  sender.SetActorMode(ActorMode::kAtRest);
  receiver.SetActorMode(ActorMode::kAtRest);

  Sleep(kAWhile);

  sender.sctp_socket().Shutdown();

  Sleep(kAWhile);

  // TCP calculator gives: 1200 MTU, 60ms RTT and 1% packet loss -> 1.6Mbps.
  // This test is doing slightly better (doesn't have any additional header
  // overhead etc). Verify that the bitrates are in the range of 1.5-2.5 Mbps.
  double bitrate = receiver.avg_received_bitrate_mbps();
  EXPECT_THAT(bitrate, AllOf(Ge(1.5), Le(2.5)));
}

TEST_F(DcSctpSocketNetworkTest, DCSCTP_NDEBUG_TEST(HasHighBandwidth)) {
  webrtc::BuiltInNetworkBehaviorConfig pipe_config;
  pipe_config.queue_delay_ms = 30;
  MakeNetwork(pipe_config);

  SctpActor sender("A", emulated_socket_a_, options_);
  SctpActor receiver("Z", emulated_socket_z_, options_);
  sender.sctp_socket().Connect();

  sender.SetActorMode(ActorMode::kThroughputSender);
  receiver.SetActorMode(ActorMode::kThroughputReceiver);

  Sleep(kBenchmarkRuntime);

  sender.SetActorMode(ActorMode::kAtRest);
  receiver.SetActorMode(ActorMode::kAtRest);
  Sleep(kAWhile);

  sender.sctp_socket().Shutdown();
  Sleep(kAWhile);

  // Verify that the bitrate is in the range of 540-640 Mbps
  double bitrate = receiver.avg_received_bitrate_mbps();
  EXPECT_THAT(bitrate, AllOf(Ge(520), Le(640)));
}
}  // namespace
}  // namespace dcsctp
