// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "osp/impl/presentation/url_availability_requester.h"

#include <algorithm>
#include <chrono>
#include <memory>
#include <utility>
#include <vector>

#include "osp/impl/presentation/presentation_common.h"
#include "osp/public/network_service_manager.h"
#include "util/chrono_helpers.h"
#include "util/osp_logging.h"

using std::chrono::seconds;

namespace openscreen {
namespace osp {
namespace {

static constexpr Clock::duration kWatchDuration = seconds(20);
static constexpr Clock::duration kWatchRefreshPadding = seconds(2);

std::vector<std::string>::iterator PartitionUrlsBySetMembership(
    std::vector<std::string>* urls,
    const std::set<std::string>& membership_test) {
  return std::partition(
      urls->begin(), urls->end(), [&membership_test](const std::string& url) {
        return membership_test.find(url) == membership_test.end();
      });
}

void MoveVectorSegment(std::vector<std::string>::iterator first,
                       std::vector<std::string>::iterator last,
                       std::set<std::string>* target) {
  for (auto it = first; it != last; ++it)
    target->emplace(std::move(*it));
}

uint64_t GetNextRequestId(const uint64_t endpoint_id) {
  return NetworkServiceManager::Get()
      ->GetProtocolConnectionClient()
      ->endpoint_request_ids()
      ->GetNextRequestId(endpoint_id);
}

}  // namespace

UrlAvailabilityRequester::UrlAvailabilityRequester(
    ClockNowFunctionPtr now_function)
    : now_function_(now_function) {
  OSP_DCHECK(now_function_);
}

UrlAvailabilityRequester::~UrlAvailabilityRequester() = default;

void UrlAvailabilityRequester::AddObserver(const std::vector<std::string>& urls,
                                           ReceiverObserver* observer) {
  for (const auto& url : urls) {
    observers_by_url_[url].push_back(observer);
  }
  for (auto& entry : receiver_by_service_id_) {
    auto& receiver = entry.second;
    receiver->GetOrRequestAvailabilities(urls, observer);
  }
}

void UrlAvailabilityRequester::RemoveObserverUrls(
    const std::vector<std::string>& urls,
    ReceiverObserver* observer) {
  std::set<std::string> unobserved_urls;
  for (const auto& url : urls) {
    auto observer_entry = observers_by_url_.find(url);
    if (observer_entry == observers_by_url_.end())
      continue;
    auto& observers = observer_entry->second;
    observers.erase(std::remove(observers.begin(), observers.end(), observer),
                    observers.end());
    if (observers.empty()) {
      unobserved_urls.emplace(std::move(observer_entry->first));
      observers_by_url_.erase(observer_entry);
      for (auto& entry : receiver_by_service_id_) {
        auto& receiver = entry.second;
        receiver->known_availability_by_url.erase(url);
      }
    }
  }

  for (auto& entry : receiver_by_service_id_) {
    auto& receiver = entry.second;
    receiver->RemoveUnobservedRequests(unobserved_urls);
    receiver->RemoveUnobservedWatches(unobserved_urls);
  }
}

void UrlAvailabilityRequester::RemoveObserver(ReceiverObserver* observer) {
  std::set<std::string> unobserved_urls;
  for (auto& entry : observers_by_url_) {
    auto& observer_list = entry.second;
    auto it = std::remove(observer_list.begin(), observer_list.end(), observer);
    if (it != observer_list.end()) {
      observer_list.erase(it);
      if (observer_list.empty())
        unobserved_urls.insert(entry.first);
    }
  }

  for (auto& entry : receiver_by_service_id_) {
    auto& receiver = entry.second;
    receiver->RemoveUnobservedRequests(unobserved_urls);
    receiver->RemoveUnobservedWatches(unobserved_urls);
  }
}

void UrlAvailabilityRequester::AddReceiver(const ServiceInfo& info) {
  auto result = receiver_by_service_id_.emplace(
      info.service_id,
      std::make_unique<ReceiverRequester>(
          this, info.service_id,
          info.v4_endpoint.address ? info.v4_endpoint : info.v6_endpoint));
  std::unique_ptr<ReceiverRequester>& receiver = result.first->second;
  std::vector<std::string> urls;
  urls.reserve(observers_by_url_.size());
  for (const auto& url : observers_by_url_)
    urls.push_back(url.first);
  receiver->RequestUrlAvailabilities(std::move(urls));
}

void UrlAvailabilityRequester::ChangeReceiver(const ServiceInfo& info) {}

void UrlAvailabilityRequester::RemoveReceiver(const ServiceInfo& info) {
  auto receiver_entry = receiver_by_service_id_.find(info.service_id);
  if (receiver_entry != receiver_by_service_id_.end()) {
    auto& receiver = receiver_entry->second;
    receiver->RemoveReceiver();
    receiver_by_service_id_.erase(receiver_entry);
  }
}

void UrlAvailabilityRequester::RemoveAllReceivers() {
  for (auto& entry : receiver_by_service_id_) {
    auto& receiver = entry.second;
    receiver->RemoveReceiver();
  }
  receiver_by_service_id_.clear();
}

Clock::time_point UrlAvailabilityRequester::RefreshWatches() {
  const Clock::time_point now = now_function_();
  Clock::time_point minimum_schedule_time = now + kWatchDuration;
  for (auto& entry : receiver_by_service_id_) {
    auto& receiver = entry.second;
    const Clock::time_point requested_schedule_time =
        receiver->RefreshWatches(now);
    if (requested_schedule_time < minimum_schedule_time)
      minimum_schedule_time = requested_schedule_time;
  }
  return minimum_schedule_time;
}

UrlAvailabilityRequester::ReceiverRequester::ReceiverRequester(
    UrlAvailabilityRequester* listener,
    const std::string& service_id,
    const IPEndpoint& endpoint)
    : listener(listener),
      service_id(service_id),
      connect_request(
          NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
              endpoint,
              this)) {}

UrlAvailabilityRequester::ReceiverRequester::~ReceiverRequester() = default;

void UrlAvailabilityRequester::ReceiverRequester::GetOrRequestAvailabilities(
    const std::vector<std::string>& requested_urls,
    ReceiverObserver* observer) {
  std::vector<std::string> unknown_urls;
  for (const auto& url : requested_urls) {
    auto availability_entry = known_availability_by_url.find(url);
    if (availability_entry == known_availability_by_url.end()) {
      unknown_urls.emplace_back(url);
      continue;
    }

    msgs::UrlAvailability availability = availability_entry->second;
    if (observer) {
      switch (availability) {
        case msgs::UrlAvailability::kAvailable:
          observer->OnReceiverAvailable(url, service_id);
          break;
        case msgs::UrlAvailability::kUnavailable:
        case msgs::UrlAvailability::kInvalid:
          observer->OnReceiverUnavailable(url, service_id);
          break;
      }
    }
  }
  if (!unknown_urls.empty()) {
    RequestUrlAvailabilities(std::move(unknown_urls));
  }
}

void UrlAvailabilityRequester::ReceiverRequester::RequestUrlAvailabilities(
    std::vector<std::string> urls) {
  if (urls.empty())
    return;
  const uint64_t request_id = GetNextRequestId(endpoint_id);
  ErrorOr<uint64_t> watch_id_or_error(0);
  if (!connection || (watch_id_or_error = SendRequest(request_id, urls))) {
    request_by_id.emplace(request_id,
                          Request{watch_id_or_error.value(), std::move(urls)});
  } else {
    for (const auto& url : urls)
      for (auto& observer : listener->observers_by_url_[url])
        observer->OnRequestFailed(url, service_id);
  }
}

ErrorOr<uint64_t> UrlAvailabilityRequester::ReceiverRequester::SendRequest(
    uint64_t request_id,
    const std::vector<std::string>& urls) {
  uint64_t watch_id = next_watch_id++;
  msgs::PresentationUrlAvailabilityRequest cbor_request;
  cbor_request.request_id = request_id;
  cbor_request.urls = urls;
  cbor_request.watch_id = watch_id;
  cbor_request.watch_duration = to_microseconds(kWatchDuration).count();

  msgs::CborEncodeBuffer buffer;
  if (msgs::EncodePresentationUrlAvailabilityRequest(cbor_request, &buffer)) {
    OSP_VLOG << "writing presentation-url-availability-request";
    connection->Write(buffer.data(), buffer.size());
    watch_by_id.emplace(
        watch_id, Watch{listener->now_function_() + kWatchDuration, urls});
    if (!event_watch) {
      event_watch = GetClientDemuxer()->WatchMessageType(
          endpoint_id, msgs::Type::kPresentationUrlAvailabilityEvent, this);
    }
    if (!response_watch) {
      response_watch = GetClientDemuxer()->WatchMessageType(
          endpoint_id, msgs::Type::kPresentationUrlAvailabilityResponse, this);
    }
    return watch_id;
  }
  return Error::Code::kCborEncoding;
}

Clock::time_point UrlAvailabilityRequester::ReceiverRequester::RefreshWatches(
    Clock::time_point now) {
  Clock::time_point minimum_schedule_time = now + kWatchDuration;
  std::vector<std::vector<std::string>> new_requests;
  for (auto entry = watch_by_id.begin(); entry != watch_by_id.end();) {
    Watch& watch = entry->second;
    const Clock::time_point buffered_deadline =
        watch.deadline - kWatchRefreshPadding;
    if (now > buffered_deadline) {
      new_requests.emplace_back(std::move(watch.urls));
      entry = watch_by_id.erase(entry);
    } else {
      ++entry;
      if (buffered_deadline < minimum_schedule_time)
        minimum_schedule_time = buffered_deadline;
    }
  }
  if (watch_by_id.empty())
    StopWatching(&event_watch);

  for (auto& request : new_requests)
    RequestUrlAvailabilities(std::move(request));

  return minimum_schedule_time;
}

Error::Code UrlAvailabilityRequester::ReceiverRequester::UpdateAvailabilities(
    const std::vector<std::string>& urls,
    const std::vector<msgs::UrlAvailability>& availabilities) {
  auto availability_it = availabilities.begin();
  if (urls.size() != availabilities.size()) {
    return Error::Code::kCborInvalidMessage;
  }
  for (const auto& url : urls) {
    auto observer_entry = listener->observers_by_url_.find(url);
    if (observer_entry == listener->observers_by_url_.end())
      continue;
    std::vector<ReceiverObserver*>& observers = observer_entry->second;
    auto result = known_availability_by_url.emplace(url, *availability_it);
    auto entry = result.first;
    bool inserted = result.second;
    bool updated = (entry->second != *availability_it);
    if (inserted || updated) {
      switch (*availability_it) {
        case msgs::UrlAvailability::kAvailable:
          for (auto* observer : observers)
            observer->OnReceiverAvailable(url, service_id);
          break;
        case msgs::UrlAvailability::kUnavailable:
        case msgs::UrlAvailability::kInvalid:
          for (auto* observer : observers)
            observer->OnReceiverUnavailable(url, service_id);
          break;
        default:
          break;
      }
    }
    ++availability_it;
  }
  return Error::Code::kNone;
}

void UrlAvailabilityRequester::ReceiverRequester::RemoveUnobservedRequests(
    const std::set<std::string>& unobserved_urls) {
  std::map<uint64_t, Request> new_requests;
  std::set<std::string> still_observed_urls;
  for (auto entry = request_by_id.begin(); entry != request_by_id.end();
       ++entry) {
    Request& request = entry->second;
    auto split = PartitionUrlsBySetMembership(&request.urls, unobserved_urls);
    if (split == request.urls.end())
      continue;
    MoveVectorSegment(request.urls.begin(), split, &still_observed_urls);
    if (connection)
      watch_by_id.erase(request.watch_id);
  }
  if (!still_observed_urls.empty()) {
    const uint64_t new_request_id = GetNextRequestId(endpoint_id);
    ErrorOr<uint64_t> watch_id_or_error(0);
    std::vector<std::string> urls;
    urls.reserve(still_observed_urls.size());
    for (auto& url : still_observed_urls)
      urls.emplace_back(std::move(url));
    if (!connection ||
        (watch_id_or_error = SendRequest(new_request_id, urls))) {
      new_requests.emplace(new_request_id,
                           Request{watch_id_or_error.value(), std::move(urls)});
    } else {
      for (const auto& url : urls)
        for (auto& observer : listener->observers_by_url_[url])
          observer->OnRequestFailed(url, service_id);
    }
  }

  for (auto& entry : new_requests)
    request_by_id.emplace(entry.first, std::move(entry.second));

  if (request_by_id.empty())
    StopWatching(&response_watch);
}

void UrlAvailabilityRequester::ReceiverRequester::RemoveUnobservedWatches(
    const std::set<std::string>& unobserved_urls) {
  std::set<std::string> still_observed_urls;
  for (auto entry = watch_by_id.begin(); entry != watch_by_id.end();) {
    Watch& watch = entry->second;
    auto split = PartitionUrlsBySetMembership(&watch.urls, unobserved_urls);
    if (split == watch.urls.end()) {
      ++entry;
      continue;
    }
    MoveVectorSegment(watch.urls.begin(), split, &still_observed_urls);
    entry = watch_by_id.erase(entry);
  }

  std::vector<std::string> urls;
  urls.reserve(still_observed_urls.size());
  for (auto& url : still_observed_urls)
    urls.emplace_back(std::move(url));
  RequestUrlAvailabilities(std::move(urls));
  // TODO(btolsch): These message watch cancels could be tested by expecting
  // messages to fall through to the default watch.
  if (watch_by_id.empty())
    StopWatching(&event_watch);
}

void UrlAvailabilityRequester::ReceiverRequester::RemoveReceiver() {
  for (const auto& availability : known_availability_by_url) {
    if (availability.second == msgs::UrlAvailability::kAvailable) {
      const std::string& url = availability.first;
      for (auto& observer : listener->observers_by_url_[url])
        observer->OnReceiverUnavailable(url, service_id);
    }
  }
}

void UrlAvailabilityRequester::ReceiverRequester::OnConnectionOpened(
    uint64_t request_id,
    std::unique_ptr<ProtocolConnection> connection) {
  connect_request.MarkComplete();
  // TODO(btolsch): This is one place where we need to make sure the QUIC
  // connection stays alive, even without constant traffic.
  endpoint_id = connection->endpoint_id();
  this->connection = std::move(connection);
  ErrorOr<uint64_t> watch_id_or_error(0);
  for (auto entry = request_by_id.begin(); entry != request_by_id.end();) {
    if ((watch_id_or_error = SendRequest(entry->first, entry->second.urls))) {
      entry->second.watch_id = watch_id_or_error.value();
      ++entry;
    } else {
      entry = request_by_id.erase(entry);
    }
  }
}

void UrlAvailabilityRequester::ReceiverRequester::OnConnectionFailed(
    uint64_t request_id) {
  connect_request.MarkComplete();

  std::set<std::string> waiting_urls;
  for (auto& entry : request_by_id) {
    Request& request = entry.second;
    for (auto& url : request.urls) {
      waiting_urls.emplace(std::move(url));
    }
  }
  for (const auto& url : waiting_urls)
    for (auto& observer : listener->observers_by_url_[url])
      observer->OnRequestFailed(url, service_id);

  std::string id = std::move(service_id);
  listener->receiver_by_service_id_.erase(id);
}

ErrorOr<size_t> UrlAvailabilityRequester::ReceiverRequester::OnStreamMessage(
    uint64_t endpoint_id,
    uint64_t connection_id,
    msgs::Type message_type,
    const uint8_t* buffer,
    size_t buffer_size,
    Clock::time_point now) {
  switch (message_type) {
    case msgs::Type::kPresentationUrlAvailabilityResponse: {
      msgs::PresentationUrlAvailabilityResponse response;
      ssize_t result = msgs::DecodePresentationUrlAvailabilityResponse(
          buffer, buffer_size, &response);
      if (result < 0) {
        if (result == msgs::kParserEOF)
          return Error::Code::kCborIncompleteMessage;
        OSP_LOG_WARN << "parse error: " << result;
        return Error::Code::kCborParsing;
      } else {
        auto request_entry = request_by_id.find(response.request_id);
        if (request_entry == request_by_id.end()) {
          OSP_LOG_ERROR << "bad response id: " << response.request_id;
          return Error::Code::kCborInvalidResponseId;
        }
        std::vector<std::string>& urls = request_entry->second.urls;
        if (urls.size() != response.url_availabilities.size()) {
          OSP_LOG_WARN << "bad response size: expected " << urls.size()
                       << " but got " << response.url_availabilities.size();
          return Error::Code::kCborInvalidMessage;
        }
        Error::Code update_result =
            UpdateAvailabilities(urls, response.url_availabilities);
        if (update_result != Error::Code::kNone) {
          return update_result;
        }
        request_by_id.erase(response.request_id);
        if (request_by_id.empty())
          StopWatching(&response_watch);
        return result;
      }
    }
    case msgs::Type::kPresentationUrlAvailabilityEvent: {
      msgs::PresentationUrlAvailabilityEvent event;
      ssize_t result = msgs::DecodePresentationUrlAvailabilityEvent(
          buffer, buffer_size, &event);
      if (result < 0) {
        if (result == msgs::kParserEOF)
          return Error::Code::kCborIncompleteMessage;
        OSP_LOG_WARN << "parse error: " << result;
        return Error::Code::kCborParsing;
      } else {
        auto watch_entry = watch_by_id.find(event.watch_id);
        if (watch_entry != watch_by_id.end()) {
          std::vector<std::string> urls = watch_entry->second.urls;
          Error::Code update_result =
              UpdateAvailabilities(urls, event.url_availabilities);
          if (update_result != Error::Code::kNone) {
            return update_result;
          }
        }
        return result;
      }
    }
    default:
      break;
  }
  return Error::Code::kCborParsing;
}

}  // namespace osp
}  // namespace openscreen
