/*
 * Copyright 2015 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "test_channel_transport.h"

#include <errno.h>   // for errno, EBADF
#include <stddef.h>  // for size_t

#include <cstdint>      // for uint8_t
#include <cstring>      // for strerror
#include <type_traits>  // for remove_extent_t

#include "log.h"
#include "net/async_data_channel.h"  // for AsyncDataChannel

using std::vector;

namespace rootcanal {

bool TestChannelTransport::SetUp(std::shared_ptr<AsyncDataChannelServer> server,
                                 ConnectCallback connection_callback) {
  socket_server_ = server;
  socket_server_->SetOnConnectCallback(connection_callback);
  socket_server_->StartListening();
  return socket_server_ != nullptr;
}

void TestChannelTransport::CleanUp() {
  socket_server_->StopListening();
  socket_server_->Close();
}

void TestChannelTransport::OnCommandReady(AsyncDataChannel* socket,
                                          std::function<void(void)> unwatch) {
  uint8_t command_name_size = 0;
  ssize_t bytes_read = socket->Recv(&command_name_size, 1);
  if (bytes_read != 1) {
    INFO("Unexpected (command_name_size) bytes_read: {} != {}, {}", bytes_read, 1, strerror(errno));
    socket->Close();
  }
  vector<uint8_t> command_name_raw;
  command_name_raw.resize(command_name_size);
  bytes_read = socket->Recv(command_name_raw.data(), command_name_size);
  if (bytes_read != command_name_size) {
    INFO("Unexpected (command_name) bytes_read: {} != {}, {}", bytes_read, command_name_size,
         strerror(errno));
  }
  std::string command_name(command_name_raw.begin(), command_name_raw.end());

  if (command_name == "CLOSE_TEST_CHANNEL" || command_name.empty()) {
    INFO("Test channel closed");
    unwatch();
    socket->Close();
    return;
  }

  uint8_t num_args = 0;
  bytes_read = socket->Recv(&num_args, 1);
  if (bytes_read != 1) {
    INFO("Unexpected (num_args) bytes_read: {} != {}, {}", bytes_read, 1, strerror(errno));
  }
  vector<std::string> args;
  for (uint8_t i = 0; i < num_args; ++i) {
    uint8_t arg_size = 0;
    bytes_read = socket->Recv(&arg_size, 1);
    if (bytes_read != 1) {
      INFO("Unexpected (arg_size) bytes_read: {} != {}, {}", bytes_read, 1, strerror(errno));
    }
    vector<uint8_t> arg;
    arg.resize(arg_size);
    bytes_read = socket->Recv(arg.data(), arg_size);
    if (bytes_read != arg_size) {
      INFO("Unexpected (arg) bytes_read: {} != {}, {}", bytes_read, arg_size, strerror(errno));
    }
    args.push_back(std::string(arg.begin(), arg.end()));
  }

  command_handler_(command_name, args);
}

void TestChannelTransport::SendResponse(std::shared_ptr<AsyncDataChannel> socket,
                                        const std::string& response) {
  size_t size = response.size();
  // Cap to 64K
  if (size > 0xffff) {
    size = 0xffff;
  }
  uint8_t size_buf[4] = {
          static_cast<uint8_t>(size & 0xff), static_cast<uint8_t>((size >> 8) & 0xff),
          static_cast<uint8_t>((size >> 16) & 0xff), static_cast<uint8_t>((size >> 24) & 0xff)};
  ssize_t written = socket->Send(size_buf, 4);
  if (written == -1 && errno == EBADF) {
    WARNING("Unable to send a response.  EBADF");
    return;
  }
  ASSERT_LOG(written == 4, "What happened? written = %zd errno = %d", written, errno);
  written = socket->Send(reinterpret_cast<const uint8_t*>(response.c_str()), size);
  ASSERT_LOG(written == static_cast<int>(size), "What happened? written = %zd errno = %d", written,
             errno);
}

void TestChannelTransport::RegisterCommandHandler(
        const std::function<void(const std::string&, const std::vector<std::string>&)>& callback) {
  command_handler_ = callback;
}

}  // namespace rootcanal
