// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "discovery/dnssd/impl/publisher_impl.h"

#include <map>
#include <string>
#include <utility>
#include <vector>

#include "absl/types/optional.h"
#include "discovery/common/reporting_client.h"
#include "discovery/dnssd/impl/conversion_layer.h"
#include "discovery/dnssd/impl/instance_key.h"
#include "discovery/dnssd/impl/network_interface_config.h"
#include "discovery/mdns/public/mdns_constants.h"
#include "platform/api/task_runner.h"
#include "platform/base/error.h"
#include "util/trace_logging.h"

namespace openscreen {
namespace discovery {
namespace {

DnsSdInstanceEndpoint CreateEndpoint(
    DnsSdInstance instance,
    InstanceKey key,
    const NetworkInterfaceConfig& network_config) {
  std::vector<IPEndpoint> endpoints;
  if (network_config.HasAddressV4()) {
    endpoints.push_back({network_config.address_v4(), instance.port()});
  }
  if (network_config.HasAddressV6()) {
    endpoints.push_back({network_config.address_v6(), instance.port()});
  }
  return DnsSdInstanceEndpoint(
      key.instance_id(), key.service_id(), key.domain_id(), instance.txt(),
      network_config.network_interface(), std::move(endpoints));
}

DnsSdInstanceEndpoint UpdateDomain(
    const DomainName& name,
    DnsSdInstance instance,
    const NetworkInterfaceConfig& network_config) {
  return CreateEndpoint(std::move(instance), InstanceKey(name), network_config);
}

DnsSdInstanceEndpoint CreateEndpoint(
    DnsSdInstance instance,
    const NetworkInterfaceConfig& network_config) {
  InstanceKey key(instance);
  return CreateEndpoint(std::move(instance), std::move(key), network_config);
}

template <typename T>
inline typename std::map<DnsSdInstance, T>::iterator FindKey(
    std::map<DnsSdInstance, T>* instances,
    const InstanceKey& key) {
  return std::find_if(instances->begin(), instances->end(),
                      [&key](const std::pair<DnsSdInstance, T>& pair) {
                        return key == InstanceKey(pair.first);
                      });
}

template <typename T>
int EraseInstancesWithServiceId(std::map<DnsSdInstance, T>* instances,
                                const std::string& service_id) {
  int removed_count = 0;
  for (auto it = instances->begin(); it != instances->end();) {
    if (it->first.service_id() == service_id) {
      removed_count++;
      it = instances->erase(it);
    } else {
      it++;
    }
  }

  return removed_count;
}

}  // namespace

PublisherImpl::PublisherImpl(MdnsService* publisher,
                             ReportingClient* reporting_client,
                             TaskRunner* task_runner,
                             const NetworkInterfaceConfig* network_config)
    : mdns_publisher_(publisher),
      reporting_client_(reporting_client),
      task_runner_(task_runner),
      network_config_(network_config) {
  OSP_DCHECK(mdns_publisher_);
  OSP_DCHECK(reporting_client_);
  OSP_DCHECK(task_runner_);
}

PublisherImpl::~PublisherImpl() = default;

Error PublisherImpl::Register(const DnsSdInstance& instance, Client* client) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  OSP_DCHECK(client != nullptr);

  if (published_instances_.find(instance) != published_instances_.end()) {
    UpdateRegistration(instance);
  } else if (pending_instances_.find(instance) != pending_instances_.end()) {
    return Error::Code::kOperationInProgress;
  }

  InstanceKey key(instance);
  const IPAddress& address = network_config_->GetAddress();
  OSP_DCHECK(address);
  pending_instances_.emplace(CreateEndpoint(instance, *network_config_),
                             client);

  OSP_DVLOG << "Registering instance '" << instance.instance_id() << "'";

  return mdns_publisher_->StartProbe(this, GetDomainName(key), address);
}

Error PublisherImpl::UpdateRegistration(const DnsSdInstance& instance) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  // Check if the instance is still pending publication.
  auto it = FindKey(&pending_instances_, InstanceKey(instance));

  OSP_DVLOG << "Updating instance '" << instance.instance_id() << "'";

  // If it is a pending instance, update it. Else, try to update a published
  // instance.
  if (it != pending_instances_.end()) {
    // The instance, service, and domain ids have not changed, so only the
    // remaining data needs to change. The ongoing probe does not need to be
    // modified.
    Client* const client = it->second;
    pending_instances_.erase(it);
    pending_instances_.emplace(CreateEndpoint(instance, *network_config_),
                               client);
    return Error::None();
  } else {
    return UpdatePublishedRegistration(instance);
  }
}

Error PublisherImpl::UpdatePublishedRegistration(
    const DnsSdInstance& instance) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  auto published_instance_it =
      FindKey(&published_instances_, InstanceKey(instance));

  // Check preconditions called out in header. Specifically, the updated
  // instance must be making changes to an already published instance.
  if (published_instance_it == published_instances_.end()) {
    return Error::Code::kParameterInvalid;
  }

  const DnsSdInstanceEndpoint updated_endpoint =
      UpdateDomain(GetDomainName(InstanceKey(published_instance_it->second)),
                   instance, *network_config_);
  if (published_instance_it->second == updated_endpoint) {
    return Error::Code::kParameterInvalid;
  }

  // Get all instances which have changed. By design, there an only be one
  // instance of each DnsType, so use that here to simplify this step. First in
  // each pair is the old instances, second is the new instance.
  std::map<DnsType,
           std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>>>
      changed_records;
  const std::vector<MdnsRecord> old_records =
      GetDnsRecords(published_instance_it->second);
  const std::vector<MdnsRecord> new_records = GetDnsRecords(updated_endpoint);

  // Populate the first part of each pair in |changed_instances|.
  for (size_t i = 0; i < old_records.size(); i++) {
    const auto key = old_records[i].dns_type();
    OSP_DCHECK(changed_records.find(key) == changed_records.end());
    auto value = std::make_pair(std::move(old_records[i]), absl::nullopt);
    changed_records.emplace(key, std::move(value));
  }

  // Populate the second part of each pair in |changed_records|.
  for (size_t i = 0; i < new_records.size(); i++) {
    const auto key = new_records[i].dns_type();
    auto find_it = changed_records.find(key);
    if (find_it == changed_records.end()) {
      std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>> value(
          absl::nullopt, std::move(new_records[i]));
      changed_records.emplace(key, std::move(value));
    } else {
      find_it->second.second = std::move(new_records[i]);
    }
  }

  // Apply changes called out in |changed_records|.
  Error total_result = Error::None();
  for (const auto& pair : changed_records) {
    OSP_DCHECK(pair.second.first != absl::nullopt ||
               pair.second.second != absl::nullopt);
    if (pair.second.first == absl::nullopt) {
      TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.RegisterRecord");
      auto error = mdns_publisher_->RegisterRecord(pair.second.second.value());
      TRACE_SET_RESULT(error);
      if (!error.ok()) {
        total_result = error;
      }
    } else if (pair.second.second == absl::nullopt) {
      TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UnregisterRecord");
      auto error = mdns_publisher_->UnregisterRecord(pair.second.first.value());
      TRACE_SET_RESULT(error);
      if (!error.ok()) {
        total_result = error;
      }
    } else if (pair.second.first.value() != pair.second.second.value()) {
      TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UpdateRegisteredRecord");
      auto error = mdns_publisher_->UpdateRegisteredRecord(
          pair.second.first.value(), pair.second.second.value());
      TRACE_SET_RESULT(error);
      if (!error.ok()) {
        total_result = error;
      }
    }
  }

  // Replace the old instances with the new ones.
  published_instances_.erase(published_instance_it);
  published_instances_.emplace(instance, std::move(updated_endpoint));

  return total_result;
}

ErrorOr<int> PublisherImpl::DeregisterAll(const std::string& service) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  OSP_DVLOG << "Deregistering all instances";

  int removed_count = 0;
  Error error = Error::None();
  for (auto it = published_instances_.begin();
       it != published_instances_.end();) {
    if (it->second.service_id() == service) {
      for (const auto& mdns_record : GetDnsRecords(it->second)) {
        TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UnregisterRecord");
        auto publisher_error = mdns_publisher_->UnregisterRecord(mdns_record);
        TRACE_SET_RESULT(error);
        if (!publisher_error.ok()) {
          error = publisher_error;
        }
      }
      removed_count++;
      it = published_instances_.erase(it);
    } else {
      it++;
    }
  }

  removed_count += EraseInstancesWithServiceId(&pending_instances_, service);

  if (!error.ok()) {
    return error;
  } else {
    return removed_count;
  }
}

void PublisherImpl::OnDomainFound(const DomainName& requested_name,
                                  const DomainName& confirmed_name) {
  TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  OSP_DVLOG << "Domain successfully claimed: '" << confirmed_name.ToString()
            << "' based on requested name: '" << requested_name.ToString()
            << "'";

  auto it = FindKey(&pending_instances_, InstanceKey(requested_name));

  if (it == pending_instances_.end()) {
    // This will be hit if the instance was deregister'd before the probe phase
    // was completed.
    return;
  }

  DnsSdInstance requested_instance = std::move(it->first);
  DnsSdInstanceEndpoint endpoint =
      CreateEndpoint(requested_instance, *network_config_);
  Client* const client = it->second;
  pending_instances_.erase(it);

  InstanceKey requested_key(requested_instance);

  if (requested_name != confirmed_name) {
    OSP_DCHECK(HasValidDnsRecordAddress(confirmed_name));
    endpoint =
        UpdateDomain(confirmed_name, requested_instance, *network_config_);
  }

  for (const auto& mdns_record : GetDnsRecords(endpoint)) {
    TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.RegisterRecord");
    Error result = mdns_publisher_->RegisterRecord(mdns_record);
    if (!result.ok()) {
      reporting_client_->OnRecoverableError(
          Error(Error::Code::kRecordPublicationError, result.ToString()));
    }
  }

  auto pair = published_instances_.emplace(std::move(requested_instance),
                                           std::move(endpoint));
  client->OnEndpointClaimed(pair.first->first, pair.first->second);
}

}  // namespace discovery
}  // namespace openscreen
