// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "osp/public/message_demuxer.h"

#include <memory>
#include <utility>

#include "osp/impl/quic/quic_connection.h"
#include "platform/base/error.h"
#include "util/big_endian.h"
#include "util/osp_logging.h"

namespace openscreen {
namespace osp {

// static
// Decodes a varUint, expecting it to follow the encoding format described here:
// https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
ErrorOr<uint64_t> MessageTypeDecoder::DecodeVarUint(
    const std::vector<uint8_t>& buffer,
    size_t* num_bytes_decoded) {
  if (buffer.size() == 0) {
    return Error::Code::kCborIncompleteMessage;
  }

  uint8_t num_type_bytes = static_cast<uint8_t>(buffer[0] >> 6 & 0x03);
  *num_bytes_decoded = 0x1 << num_type_bytes;

  // Ensure that ReadBigEndian won't read beyond the end of the buffer. Also,
  // since we expect the id to be followed by the message, equality is not valid
  if (buffer.size() <= *num_bytes_decoded) {
    return Error::Code::kCborIncompleteMessage;
  }

  switch (num_type_bytes) {
    case 0:
      return ReadBigEndian<uint8_t>(&buffer[0]) & ~0xC0;
    case 1:
      return ReadBigEndian<uint16_t>(&buffer[0]) & ~(0xC0 << 8);
    case 2:
      return ReadBigEndian<uint32_t>(&buffer[0]) & ~(0xC0 << 24);
    case 3:
      return ReadBigEndian<uint64_t>(&buffer[0]) & ~(uint64_t{0xC0} << 56);
    default:
      OSP_NOTREACHED();
  }
}

// static
// Decodes the Type of message, expecting it to follow the encoding format
// described here:
// https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
ErrorOr<msgs::Type> MessageTypeDecoder::DecodeType(
    const std::vector<uint8_t>& buffer,
    size_t* num_bytes_decoded) {
  ErrorOr<uint64_t> message_type =
      MessageTypeDecoder::DecodeVarUint(buffer, num_bytes_decoded);
  if (message_type.is_error()) {
    return message_type.error();
  }

  msgs::Type parsed_type =
      msgs::TypeEnumValidator::SafeCast(message_type.value());
  if (parsed_type == msgs::Type::kUnknown) {
    return Error::Code::kCborInvalidMessage;
  }

  return parsed_type;
}

// static
constexpr size_t MessageDemuxer::kDefaultBufferLimit;

MessageDemuxer::MessageWatch::MessageWatch() = default;

MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent,
                                           bool is_default,
                                           uint64_t endpoint_id,
                                           msgs::Type message_type)
    : parent_(parent),
      is_default_(is_default),
      endpoint_id_(endpoint_id),
      message_type_(message_type) {}

MessageDemuxer::MessageWatch::MessageWatch(
    MessageDemuxer::MessageWatch&& other) noexcept
    : parent_(other.parent_),
      is_default_(other.is_default_),
      endpoint_id_(other.endpoint_id_),
      message_type_(other.message_type_) {
  other.parent_ = nullptr;
}

MessageDemuxer::MessageWatch::~MessageWatch() {
  if (parent_) {
    if (is_default_) {
      OSP_VLOG << "dropping default handler for type: "
               << static_cast<int>(message_type_);
      parent_->StopDefaultMessageTypeWatch(message_type_);
    } else {
      OSP_VLOG << "dropping handler for type: "
               << static_cast<int>(message_type_);
      parent_->StopWatchingMessageType(endpoint_id_, message_type_);
    }
  }
}

MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=(
    MessageWatch&& other) noexcept {
  using std::swap;
  swap(parent_, other.parent_);
  swap(is_default_, other.is_default_);
  swap(endpoint_id_, other.endpoint_id_);
  swap(message_type_, other.message_type_);
  return *this;
}

MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function,
                               size_t buffer_limit = kDefaultBufferLimit)
    : now_function_(now_function), buffer_limit_(buffer_limit) {
  OSP_DCHECK(now_function_);
}

MessageDemuxer::~MessageDemuxer() = default;

MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType(
    uint64_t endpoint_id,
    msgs::Type message_type,
    MessageCallback* callback) {
  auto callbacks_entry = message_callbacks_.find(endpoint_id);
  if (callbacks_entry == message_callbacks_.end()) {
    callbacks_entry =
        message_callbacks_
            .emplace(endpoint_id, std::map<msgs::Type, MessageCallback*>{})
            .first;
  }
  auto emplace_result = callbacks_entry->second.emplace(message_type, callback);
  if (!emplace_result.second)
    return MessageWatch();
  auto endpoint_entry = buffers_.find(endpoint_id);
  if (endpoint_entry != buffers_.end()) {
    for (auto& buffer : endpoint_entry->second) {
      if (buffer.second.empty())
        continue;
      auto buffered_type = static_cast<msgs::Type>(buffer.second[0]);
      if (message_type == buffered_type) {
        HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry,
                               &buffer.second);
      }
    }
  }
  return MessageWatch(this, false, endpoint_id, message_type);
}

MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch(
    msgs::Type message_type,
    MessageCallback* callback) {
  auto emplace_result = default_callbacks_.emplace(message_type, callback);
  if (!emplace_result.second)
    return MessageWatch();
  for (auto& endpoint_buffers : buffers_) {
    auto endpoint_id = endpoint_buffers.first;
    for (auto& stream_map : endpoint_buffers.second) {
      if (stream_map.second.empty())
        continue;
      auto buffered_type = static_cast<msgs::Type>(stream_map.second[0]);
      if (message_type == buffered_type) {
        auto connection_id = stream_map.first;
        auto callbacks_entry = message_callbacks_.find(endpoint_id);
        HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry,
                               &stream_map.second);
      }
    }
  }
  return MessageWatch(this, true, 0, message_type);
}

void MessageDemuxer::OnStreamData(uint64_t endpoint_id,
                                  uint64_t connection_id,
                                  const uint8_t* data,
                                  size_t data_size) {
  OSP_VLOG << __func__ << ": [" << endpoint_id << ", " << connection_id
           << "] - (" << data_size << ")";
  auto& stream_map = buffers_[endpoint_id];
  if (!data_size) {
    stream_map.erase(connection_id);
    if (stream_map.empty())
      buffers_.erase(endpoint_id);
    return;
  }
  std::vector<uint8_t>& buffer = stream_map[connection_id];
  buffer.insert(buffer.end(), data, data + data_size);

  auto callbacks_entry = message_callbacks_.find(endpoint_id);
  HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer);

  if (buffer.size() > buffer_limit_)
    stream_map.erase(connection_id);
}

void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id,
                                             msgs::Type message_type) {
  auto& message_map = message_callbacks_[endpoint_id];
  auto it = message_map.find(message_type);
  message_map.erase(it);
}

void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) {
  default_callbacks_.erase(message_type);
}

MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop(
    uint64_t endpoint_id,
    uint64_t connection_id,
    std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator
        callbacks_entry,
    std::vector<uint8_t>* buffer) {
  HandleStreamBufferResult result;
  do {
    result = {false, 0};
    if (callbacks_entry != message_callbacks_.end()) {
      OSP_VLOG << "attempting endpoint-specific handling";
      result = HandleStreamBuffer(endpoint_id, connection_id,
                                  &callbacks_entry->second, buffer);
    }
    if (!result.handled) {
      if (!default_callbacks_.empty()) {
        OSP_VLOG << "attempting generic message handling";
        result = HandleStreamBuffer(endpoint_id, connection_id,
                                    &default_callbacks_, buffer);
      }
    }
    OSP_VLOG_IF(!result.handled) << "no message handler matched";
  } while (result.consumed && !buffer->empty());
  return result;
}

// TODO(rwkeane) Use absl::Span for the buffer
MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer(
    uint64_t endpoint_id,
    uint64_t connection_id,
    std::map<msgs::Type, MessageCallback*>* message_callbacks,
    std::vector<uint8_t>* buffer) {
  size_t consumed = 0;
  size_t total_consumed = 0;
  bool handled = false;
  do {
    consumed = 0;
    size_t msg_type_byte_length;
    ErrorOr<msgs::Type> message_type =
        MessageTypeDecoder::DecodeType(*buffer, &msg_type_byte_length);
    if (message_type.is_error()) {
      buffer->clear();
      break;
    }
    auto callback_entry = message_callbacks->find(message_type.value());
    if (callback_entry == message_callbacks->end())
      break;
    handled = true;
    OSP_VLOG << "handling message type "
             << static_cast<int>(message_type.value());
    auto consumed_or_error = callback_entry->second->OnStreamMessage(
        endpoint_id, connection_id, message_type.value(),
        buffer->data() + msg_type_byte_length,
        buffer->size() - msg_type_byte_length, now_function_());
    if (!consumed_or_error) {
      if (consumed_or_error.error().code() !=
          Error::Code::kCborIncompleteMessage) {
        buffer->clear();
        break;
      }
    } else {
      consumed = consumed_or_error.value();
      buffer->erase(buffer->begin(),
                    buffer->begin() + consumed + msg_type_byte_length);
    }
    total_consumed += consumed;
  } while (consumed && !buffer->empty());
  return HandleStreamBufferResult{handled, total_consumed};
}

void StopWatching(MessageDemuxer::MessageWatch* watch) {
  *watch = MessageDemuxer::MessageWatch();
}

}  // namespace osp
}  // namespace openscreen
