// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "platform/impl/tls_connection_factory_posix.h"

#include <errno.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <openssl/ssl.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>

#include <cstring>
#include <utility>
#include <vector>

#include "platform/api/task_runner.h"
#include "platform/api/tls_connection_factory.h"
#include "platform/base/tls_connect_options.h"
#include "platform/base/tls_credentials.h"
#include "platform/base/tls_listen_options.h"
#include "platform/impl/stream_socket.h"
#include "platform/impl/tls_connection_posix.h"
#include "util/crypto/certificate_utils.h"
#include "util/crypto/openssl_util.h"
#include "util/osp_logging.h"
#include "util/trace_logging.h"

namespace openscreen {

namespace {

ErrorOr<std::vector<uint8_t>> GetDEREncodedPeerCertificate(const SSL& ssl) {
  X509* const peer_cert = SSL_get_peer_certificate(&ssl);
  ErrorOr<std::vector<uint8_t>> der_peer_cert =
      ExportX509CertificateToDer(*peer_cert);
  X509_free(peer_cert);
  return der_peer_cert;
}

}  // namespace

std::unique_ptr<TlsConnectionFactory> TlsConnectionFactory::CreateFactory(
    Client* client,
    TaskRunner* task_runner) {
  return std::unique_ptr<TlsConnectionFactory>(
      new TlsConnectionFactoryPosix(client, task_runner));
}

TlsConnectionFactoryPosix::TlsConnectionFactoryPosix(
    Client* client,
    TaskRunner* task_runner,
    PlatformClientPosix* platform_client)
    : client_(client),
      task_runner_(task_runner),
      platform_client_(platform_client) {
  OSP_DCHECK(client_);
  OSP_DCHECK(task_runner_);
}

TlsConnectionFactoryPosix::~TlsConnectionFactoryPosix() {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  if (platform_client_) {
    platform_client_->tls_data_router()->DeregisterAcceptObserver(this);
  }
}

// TODO(rwkeane): Add support for resuming sessions.
// TODO(rwkeane): Integrate with Auth.
void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address,
                                        const TlsConnectOptions& options) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  TRACE_SCOPED(TraceCategory::kSsl, "TlsConnectionFactoryPosix::Connect");
  IPAddress::Version version = remote_address.address.version();
  std::unique_ptr<TlsConnectionPosix> connection(
      new TlsConnectionPosix(version, task_runner_));
  Error connect_error = connection->socket_->Connect(remote_address);
  if (!connect_error.ok()) {
    TRACE_SET_RESULT(connect_error);
    DispatchConnectionFailed(remote_address);
    return;
  }

  if (!ConfigureSsl(connection.get())) {
    return;
  }

  if (options.unsafely_skip_certificate_validation) {
    // Verifies the server certificate but does not make errors fatal.
    SSL_set_verify(connection->ssl_.get(), SSL_VERIFY_NONE, nullptr);
  } else {
    // Make server certificate errors fatal.
    SSL_set_verify(connection->ssl_.get(), SSL_VERIFY_PEER, nullptr);
  }

  Connect(std::move(connection));
}

void TlsConnectionFactoryPosix::SetListenCredentials(
    const TlsCredentials& credentials) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  EnsureInitialized();

  ErrorOr<bssl::UniquePtr<X509>> cert = ImportCertificate(
      credentials.der_x509_cert.data(), credentials.der_x509_cert.size());
  ErrorOr<bssl::UniquePtr<EVP_PKEY>> pkey =
      ImportRSAPrivateKey(credentials.der_rsa_private_key.data(),
                          credentials.der_rsa_private_key.size());

  if (!cert || !pkey ||
      SSL_CTX_use_certificate(ssl_context_.get(), cert.value().get()) != 1 ||
      SSL_CTX_use_PrivateKey(ssl_context_.get(), pkey.value().get()) != 1) {
    DispatchError(Error::Code::kSocketListenFailure);
    TRACE_SET_RESULT(Error::Code::kSocketListenFailure);
    return;
  }

  listen_credentials_set_ = true;
}

void TlsConnectionFactoryPosix::Listen(const IPEndpoint& local_address,
                                       const TlsListenOptions& options) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  // Credentials must be set before Listen() is called.
  OSP_DCHECK(listen_credentials_set_);

  auto socket = std::make_unique<StreamSocketPosix>(local_address);
  socket->Bind();
  socket->Listen(options.backlog_size);
  if (socket->state() == TcpSocketState::kClosed) {
    DispatchError(Error::Code::kSocketListenFailure);
    TRACE_SET_RESULT(Error::Code::kSocketListenFailure);
    return;
  }
  OSP_DCHECK(socket->state() == TcpSocketState::kListening);

  OSP_DCHECK(platform_client_);
  if (platform_client_) {
    platform_client_->tls_data_router()->RegisterAcceptObserver(
        std::move(socket), this);
  }
}

void TlsConnectionFactoryPosix::OnConnectionPending(StreamSocketPosix* socket) {
  task_runner_->PostTask([connection_factory_weak_ptr =
                              weak_factory_.GetWeakPtr(),
                          socket_weak_ptr = socket->GetWeakPtr()] {
    if (!connection_factory_weak_ptr || !socket_weak_ptr) {
      // Cancel the Accept() since either the factory or the listener socket
      // went away before this task has run.
      return;
    }

    ErrorOr<std::unique_ptr<StreamSocket>> accepted = socket_weak_ptr->Accept();
    if (accepted.is_error()) {
      // Check for special error code. Because this call doesn't get executed
      // until it gets through the task runner, OnConnectionPending may get
      // called multiple times. This check ensures only the first such call will
      // create a new SSL connection.
      if (accepted.error().code() != Error::Code::kAgain) {
        connection_factory_weak_ptr->DispatchError(std::move(accepted.error()));
      }
      return;
    }

    connection_factory_weak_ptr->OnSocketAccepted(std::move(accepted.value()));
  });
}

void TlsConnectionFactoryPosix::OnSocketAccepted(
    std::unique_ptr<StreamSocket> socket) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  TRACE_SCOPED(TraceCategory::kSsl,
               "TlsConnectionFactoryPosix::OnSocketAccepted");
  std::unique_ptr<TlsConnectionPosix> connection(
      new TlsConnectionPosix(std::move(socket), task_runner_));

  if (!ConfigureSsl(connection.get())) {
    return;
  }

  Accept(std::move(connection));
}

bool TlsConnectionFactoryPosix::ConfigureSsl(TlsConnectionPosix* connection) {
  ErrorOr<bssl::UniquePtr<SSL>> connection_result = GetSslConnection();
  if (connection_result.is_error()) {
    DispatchError(connection_result.error());
    TRACE_SET_RESULT(connection_result.error());
    return false;
  }

  bssl::UniquePtr<SSL> ssl = std::move(connection_result.value());
  if (!SSL_set_fd(ssl.get(), connection->socket_->socket_handle().fd)) {
    DispatchConnectionFailed(connection->GetRemoteEndpoint());
    TRACE_SET_RESULT(Error(Error::Code::kSocketBindFailure));
    return false;
  }

  connection->ssl_.swap(ssl);
  return true;
}

ErrorOr<bssl::UniquePtr<SSL>> TlsConnectionFactoryPosix::GetSslConnection() {
  EnsureInitialized();
  if (!ssl_context_.get()) {
    return Error::Code::kFatalSSLError;
  }

  SSL* ssl = SSL_new(ssl_context_.get());
  if (ssl == nullptr) {
    return Error::Code::kFatalSSLError;
  }

  return bssl::UniquePtr<SSL>(ssl);
}

void TlsConnectionFactoryPosix::EnsureInitialized() {
  std::call_once(init_instance_flag_, [this]() { this->Initialize(); });
}

void TlsConnectionFactoryPosix::Initialize() {
  EnsureOpenSSLInit();
  SSL_CTX* context = SSL_CTX_new(TLS_method());
  if (context == nullptr) {
    return;
  }

  SSL_CTX_set_mode(context, SSL_MODE_ENABLE_PARTIAL_WRITE);

  ssl_context_.reset(context);
}

void TlsConnectionFactoryPosix::Connect(
    std::unique_ptr<TlsConnectionPosix> connection) {
  if (connection->socket_->state() == TcpSocketState::kClosed) {
    return;
  }
  OSP_DCHECK(connection->socket_->state() == TcpSocketState::kConnected);
  ClearOpenSSLERRStack(CURRENT_LOCATION);
  const int connection_status = SSL_connect(connection->ssl_.get());
  if (connection_status != 1) {
    Error error = GetSSLError(connection->ssl_.get(), connection_status);
    if (error.code() == Error::Code::kAgain) {
      task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
                              conn = std::move(connection)]() mutable {
        if (auto* self = weak_this.get()) {
          self->Connect(std::move(conn));
        }
      });
      return;
    } else {
      OSP_DVLOG << "SSL_connect failed with error: " << error;
      DispatchConnectionFailed(connection->GetRemoteEndpoint());
      TRACE_SET_RESULT(error);
      return;
    }
  }

  ErrorOr<std::vector<uint8_t>> der_peer_cert =
      GetDEREncodedPeerCertificate(*connection->ssl_);
  if (!der_peer_cert) {
    DispatchConnectionFailed(connection->GetRemoteEndpoint());
    TRACE_SET_RESULT(der_peer_cert.error());
    return;
  }

  connection->RegisterConnectionWithDataRouter(platform_client_);
  task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
                          der = std::move(der_peer_cert.value()),
                          moved_connection = std::move(connection)]() mutable {
    if (auto* self = weak_this.get()) {
      self->client_->OnConnected(self, std::move(der),
                                 std::move(moved_connection));
    }
  });
}

void TlsConnectionFactoryPosix::Accept(
    std::unique_ptr<TlsConnectionPosix> connection) {
  if (connection->socket_->state() == TcpSocketState::kClosed) {
    return;
  }
  OSP_DCHECK(connection->socket_->state() == TcpSocketState::kConnected);

  ClearOpenSSLERRStack(CURRENT_LOCATION);
  const int connection_status = SSL_accept(connection->ssl_.get());
  if (connection_status != 1) {
    Error error = GetSSLError(connection->ssl_.get(), connection_status);
    if (error.code() == Error::Code::kAgain) {
      task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
                              conn = std::move(connection)]() mutable {
        if (auto* self = weak_this.get()) {
          self->Accept(std::move(conn));
        }
      });
      return;
    } else {
      OSP_DVLOG << "SSL_accept failed with error: " << error;
      DispatchConnectionFailed(connection->GetRemoteEndpoint());
      TRACE_SET_RESULT(error);
      return;
    }
  }

  ErrorOr<std::vector<uint8_t>> der_peer_cert =
      GetDEREncodedPeerCertificate(*connection->ssl_);
  std::vector<uint8_t> der;
  if (der_peer_cert) {
    der = std::move(der_peer_cert.value());
  }
  connection->RegisterConnectionWithDataRouter(platform_client_);
  task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
                          der = std::move(der),
                          moved_connection = std::move(connection)]() mutable {
    if (auto* self = weak_this.get()) {
      self->client_->OnAccepted(self, std::move(der),
                                std::move(moved_connection));
    }
  });
}

void TlsConnectionFactoryPosix::DispatchConnectionFailed(
    const IPEndpoint& remote_endpoint) {
  task_runner_->PostTask(
      [weak_this = weak_factory_.GetWeakPtr(), remote = remote_endpoint] {
        if (auto* self = weak_this.get()) {
          self->client_->OnConnectionFailed(self, remote);
        }
      });
}

void TlsConnectionFactoryPosix::DispatchError(Error error) {
  task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
                          moved_error = std::move(error)]() mutable {
    if (auto* self = weak_this.get()) {
      self->client_->OnError(self, std::move(moved_error));
    }
  });
}

}  // namespace openscreen
