// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "cast/sender/public/sender_socket_factory.h"

#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/sender/channel/cast_auth_util.h"
#include "cast/sender/channel/message_util.h"
#include "platform/base/tls_connect_options.h"
#include "util/crypto/certificate_utils.h"
#include "util/osp_logging.h"

using ::cast::channel::CastMessage;

namespace openscreen {
namespace cast {

SenderSocketFactory::Client::~Client() = default;

bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
               int b) {
  return a && a->socket->socket_id() < b;
}

bool operator<(int a,
               const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
  return b && a < b->socket->socket_id();
}

SenderSocketFactory::SenderSocketFactory(Client* client,
                                         TaskRunner* task_runner)
    : client_(client), task_runner_(task_runner) {
  OSP_DCHECK(client);
  OSP_DCHECK(task_runner);
}

SenderSocketFactory::~SenderSocketFactory() {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
}

void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) {
  OSP_DCHECK(factory);
  factory_ = factory;
}

void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
                                  DeviceMediaPolicy media_policy,
                                  CastSocket::Client* client) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  OSP_DCHECK(client);
  auto it = FindPendingConnection(endpoint);
  if (it == pending_connections_.end()) {
    pending_connections_.emplace_back(
        PendingConnection{endpoint, media_policy, client});
    factory_->Connect(endpoint, TlsConnectOptions{true});
  }
}

void SenderSocketFactory::OnAccepted(
    TlsConnectionFactory* factory,
    std::vector<uint8_t> der_x509_peer_cert,
    std::unique_ptr<TlsConnection> connection) {
  OSP_NOTREACHED();
  OSP_LOG_FATAL << "This factory is connect-only";
}

void SenderSocketFactory::OnConnected(
    TlsConnectionFactory* factory,
    std::vector<uint8_t> der_x509_peer_cert,
    std::unique_ptr<TlsConnection> connection) {
  const IPEndpoint& endpoint = connection->GetRemoteEndpoint();
  auto it = FindPendingConnection(endpoint);
  if (it == pending_connections_.end()) {
    OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
                   << endpoint;
    return;
  }
  DeviceMediaPolicy media_policy = it->media_policy;
  CastSocket::Client* client = it->client;
  pending_connections_.erase(it);

  ErrorOr<bssl::UniquePtr<X509>> peer_cert =
      ImportCertificate(der_x509_peer_cert.data(), der_x509_peer_cert.size());
  if (!peer_cert) {
    client_->OnError(this, endpoint, peer_cert.error());
    return;
  }

  auto socket =
      MakeSerialDelete<CastSocket>(task_runner_, std::move(connection), this);
  pending_auth_.emplace_back(
      new PendingAuth{endpoint, media_policy, std::move(socket), client,
                      std::make_unique<AuthContext>(AuthContext::Create()),
                      std::move(peer_cert.value())});
  PendingAuth& pending = *pending_auth_.back();

  CastMessage auth_challenge =
      CreateAuthChallengeMessage(*pending.auth_context);
  Error error = pending.socket->Send(auth_challenge);
  if (!error.ok()) {
    pending_auth_.pop_back();
    client_->OnError(this, endpoint, error);
  }
}

void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory,
                                             const IPEndpoint& remote_address) {
  auto it = FindPendingConnection(remote_address);
  if (it == pending_connections_.end()) {
    return;
  }
  pending_connections_.erase(it);
  client_->OnError(this, remote_address, Error::Code::kConnectionFailed);
}

void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) {
  std::vector<PendingConnection> connections;
  pending_connections_.swap(connections);
  for (const PendingConnection& pending : connections) {
    client_->OnError(this, pending.endpoint, error);
  }
}

std::vector<SenderSocketFactory::PendingConnection>::iterator
SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) {
  return std::find_if(pending_connections_.begin(), pending_connections_.end(),
                      [&endpoint](const PendingConnection& pending) {
                        return pending.endpoint == endpoint;
                      });
}

void SenderSocketFactory::OnError(CastSocket* socket, Error error) {
  auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
                         [id = socket->socket_id()](
                             const std::unique_ptr<PendingAuth>& pending_auth) {
                           return pending_auth->socket->socket_id() == id;
                         });
  if (it == pending_auth_.end()) {
    OSP_DLOG_ERROR << "Got error for unknown pending socket";
    return;
  }
  IPEndpoint endpoint = (*it)->endpoint;
  pending_auth_.erase(it);
  client_->OnError(this, endpoint, error);
}

void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
  auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
                         [id = socket->socket_id()](
                             const std::unique_ptr<PendingAuth>& pending_auth) {
                           return pending_auth->socket->socket_id() == id;
                         });
  if (it == pending_auth_.end()) {
    OSP_DLOG_ERROR << "Got message for unknown pending socket";
    return;
  }

  std::unique_ptr<PendingAuth> pending = std::move(*it);
  pending_auth_.erase(it);
  if (!IsAuthMessage(message)) {
    client_->OnError(this, pending->endpoint,
                     Error::Code::kCastV2AuthenticationError);
    return;
  }

  ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
      message, pending->peer_cert.get(), *pending->auth_context);
  if (policy_or_error.is_error()) {
    OSP_DLOG_WARN << "Authentication failed for " << pending->endpoint
                  << " with error: " << policy_or_error.error();
    client_->OnError(this, pending->endpoint, policy_or_error.error());
    return;
  }

  if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
      pending->media_policy == DeviceMediaPolicy::kIncludesVideo) {
    client_->OnError(this, pending->endpoint,
                     Error::Code::kCastV2ChannelPolicyMismatch);
    return;
  }
  pending->socket->set_audio_only(policy_or_error.value() ==
                                  CastDeviceCertPolicy::kAudioOnly);

  pending->socket->SetClient(pending->client);
  client_->OnConnected(this, pending->endpoint,
                       std::unique_ptr<CastSocket>(pending->socket.release()));
}

}  // namespace cast
}  // namespace openscreen
