/*
 *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "rtc_base/nat_socket_factory.h"

#include "rtc_base/arraysize.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/nat_server.h"
#include "rtc_base/virtual_socket_server.h"

namespace rtc {

// Packs the given socketaddress into the buffer in buf, in the quasi-STUN
// format that the natserver uses.
// Returns 0 if an invalid address is passed.
size_t PackAddressForNAT(char* buf,
                         size_t buf_size,
                         const SocketAddress& remote_addr) {
  const IPAddress& ip = remote_addr.ipaddr();
  int family = ip.family();
  buf[0] = 0;
  buf[1] = family;
  // Writes the port.
  *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
  if (family == AF_INET) {
    RTC_DCHECK(buf_size >= kNATEncodedIPv4AddressSize);
    in_addr v4addr = ip.ipv4_address();
    memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
    return kNATEncodedIPv4AddressSize;
  } else if (family == AF_INET6) {
    RTC_DCHECK(buf_size >= kNATEncodedIPv6AddressSize);
    in6_addr v6addr = ip.ipv6_address();
    memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
    return kNATEncodedIPv6AddressSize;
  }
  return 0U;
}

// Decodes the remote address from a packet that has been encoded with the nat's
// quasi-STUN format. Returns the length of the address (i.e., the offset into
// data where the original packet starts).
size_t UnpackAddressFromNAT(const char* buf,
                            size_t buf_size,
                            SocketAddress* remote_addr) {
  RTC_DCHECK(buf_size >= 8);
  RTC_DCHECK(buf[0] == 0);
  int family = buf[1];
  uint16_t port =
      NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
  if (family == AF_INET) {
    const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
    *remote_addr = SocketAddress(IPAddress(*v4addr), port);
    return kNATEncodedIPv4AddressSize;
  } else if (family == AF_INET6) {
    RTC_DCHECK(buf_size >= 20);
    const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
    *remote_addr = SocketAddress(IPAddress(*v6addr), port);
    return kNATEncodedIPv6AddressSize;
  }
  return 0U;
}

// NATSocket
class NATSocket : public Socket, public sigslot::has_slots<> {
 public:
  explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
      : sf_(sf),
        family_(family),
        type_(type),
        connected_(false),
        socket_(nullptr),
        buf_(nullptr),
        size_(0) {}

  ~NATSocket() override {
    delete socket_;
    delete[] buf_;
  }

  SocketAddress GetLocalAddress() const override {
    return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
  }

  SocketAddress GetRemoteAddress() const override {
    return remote_addr_;  // will be NIL if not connected
  }

  int Bind(const SocketAddress& addr) override {
    if (socket_) {  // already bound, bubble up error
      return -1;
    }

    return BindInternal(addr);
  }

  int Connect(const SocketAddress& addr) override {
    int result = 0;
    // If we're not already bound (meaning `socket_` is null), bind to ANY
    // address.
    if (!socket_) {
      result = BindInternal(SocketAddress(GetAnyIP(family_), 0));
      if (result < 0) {
        return result;
      }
    }

    if (type_ == SOCK_STREAM) {
      result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
    } else {
      connected_ = true;
    }

    if (result >= 0) {
      remote_addr_ = addr;
    }

    return result;
  }

  int Send(const void* data, size_t size) override {
    RTC_DCHECK(connected_);
    return SendTo(data, size, remote_addr_);
  }

  int SendTo(const void* data,
             size_t size,
             const SocketAddress& addr) override {
    RTC_DCHECK(!connected_ || addr == remote_addr_);
    if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
      return socket_->SendTo(data, size, addr);
    }
    // This array will be too large for IPv4 packets, but only by 12 bytes.
    std::unique_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
    size_t addrlength =
        PackAddressForNAT(buf.get(), size + kNATEncodedIPv6AddressSize, addr);
    size_t encoded_size = size + addrlength;
    memcpy(buf.get() + addrlength, data, size);
    int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
    if (result >= 0) {
      RTC_DCHECK(result == static_cast<int>(encoded_size));
      result = result - static_cast<int>(addrlength);
    }
    return result;
  }

  int Recv(void* data, size_t size, int64_t* timestamp) override {
    SocketAddress addr;
    return RecvFrom(data, size, &addr, timestamp);
  }

  int RecvFrom(void* data,
               size_t size,
               SocketAddress* out_addr,
               int64_t* timestamp) override {
    if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
      return socket_->RecvFrom(data, size, out_addr, timestamp);
    }
    // Make sure we have enough room to read the requested amount plus the
    // largest possible header address.
    SocketAddress remote_addr;
    Grow(size + kNATEncodedIPv6AddressSize);

    // Read the packet from the socket.
    int result = socket_->RecvFrom(buf_, size_, &remote_addr, timestamp);
    if (result >= 0) {
      RTC_DCHECK(remote_addr == server_addr_);

      // TODO: we need better framing so we know how many bytes we can
      // return before we need to read the next address. For UDP, this will be
      // fine as long as the reader always reads everything in the packet.
      RTC_DCHECK((size_t)result < size_);

      // Decode the wire packet into the actual results.
      SocketAddress real_remote_addr;
      size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
      memcpy(data, buf_ + addrlength, result - addrlength);

      // Make sure this packet should be delivered before returning it.
      if (!connected_ || (real_remote_addr == remote_addr_)) {
        if (out_addr)
          *out_addr = real_remote_addr;
        result = result - static_cast<int>(addrlength);
      } else {
        RTC_LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
                          << real_remote_addr.ToString();
        result = 0;  // Tell the caller we didn't read anything
      }
    }

    return result;
  }

  int Close() override {
    int result = 0;
    if (socket_) {
      result = socket_->Close();
      if (result >= 0) {
        connected_ = false;
        remote_addr_ = SocketAddress();
        delete socket_;
        socket_ = nullptr;
      }
    }
    return result;
  }

  int Listen(int backlog) override { return socket_->Listen(backlog); }
  Socket* Accept(SocketAddress* paddr) override {
    return socket_->Accept(paddr);
  }
  int GetError() const override {
    return socket_ ? socket_->GetError() : error_;
  }
  void SetError(int error) override {
    if (socket_) {
      socket_->SetError(error);
    } else {
      error_ = error;
    }
  }
  ConnState GetState() const override {
    return connected_ ? CS_CONNECTED : CS_CLOSED;
  }
  int GetOption(Option opt, int* value) override {
    return socket_ ? socket_->GetOption(opt, value) : -1;
  }
  int SetOption(Option opt, int value) override {
    return socket_ ? socket_->SetOption(opt, value) : -1;
  }

  void OnConnectEvent(Socket* socket) {
    // If we're NATed, we need to send a message with the real addr to use.
    RTC_DCHECK(socket == socket_);
    if (server_addr_.IsNil()) {
      connected_ = true;
      SignalConnectEvent(this);
    } else {
      SendConnectRequest();
    }
  }
  void OnReadEvent(Socket* socket) {
    // If we're NATed, we need to process the connect reply.
    RTC_DCHECK(socket == socket_);
    if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
      HandleConnectReply();
    } else {
      SignalReadEvent(this);
    }
  }
  void OnWriteEvent(Socket* socket) {
    RTC_DCHECK(socket == socket_);
    SignalWriteEvent(this);
  }
  void OnCloseEvent(Socket* socket, int error) {
    RTC_DCHECK(socket == socket_);
    SignalCloseEvent(this, error);
  }

 private:
  int BindInternal(const SocketAddress& addr) {
    RTC_DCHECK(!socket_);

    int result;
    socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
    result = (socket_) ? socket_->Bind(addr) : -1;
    if (result >= 0) {
      socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
      socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
      socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
      socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
    } else {
      server_addr_.Clear();
      delete socket_;
      socket_ = nullptr;
    }

    return result;
  }

  // Makes sure the buffer is at least the given size.
  void Grow(size_t new_size) {
    if (size_ < new_size) {
      delete[] buf_;
      size_ = new_size;
      buf_ = new char[size_];
    }
  }

  // Sends the destination address to the server to tell it to connect.
  void SendConnectRequest() {
    char buf[kNATEncodedIPv6AddressSize];
    size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
    socket_->Send(buf, length);
  }

  // Handles the byte sent back from the server and fires the appropriate event.
  void HandleConnectReply() {
    char code;
    socket_->Recv(&code, sizeof(code), nullptr);
    if (code == 0) {
      connected_ = true;
      SignalConnectEvent(this);
    } else {
      Close();
      SignalCloseEvent(this, code);
    }
  }

  NATInternalSocketFactory* sf_;
  int family_;
  int type_;
  bool connected_;
  SocketAddress remote_addr_;
  SocketAddress server_addr_;  // address of the NAT server
  Socket* socket_;
  // Need to hold error in case it occurs before the socket is created.
  int error_ = 0;
  char* buf_;
  size_t size_;
};

// NATSocketFactory
NATSocketFactory::NATSocketFactory(SocketFactory* factory,
                                   const SocketAddress& nat_udp_addr,
                                   const SocketAddress& nat_tcp_addr)
    : factory_(factory),
      nat_udp_addr_(nat_udp_addr),
      nat_tcp_addr_(nat_tcp_addr) {}

Socket* NATSocketFactory::CreateSocket(int family, int type) {
  return new NATSocket(this, family, type);
}

Socket* NATSocketFactory::CreateInternalSocket(int family,
                                               int type,
                                               const SocketAddress& local_addr,
                                               SocketAddress* nat_addr) {
  if (type == SOCK_STREAM) {
    *nat_addr = nat_tcp_addr_;
  } else {
    *nat_addr = nat_udp_addr_;
  }
  return factory_->CreateSocket(family, type);
}

// NATSocketServer
NATSocketServer::NATSocketServer(SocketServer* server)
    : server_(server), msg_queue_(nullptr) {}

NATSocketServer::Translator* NATSocketServer::GetTranslator(
    const SocketAddress& ext_ip) {
  return nats_.Get(ext_ip);
}

NATSocketServer::Translator* NATSocketServer::AddTranslator(
    const SocketAddress& ext_ip,
    const SocketAddress& int_ip,
    NATType type) {
  // Fail if a translator already exists with this extternal address.
  if (nats_.Get(ext_ip))
    return nullptr;

  return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
}

void NATSocketServer::RemoveTranslator(const SocketAddress& ext_ip) {
  nats_.Remove(ext_ip);
}

Socket* NATSocketServer::CreateSocket(int family, int type) {
  return new NATSocket(this, family, type);
}

void NATSocketServer::SetMessageQueue(Thread* queue) {
  msg_queue_ = queue;
  server_->SetMessageQueue(queue);
}

bool NATSocketServer::Wait(webrtc::TimeDelta max_wait_duration,
                           bool process_io) {
  return server_->Wait(max_wait_duration, process_io);
}

void NATSocketServer::WakeUp() {
  server_->WakeUp();
}

Socket* NATSocketServer::CreateInternalSocket(int family,
                                              int type,
                                              const SocketAddress& local_addr,
                                              SocketAddress* nat_addr) {
  Socket* socket = nullptr;
  Translator* nat = nats_.FindClient(local_addr);
  if (nat) {
    socket = nat->internal_factory()->CreateSocket(family, type);
    *nat_addr = (type == SOCK_STREAM) ? nat->internal_tcp_address()
                                      : nat->internal_udp_address();
  } else {
    socket = server_->CreateSocket(family, type);
  }
  return socket;
}

// NATSocketServer::Translator
NATSocketServer::Translator::Translator(NATSocketServer* server,
                                        NATType type,
                                        const SocketAddress& int_ip,
                                        SocketFactory* ext_factory,
                                        const SocketAddress& ext_ip)
    : server_(server) {
  // Create a new private network, and a NATServer running on the private
  // network that bridges to the external network. Also tell the private
  // network to use the same message queue as us.
  internal_server_ = std::make_unique<VirtualSocketServer>();
  internal_server_->SetMessageQueue(server_->queue());
  nat_server_ = std::make_unique<NATServer>(
      type, internal_server_.get(), int_ip, int_ip, ext_factory, ext_ip);
}

NATSocketServer::Translator::~Translator() {
  internal_server_->SetMessageQueue(nullptr);
}

NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
    const SocketAddress& ext_ip) {
  return nats_.Get(ext_ip);
}

NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
    const SocketAddress& ext_ip,
    const SocketAddress& int_ip,
    NATType type) {
  // Fail if a translator already exists with this extternal address.
  if (nats_.Get(ext_ip))
    return nullptr;

  AddClient(ext_ip);
  return nats_.Add(ext_ip,
                   new Translator(server_, type, int_ip, server_, ext_ip));
}
void NATSocketServer::Translator::RemoveTranslator(
    const SocketAddress& ext_ip) {
  nats_.Remove(ext_ip);
  RemoveClient(ext_ip);
}

bool NATSocketServer::Translator::AddClient(const SocketAddress& int_ip) {
  // Fail if a client already exists with this internal address.
  if (clients_.find(int_ip) != clients_.end())
    return false;

  clients_.insert(int_ip);
  return true;
}

void NATSocketServer::Translator::RemoveClient(const SocketAddress& int_ip) {
  std::set<SocketAddress>::iterator it = clients_.find(int_ip);
  if (it != clients_.end()) {
    clients_.erase(it);
  }
}

NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
    const SocketAddress& int_ip) {
  // See if we have the requested IP, or any of our children do.
  return (clients_.find(int_ip) != clients_.end()) ? this
                                                   : nats_.FindClient(int_ip);
}

// NATSocketServer::TranslatorMap
NATSocketServer::TranslatorMap::~TranslatorMap() {
  for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
    delete it->second;
  }
}

NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
    const SocketAddress& ext_ip) {
  TranslatorMap::iterator it = find(ext_ip);
  return (it != end()) ? it->second : nullptr;
}

NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
    const SocketAddress& ext_ip,
    Translator* nat) {
  (*this)[ext_ip] = nat;
  return nat;
}

void NATSocketServer::TranslatorMap::Remove(const SocketAddress& ext_ip) {
  TranslatorMap::iterator it = find(ext_ip);
  if (it != end()) {
    delete it->second;
    erase(it);
  }
}

NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
    const SocketAddress& int_ip) {
  Translator* nat = nullptr;
  for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
    nat = it->second->FindClient(int_ip);
  }
  return nat;
}

}  // namespace rtc
