# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for task_assignments."""

import http
from unittest import mock
import uuid

from absl.testing import absltest

from google.rpc import code_pb2
from fcp.demo import aggregations
from fcp.demo import http_actions
from fcp.demo import task_assignments
from fcp.protos.federatedcompute import common_pb2
from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
from fcp.protos.federatedcompute import task_assignments_pb2

_TaskAssignmentMode = (
    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
)

POPULATION_NAME = 'test/population'
FORWARDING_INFO = common_pb2.ForwardingInfo(
    target_uri_prefix='https://forwarding.example/')


class TaskAssignmentsTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.mock_aggregations_service = self.enter_context(
        mock.patch.object(aggregations, 'Service', autospec=True))
    self.mock_aggregations_service.pre_authorize_clients.return_value = ['']

  def test_start_task_assignment_with_wrong_population(self):
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)
    request = task_assignments_pb2.StartTaskAssignmentRequest(
        population_name='other/population', session_id='session-id')
    with self.assertRaises(http_actions.HttpError) as cm:
      service.start_task_assignment(request)
    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)

  @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
  def test_start_task_assignment_with_no_tasks(self, mock_uuid):
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)
    request = task_assignments_pb2.StartTaskAssignmentRequest(
        population_name=POPULATION_NAME, session_id='session-id')
    operation = service.start_task_assignment(request)
    self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}')
    self.assertTrue(operation.done)

    metadata = task_assignments_pb2.StartTaskAssignmentMetadata()
    operation.metadata.Unpack(metadata)
    self.assertEqual(metadata,
                     task_assignments_pb2.StartTaskAssignmentMetadata())

    response = task_assignments_pb2.StartTaskAssignmentResponse()
    operation.response.Unpack(response)
    self.assertEqual(
        response,
        task_assignments_pb2.StartTaskAssignmentResponse(
            rejection_info=common_pb2.RejectionInfo()
        ),
    )

  def test_start_task_assignment_with_multiple_assignment_task(self):
    service = task_assignments.Service(
        POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
    )
    service.add_task(
        'task',
        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
        'aggregation-session',
        common_pb2.Resource(uri='https://task.example/plan'),
        common_pb2.Resource(uri='https://task.example/checkpoint'),
        'https://task.example/{key_base10}',
    )

    request = task_assignments_pb2.StartTaskAssignmentRequest(
        population_name=POPULATION_NAME, session_id='session-id'
    )
    operation = service.start_task_assignment(request)
    self.assertTrue(operation.done)

    response = task_assignments_pb2.StartTaskAssignmentResponse()
    operation.response.Unpack(response)
    self.assertEqual(
        response,
        task_assignments_pb2.StartTaskAssignmentResponse(
            rejection_info=common_pb2.RejectionInfo()))

  @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
  def test_start_task_assignment_with_one_task(self, mock_uuid):
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)

    self.mock_aggregations_service.pre_authorize_clients.return_value = [
        'token'
    ]

    task_plan = common_pb2.Resource(uri='https://task.example/plan')
    task_checkpoint = common_pb2.Resource(uri='https://task.example/checkpoint')
    task_federated_select_uri_template = 'https://task.example/{key_base10}'
    service.add_task(
        'task',
        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
        'aggregation-session',
        task_plan,
        task_checkpoint,
        task_federated_select_uri_template,
    )

    request = task_assignments_pb2.StartTaskAssignmentRequest(
        population_name=POPULATION_NAME, session_id='session-id')
    operation = service.start_task_assignment(request)
    self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}')
    self.assertTrue(operation.done)

    metadata = task_assignments_pb2.StartTaskAssignmentMetadata()
    operation.metadata.Unpack(metadata)
    self.assertEqual(metadata,
                     task_assignments_pb2.StartTaskAssignmentMetadata())

    response = task_assignments_pb2.StartTaskAssignmentResponse()
    operation.response.Unpack(response)
    self.assertEqual(
        response,
        task_assignments_pb2.StartTaskAssignmentResponse(
            task_assignment=task_assignments_pb2.TaskAssignment(
                aggregation_data_forwarding_info=FORWARDING_INFO,
                aggregation_info=(
                    task_assignments_pb2.TaskAssignment.AggregationInfo()
                ),
                session_id=request.session_id,
                aggregation_id='aggregation-session',
                authorization_token='token',
                task_name='task',
                plan=task_plan,
                init_checkpoint=task_checkpoint,
                federated_select_uri_info=(
                    task_assignments_pb2.FederatedSelectUriInfo(
                        uri_template=task_federated_select_uri_template
                    )
                ),
            )
        ),
    )

    self.mock_aggregations_service.pre_authorize_clients.assert_called_once_with(
        'aggregation-session', num_tokens=1)

  def test_start_task_assignment_with_multiple_tasks(self):
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)

    self.mock_aggregations_service.pre_authorize_clients.return_value = [
        'token'
    ]

    task1_plan = common_pb2.Resource(uri='https://task1.example/plan')
    task1_checkpoint = common_pb2.Resource(
        uri='https://task1.example/checkpoint')
    task1_federated_select_uri_template = 'https://task1.example/{key_base10}'
    service.add_task(
        'task1',
        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
        'aggregation-session1',
        task1_plan,
        task1_checkpoint,
        task1_federated_select_uri_template,
    )
    task2_plan = common_pb2.Resource(uri='https://task2.example/plan')
    task2_checkpoint = common_pb2.Resource(
        uri='https://task2.example/checkpoint')
    task2_federated_select_uri_template = 'https://task2.example/{key_base10}'
    service.add_task(
        'task2',
        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
        'aggregation-session2',
        task2_plan,
        task2_checkpoint,
        task2_federated_select_uri_template,
    )

    request = task_assignments_pb2.StartTaskAssignmentRequest(
        population_name=POPULATION_NAME, session_id='session-id')

    # Initially, task1 should be used.
    operation = service.start_task_assignment(request)
    response = task_assignments_pb2.StartTaskAssignmentResponse()
    operation.response.Unpack(response)
    self.assertEqual(
        response,
        task_assignments_pb2.StartTaskAssignmentResponse(
            task_assignment=task_assignments_pb2.TaskAssignment(
                aggregation_data_forwarding_info=FORWARDING_INFO,
                aggregation_info=(
                    task_assignments_pb2.TaskAssignment.AggregationInfo()
                ),
                session_id=request.session_id,
                aggregation_id='aggregation-session1',
                authorization_token='token',
                task_name='task1',
                plan=task1_plan,
                init_checkpoint=task1_checkpoint,
                federated_select_uri_info=(
                    task_assignments_pb2.FederatedSelectUriInfo(
                        uri_template=task1_federated_select_uri_template
                    )
                ),
            )
        ),
    )
    self.mock_aggregations_service.pre_authorize_clients.assert_called_with(
        'aggregation-session1', num_tokens=1)

    # After task1 is removed, task2 should be used.
    service.remove_task('aggregation-session1')
    operation = service.start_task_assignment(request)
    response = task_assignments_pb2.StartTaskAssignmentResponse()
    operation.response.Unpack(response)
    self.assertEqual(
        response,
        task_assignments_pb2.StartTaskAssignmentResponse(
            task_assignment=task_assignments_pb2.TaskAssignment(
                aggregation_data_forwarding_info=FORWARDING_INFO,
                aggregation_info=(
                    task_assignments_pb2.TaskAssignment.AggregationInfo()
                ),
                session_id=request.session_id,
                aggregation_id='aggregation-session2',
                authorization_token='token',
                task_name='task2',
                plan=task2_plan,
                init_checkpoint=task2_checkpoint,
                federated_select_uri_info=(
                    task_assignments_pb2.FederatedSelectUriInfo(
                        uri_template=task2_federated_select_uri_template
                    )
                ),
            )
        ),
    )
    self.mock_aggregations_service.pre_authorize_clients.assert_called_with(
        'aggregation-session2', num_tokens=1)

    # After task2 is removed, the client should be rejected.
    service.remove_task('aggregation-session2')
    operation = service.start_task_assignment(request)
    response = task_assignments_pb2.StartTaskAssignmentResponse()
    operation.response.Unpack(response)
    self.assertEqual(
        response,
        task_assignments_pb2.StartTaskAssignmentResponse(
            rejection_info=common_pb2.RejectionInfo()))

  def test_perform_multiple_task_assignments_with_wrong_population(self):
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)
    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
        population_name='other/population',
        session_id='session-id',
        task_names=['task1', 'task2', 'task3'])
    with self.assertRaises(http_actions.HttpError) as cm:
      service.perform_multiple_task_assignments(request)
    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)

  def test_perform_multiple_task_assignments_without_tasks(self):
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)

    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
        population_name=POPULATION_NAME,
        session_id='session-id',
        task_names=['task1', 'task2', 'task3'])
    self.assertEqual(
        service.perform_multiple_task_assignments(request),
        task_assignments_pb2.PerformMultipleTaskAssignmentsResponse())

  def test_perform_multiple_task_assignments_with_multiple_tasks(self):
    self.mock_aggregations_service.pre_authorize_clients.side_effect = (
        lambda session_id, num_tokens=1: [f'token-for-{session_id}'])
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)

    task1_plan = common_pb2.Resource(uri='https://task1.example/plan')
    task1_checkpoint = common_pb2.Resource(
        uri='https://task1.example/checkpoint')
    task1_federated_select_uri_template = 'https://task1.example/{key_base10}'
    service.add_task(
        'task1',
        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
        'aggregation-session1',
        task1_plan,
        task1_checkpoint,
        task1_federated_select_uri_template,
    )
    task2_plan = common_pb2.Resource(uri='https://task2.example/plan')
    task2_checkpoint = common_pb2.Resource(
        uri='https://task2.example/checkpoint')
    task2_federated_select_uri_template = 'https://task2.example/{key_base10}'
    service.add_task(
        'task2',
        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
        'aggregation-session2',
        task2_plan,
        task2_checkpoint,
        task2_federated_select_uri_template,
    )
    # Tasks using other TaskAssignmentModes should be skipped.
    task3_plan = common_pb2.Resource(uri='https://task3.example/plan')
    task3_checkpoint = common_pb2.Resource(
        uri='https://task3.example/checkpoint'
    )
    task3_federated_select_uri_template = 'https://task3.example/{key_base10}'
    service.add_task(
        'task3',
        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
        'aggregation-session3',
        task3_plan,
        task3_checkpoint,
        task3_federated_select_uri_template,
    )

    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
        population_name=POPULATION_NAME,
        session_id='session-id',
        task_names=['task1', 'task2', 'task3'])
    self.assertCountEqual(
        service.perform_multiple_task_assignments(request).task_assignments,
        [
            task_assignments_pb2.TaskAssignment(
                aggregation_data_forwarding_info=FORWARDING_INFO,
                aggregation_info=(
                    task_assignments_pb2.TaskAssignment.AggregationInfo()
                ),
                session_id=request.session_id,
                aggregation_id='aggregation-session1',
                authorization_token='token-for-aggregation-session1',
                task_name='task1',
                plan=task1_plan,
                init_checkpoint=task1_checkpoint,
                federated_select_uri_info=(
                    task_assignments_pb2.FederatedSelectUriInfo(
                        uri_template=task1_federated_select_uri_template
                    )
                ),
            ),
            task_assignments_pb2.TaskAssignment(
                aggregation_data_forwarding_info=FORWARDING_INFO,
                aggregation_info=(
                    task_assignments_pb2.TaskAssignment.AggregationInfo()
                ),
                session_id=request.session_id,
                aggregation_id='aggregation-session2',
                authorization_token='token-for-aggregation-session2',
                task_name='task2',
                plan=task2_plan,
                init_checkpoint=task2_checkpoint,
                federated_select_uri_info=(
                    task_assignments_pb2.FederatedSelectUriInfo(
                        uri_template=task2_federated_select_uri_template
                    )
                ),
            ),
            # 'task3' should be omitted since there isn't a corresponding task.
        ],
    )

  def test_add_task_with_invalid_task_assignment_mode(self):
    service = task_assignments.Service(
        POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
    )
    with self.assertRaises(ValueError):
      service.add_task(
          'task',
          _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_UNSPECIFIED,
          'aggregation-session',
          common_pb2.Resource(uri='https://task.example/plan'),
          common_pb2.Resource(uri='https://task.example/checkpoint'),
          'https://task.example/{key_base10}',
      )

  def test_remove_multiple_assignment_task(self):
    service = task_assignments.Service(
        POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
    )
    service.add_task(
        'task',
        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
        'aggregation-session',
        common_pb2.Resource(uri='https://task.example/plan'),
        common_pb2.Resource(uri='https://task.example/checkpoint'),
        'https://task.example/{key_base10}',
    )
    service.remove_task('aggregation-session')

    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
        population_name=POPULATION_NAME,
        session_id='session-id',
        task_names=['task'],
    )
    self.assertEmpty(
        service.perform_multiple_task_assignments(request).task_assignments
    )

  def test_remove_missing_task(self):
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)
    with self.assertRaises(KeyError):
      service.remove_task('does-not-exist')

  def test_report_task_result(self):
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)
    request = task_assignments_pb2.ReportTaskResultRequest(
        population_name=POPULATION_NAME,
        session_id='session-id',
        aggregation_id='aggregation-id',
        computation_status_code=code_pb2.ABORTED)
    self.assertEqual(
        service.report_task_result(request),
        task_assignments_pb2.ReportTaskResultResponse())

  def test_report_task_result_with_wrong_population(self):
    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
                                       self.mock_aggregations_service)
    request = task_assignments_pb2.ReportTaskResultRequest(
        population_name='other/population',
        session_id='session-id',
        aggregation_id='aggregation-id',
        computation_status_code=code_pb2.ABORTED)
    with self.assertRaises(http_actions.HttpError) as cm:
      service.report_task_result(request)
    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)


if __name__ == '__main__':
  absltest.main()
