# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An in-process federated compute server."""

import contextlib
import gzip
import http.server
import socket
import socketserver
import ssl
from typing import Optional

from absl import logging

from fcp.demo import aggregations
from fcp.demo import eligibility_eval_tasks
from fcp.demo import http_actions
from fcp.demo import media
from fcp.demo import plan_utils
from fcp.demo import task_assignments
from fcp.protos import plan_pb2
from fcp.protos.federatedcompute import common_pb2
from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2

_TaskAssignmentMode = (
    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
)

# Template for file name for federated select slices. See
# `FederatedSelectUriInfo.uri_template` for the meaning of the "{served_at_id}"
# and "{key_base10}" substrings.
_FEDERATED_SELECT_NAME_TEMPLATE = '{served_at_id}_{key_base10}'

# Content type used for serialized and compressed Plan messages.
_PLAN_CONTENT_TYPE = 'application/x-protobuf+gzip'

# Content type used for serialzied and compressed TensorFlow checkpoints.
_CHECKPOINT_CONTENT_TYPE = 'application/octet-stream+gzip'


class InProcessServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
  """An in-process HTTP server implementing the Federated Compute protocol."""

  def __init__(self,
               *,
               population_name: str,
               host: str,
               port: int,
               address_family: Optional[socket.AddressFamily] = None):
    self._media_service = media.Service(self._get_forwarding_info)
    self._aggregations_service = aggregations.Service(self._get_forwarding_info,
                                                      self._media_service)
    self._task_assignments_service = task_assignments.Service(
        population_name, self._get_forwarding_info, self._aggregations_service)
    self._eligibility_eval_tasks_service = eligibility_eval_tasks.Service(
        population_name, self._get_forwarding_info
    )
    handler = http_actions.create_handler(
        self._media_service,
        self._aggregations_service,
        self._task_assignments_service,
        self._eligibility_eval_tasks_service,
    )
    if address_family is not None:
      self.address_family = address_family
    http.server.HTTPServer.__init__(self, (host, port), handler)

  async def run_computation(
      self,
      task_name: str,
      plan: plan_pb2.Plan,
      server_checkpoint: bytes,
      task_assignment_mode: _TaskAssignmentMode,
      number_of_clients: int,
  ) -> bytes:
    """Runs a computation, returning the resulting checkpoint.

    If there's already a computation in progress, the new computation will
    not start until the previous one has completed (either successfully or not).

    Args:
      task_name: The name of the task.
      plan: The Plan proto containing the client and server computations.
      server_checkpoint: The starting server checkpoint.
      task_assignment_mode: The task assignment mode to use for the computation.
      number_of_clients: The minimum number of clients to include.

    Returns:
      A TensorFlow checkpoint containing the aggregated results.
    """
    requirements = aggregations.AggregationRequirements(
        minimum_clients_in_server_published_aggregate=number_of_clients,
        plan=plan)
    session_id = self._aggregations_service.create_session(requirements)
    with contextlib.ExitStack() as stack:
      stack.callback(
          lambda: self._aggregations_service.abort_session(session_id))
      with plan_utils.Session(plan, server_checkpoint) as session:
        with self._media_service.create_download_group() as group:
          plan_url = group.add(
              'plan',
              gzip.compress(session.client_plan),
              content_type=_PLAN_CONTENT_TYPE,
          )
          checkpoint_url = group.add(
              'checkpoint',
              gzip.compress(session.client_checkpoint),
              content_type=_CHECKPOINT_CONTENT_TYPE,
          )
          for served_at_id, slices in session.slices.items():
            for i, slice_data in enumerate(slices):
              group.add(
                  _FEDERATED_SELECT_NAME_TEMPLATE.format(
                      served_at_id=served_at_id, key_base10=str(i)
                  ),
                  gzip.compress(slice_data),
                  content_type=_CHECKPOINT_CONTENT_TYPE,
              )
          self._eligibility_eval_tasks_service.add_task(
              task_name, task_assignment_mode
          )
          self._task_assignments_service.add_task(
              task_name,
              task_assignment_mode,
              session_id,
              common_pb2.Resource(uri=plan_url),
              common_pb2.Resource(uri=checkpoint_url),
              group.prefix + _FEDERATED_SELECT_NAME_TEMPLATE,
          )
          try:
            status = await self._aggregations_service.wait(
                session_id,
                num_inputs_aggregated_and_included=number_of_clients)
            if status.status != aggregations.AggregationStatus.PENDING:
              raise ValueError('Aggregation failed.')
          finally:
            self._task_assignments_service.remove_task(session_id)
            self._eligibility_eval_tasks_service.remove_task(task_name)

          stack.pop_all()
          status, intermedia_update = (
              self._aggregations_service.complete_session(session_id))
          if (status.status != aggregations.AggregationStatus.COMPLETED or
              intermedia_update is None):
            raise ValueError('Aggregation failed.')
        logging.debug('%s aggregation complete: %s', task_name, status)
        return session.finalize(intermedia_update)

  def _get_forwarding_info(self) -> common_pb2.ForwardingInfo:
    protocol = 'https' if isinstance(self.socket, ssl.SSLSocket) else 'http'
    return common_pb2.ForwardingInfo(
        target_uri_prefix=(
            f'{protocol}://{self.server_name}:{self.server_port}/'))
