#!/usr/bin/env python3
# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests using the callback client for pw_rpc."""

import unittest
from unittest import mock
from typing import Any

from pw_protobuf_compiler import python_protos
from pw_status import Status

from pw_rpc import callback_client, client, descriptors, packets
from pw_rpc.internal import packet_pb2

TEST_PROTO_1 = """\
syntax = "proto3";

package pw.test1;

message SomeMessage {
  uint32 magic_number = 1;
}

message AnotherMessage {
  enum Result {
    FAILED = 0;
    FAILED_MISERABLY = 1;
    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
  }

  Result result = 1;
  string payload = 2;
}

service PublicService {
  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
}
"""

CLIENT_CHANNEL_ID: int = 489


def _message_bytes(msg) -> bytes:
    return msg if isinstance(msg, bytes) else msg.SerializeToString()


class _CallbackClientImplTestBase(unittest.TestCase):
    """Supports writing tests that require responses from an RPC server."""

    def setUp(self) -> None:
        self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
        self._request = self._protos.packages.pw.test1.SomeMessage

        self._client = client.Client.from_modules(
            callback_client.Impl(),
            [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)],
            self._protos.modules(),
        )
        self._service = self._client.channel(
            CLIENT_CHANNEL_ID
        ).rpcs.pw.test1.PublicService

        self.requests: list[packet_pb2.RpcPacket] = []
        self._next_packets: list[tuple[bytes, Status]] = []
        self.send_responses_after_packets: float = 1

        self.output_exception: Exception | None = None

    def last_request(self) -> packet_pb2.RpcPacket:
        assert self.requests
        return self.requests[-1]

    def _enqueue_response(
        self,
        channel_id: int = CLIENT_CHANNEL_ID,
        method: descriptors.Method | None = None,
        status: Status = Status.OK,
        payload: bytes = b'',
        *,
        ids: tuple[int, int] | None = None,
        process_status: Status = Status.OK,
        call_id: int = client.OPEN_CALL_ID,
    ) -> None:
        if method:
            assert ids is None
            service_id, method_id = method.service.id, method.id
        else:
            assert ids is not None and method is None
            service_id, method_id = ids

        self._next_packets.append(
            (
                packet_pb2.RpcPacket(
                    type=packet_pb2.PacketType.RESPONSE,
                    channel_id=channel_id,
                    service_id=service_id,
                    method_id=method_id,
                    call_id=call_id,
                    status=status.value,
                    payload=_message_bytes(payload),
                ).SerializeToString(),
                process_status,
            )
        )

    def _enqueue_server_stream(
        self,
        channel_id: int,
        method,
        response,
        process_status=Status.OK,
        call_id: int = client.OPEN_CALL_ID,
    ) -> None:
        self._next_packets.append(
            (
                packet_pb2.RpcPacket(
                    type=packet_pb2.PacketType.SERVER_STREAM,
                    channel_id=channel_id,
                    service_id=method.service.id,
                    method_id=method.id,
                    call_id=call_id,
                    payload=_message_bytes(response),
                ).SerializeToString(),
                process_status,
            )
        )

    def _enqueue_error(
        self,
        channel_id: int,
        service,
        method,
        status: Status,
        process_status=Status.OK,
        call_id: int = client.OPEN_CALL_ID,
    ) -> None:
        self._next_packets.append(
            (
                packet_pb2.RpcPacket(
                    type=packet_pb2.PacketType.SERVER_ERROR,
                    channel_id=channel_id,
                    service_id=service
                    if isinstance(service, int)
                    else service.id,
                    method_id=method if isinstance(method, int) else method.id,
                    call_id=call_id,
                    status=status.value,
                ).SerializeToString(),
                process_status,
            )
        )

    def _handle_packet(self, data: bytes) -> None:
        if self.output_exception:
            raise self.output_exception  # pylint: disable=raising-bad-type

        self.requests.append(packets.decode(data))

        if self.send_responses_after_packets > 1:
            self.send_responses_after_packets -= 1
            return

        self._process_enqueued_packets()

    def _process_enqueued_packets(self) -> None:
        # Set send_responses_after_packets to infinity to prevent potential
        # infinite recursion when a packet causes another packet to send.
        send_after_count = self.send_responses_after_packets
        self.send_responses_after_packets = float('inf')

        for packet, status in self._next_packets:
            self.assertIs(status, self._client.process_packet(packet))

        self._next_packets.clear()
        self.send_responses_after_packets = send_after_count

    def _sent_payload(self, message_type: type) -> Any:
        message = message_type()
        message.ParseFromString(self.last_request().payload)
        return message


# Disable docstring requirements for test functions.
# pylint: disable=missing-function-docstring


class CallbackClientImplTest(_CallbackClientImplTestBase):
    """Tests the callback_client.Impl client implementation."""

    def test_callback_exceptions_suppressed(self) -> None:
        stub = self._service.SomeUnary

        self._enqueue_response(CLIENT_CHANNEL_ID, stub.method)
        exception_msg = 'YOU BROKE IT O-]-<'

        with self.assertLogs(callback_client.__package__, 'ERROR') as logs:
            stub.invoke(
                self._request(), mock.Mock(side_effect=Exception(exception_msg))
            )

        self.assertIn(exception_msg, ''.join(logs.output))

        # Make sure we can still invoke the RPC.
        self._enqueue_response(CLIENT_CHANNEL_ID, stub.method, Status.UNKNOWN)
        status, _ = stub()
        self.assertIs(status, Status.UNKNOWN)

    def test_ignore_bad_packets_with_pending_rpc(self) -> None:
        method = self._service.SomeUnary.method
        service_id = method.service.id

        # Unknown channel
        self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
        # Bad service
        self._enqueue_response(
            CLIENT_CHANNEL_ID, ids=(999, method.id), process_status=Status.OK
        )
        # Bad method
        self._enqueue_response(
            CLIENT_CHANNEL_ID, ids=(service_id, 999), process_status=Status.OK
        )
        # For RPC not pending (is Status.OK because the packet is processed)
        self._enqueue_response(
            CLIENT_CHANNEL_ID,
            ids=(service_id, self._service.SomeBidiStreaming.method.id),
            process_status=Status.OK,
        )

        self._enqueue_response(
            CLIENT_CHANNEL_ID, method, process_status=Status.OK
        )

        status, response = self._service.SomeUnary(magic_number=6)
        self.assertIs(Status.OK, status)
        self.assertEqual('', response.payload)

    def test_server_error_for_unknown_call_sends_no_errors(self) -> None:
        method = self._service.SomeUnary.method
        service_id = method.service.id

        # Unknown channel
        self._enqueue_error(
            999,
            service_id,
            method,
            Status.NOT_FOUND,
            process_status=Status.NOT_FOUND,
        )
        # Bad service
        self._enqueue_error(
            CLIENT_CHANNEL_ID, 999, method.id, Status.INVALID_ARGUMENT
        )
        # Bad method
        self._enqueue_error(
            CLIENT_CHANNEL_ID, service_id, 999, Status.INVALID_ARGUMENT
        )
        # For RPC not pending
        self._enqueue_error(
            CLIENT_CHANNEL_ID,
            service_id,
            self._service.SomeBidiStreaming.method.id,
            Status.NOT_FOUND,
        )

        self._process_enqueued_packets()

        self.assertEqual(self.requests, [])

    def test_exception_if_payload_fails_to_decode(self) -> None:
        method = self._service.SomeUnary.method

        self._enqueue_response(
            CLIENT_CHANNEL_ID,
            method,
            Status.OK,
            b'INVALID DATA!!!',
            process_status=Status.OK,
        )

        with self.assertRaises(callback_client.RpcError) as context:
            self._service.SomeUnary(magic_number=6)

        self.assertIs(context.exception.status, Status.DATA_LOSS)

    def test_rpc_help_contains_method_name(self) -> None:
        rpc = self._service.SomeUnary
        self.assertIn(rpc.method.full_name, rpc.help())

    def test_default_timeouts_set_on_impl(self) -> None:
        impl = callback_client.Impl(None, 1.5)

        self.assertEqual(impl.default_unary_timeout_s, None)
        self.assertEqual(impl.default_stream_timeout_s, 1.5)

    def test_default_timeouts_set_for_all_rpcs(self) -> None:
        rpc_client = client.Client.from_modules(
            callback_client.Impl(99, 100),
            [client.Channel(CLIENT_CHANNEL_ID, lambda *a, **b: None)],
            self._protos.modules(),
        )
        rpcs = rpc_client.channel(CLIENT_CHANNEL_ID).rpcs

        self.assertEqual(
            rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99
        )
        self.assertEqual(
            rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
            100,
        )
        self.assertEqual(
            rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s,
            99,
        )
        self.assertEqual(
            rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100
        )

    def test_rpc_provides_request_type(self) -> None:
        self.assertIs(
            self._service.SomeUnary.request,
            self._service.SomeUnary.method.request_type,
        )

    def test_rpc_provides_response_type(self) -> None:
        self.assertIs(
            self._service.SomeUnary.request,
            self._service.SomeUnary.method.request_type,
        )


class UnaryTest(_CallbackClientImplTestBase):
    """Tests for invoking a unary RPC."""

    def setUp(self) -> None:
        super().setUp()
        self.rpc = self._service.SomeUnary
        self.method = self.rpc.method

    def test_blocking_call(self) -> None:
        for _ in range(3):
            self._enqueue_response(
                CLIENT_CHANNEL_ID,
                self.method,
                Status.ABORTED,
                self.method.response_type(payload='0_o'),
            )

            status, response = self._service.SomeUnary(
                self.method.request_type(magic_number=6)
            )

            self.assertEqual(
                6, self._sent_payload(self.method.request_type).magic_number
            )

            self.assertIs(Status.ABORTED, status)
            self.assertEqual('0_o', response.payload)

    def test_nonblocking_call(self) -> None:
        for _ in range(3):
            callback = mock.Mock()
            call = self.rpc.invoke(
                self._request(magic_number=5), callback, callback
            )

            self._enqueue_response(
                CLIENT_CHANNEL_ID,
                self.method,
                Status.ABORTED,
                self.method.response_type(payload='0_o'),
                call_id=call.call_id,
            )
            self._process_enqueued_packets()

            callback.assert_has_calls(
                [
                    mock.call(call, self.method.response_type(payload='0_o')),
                    mock.call(call, Status.ABORTED),
                ]
            )

            self.assertEqual(
                5, self._sent_payload(self.method.request_type).magic_number
            )

    def test_concurrent_nonblocking_calls(self) -> None:
        # Start several calls to the same method
        callbacks_and_calls: list[
            tuple[mock.Mock, callback_client.call.Call]
        ] = []
        for _ in range(3):
            callback = mock.Mock()
            call = self.rpc.invoke(self._request(magic_number=5), callback)
            callbacks_and_calls.append((callback, call))

        # Respond only to the last call
        last_callback, last_call = callbacks_and_calls.pop()
        last_payload = self.method.response_type(payload='last payload')
        self._enqueue_response(
            CLIENT_CHANNEL_ID,
            self.method,
            payload=last_payload,
            call_id=last_call.call_id,
        )
        self._process_enqueued_packets()

        # Assert that only the last caller received a response
        last_callback.assert_called_once_with(last_call, last_payload)
        for remaining_callback, _ in callbacks_and_calls:
            remaining_callback.assert_not_called()

        # Respond to the other callers and check for receipt
        other_payload = self.method.response_type(payload='other payload')
        for callback, call in callbacks_and_calls:
            self._enqueue_response(
                CLIENT_CHANNEL_ID,
                self.method,
                payload=other_payload,
                call_id=call.call_id,
            )
            self._process_enqueued_packets()
            callback.assert_called_once_with(call, other_payload)

    def test_open(self) -> None:
        self.output_exception = IOError('something went wrong sending!')

        for _ in range(3):
            self._enqueue_response(
                CLIENT_CHANNEL_ID,
                self.method,
                Status.ABORTED,
                self.method.response_type(payload='0_o'),
            )

            callback = mock.Mock()
            call = self.rpc.open(
                self._request(magic_number=5), callback, callback
            )
            self.assertEqual(self.requests, [])

            self._process_enqueued_packets()

            callback.assert_has_calls(
                [
                    mock.call(call, self.method.response_type(payload='0_o')),
                    mock.call(call, Status.ABORTED),
                ]
            )

    def test_blocking_server_error(self) -> None:
        for _ in range(3):
            self._enqueue_error(
                CLIENT_CHANNEL_ID,
                self.method.service,
                self.method,
                Status.NOT_FOUND,
            )

            with self.assertRaises(callback_client.RpcError) as context:
                self._service.SomeUnary(
                    self.method.request_type(magic_number=6)
                )

            self.assertIs(context.exception.status, Status.NOT_FOUND)

    def test_nonblocking_cancel(self) -> None:
        callback = mock.Mock()

        for _ in range(3):
            call = self._service.SomeUnary.invoke(
                self._request(magic_number=55), callback
            )

            self.assertGreater(len(self.requests), 0)
            self.requests.clear()

            self.assertTrue(call.cancel())
            self.assertFalse(call.cancel())  # Already cancelled, returns False

            self.assertEqual(
                self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR
            )
            self.assertEqual(self.last_request().status, Status.CANCELLED.value)

        callback.assert_not_called()

    def test_nonblocking_with_request_args(self) -> None:
        self.rpc.invoke(request_args=dict(magic_number=1138))
        self.assertEqual(
            self._sent_payload(self.rpc.request).magic_number, 1138
        )

    def test_blocking_timeout_as_argument(self) -> None:
        with self.assertRaises(callback_client.RpcTimeout):
            self._service.SomeUnary(pw_rpc_timeout_s=0.0001)

    def test_blocking_timeout_set_default(self) -> None:
        self._service.SomeUnary.default_timeout_s = 0.0001

        with self.assertRaises(callback_client.RpcTimeout):
            self._service.SomeUnary()

    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
        first_call = self.rpc.invoke()
        self.assertFalse(first_call.completed())

        second_call = self.rpc.invoke()

        self.assertIs(first_call.error, None)
        self.assertIs(second_call.error, None)

    def test_nonblocking_exception_in_callback(self) -> None:
        exception = ValueError('something went wrong! (intentionally)')

        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)

        call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception))

        with self.assertRaises(RuntimeError) as context:
            call.wait()

        self.assertEqual(context.exception.__cause__, exception)

    def test_unary_response(self) -> None:
        proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123)
        self.assertEqual(
            repr(callback_client.UnaryResponse(Status.ABORTED, proto)),
            '(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))',
        )
        self.assertEqual(
            repr(callback_client.UnaryResponse(Status.OK, None)),
            '(Status.OK, None)',
        )

    def test_on_call_hook(self) -> None:
        hook_function = mock.Mock()

        self._client = client.Client.from_modules(
            callback_client.Impl(on_call_hook=hook_function),
            [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)],
            self._protos.modules(),
        )

        self._service = self._client.channel(
            CLIENT_CHANNEL_ID
        ).rpcs.pw.test1.PublicService

        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
        self._service.SomeUnary(self.method.request_type(magic_number=6))

        hook_function.assert_called_once()
        self.assertEqual(
            hook_function.call_args[0][0].method.full_name,
            self.method.full_name,
        )


class ServerStreamingTest(_CallbackClientImplTestBase):
    """Tests for server streaming RPCs."""

    def setUp(self) -> None:
        super().setUp()
        self.rpc = self._service.SomeServerStreaming
        self.method = self.rpc.method

    def test_blocking_call(self) -> None:
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
            self._enqueue_response(
                CLIENT_CHANNEL_ID, self.method, Status.ABORTED
            )

            self.assertEqual(
                [rep1, rep2],
                self._service.SomeServerStreaming(magic_number=4).responses,
            )

            self.assertEqual(
                4, self._sent_payload(self.method.request_type).magic_number
            )

    def test_nonblocking_call(self) -> None:
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
            self._enqueue_response(
                CLIENT_CHANNEL_ID, self.method, Status.ABORTED
            )

            callback = mock.Mock()
            call = self.rpc.invoke(
                self._request(magic_number=3), callback, callback
            )

            callback.assert_has_calls(
                [
                    mock.call(call, self.method.response_type(payload='!!!')),
                    mock.call(call, self.method.response_type(payload='?')),
                    mock.call(call, Status.ABORTED),
                ]
            )

            self.assertEqual(
                3, self._sent_payload(self.method.request_type).magic_number
            )

    def test_open(self) -> None:
        self.output_exception = IOError('something went wrong sending!')
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
            self._enqueue_response(
                CLIENT_CHANNEL_ID, self.method, Status.ABORTED
            )

            callback = mock.Mock()
            call = self.rpc.open(
                self._request(magic_number=3), callback, callback
            )
            self.assertEqual(self.requests, [])

            self._process_enqueued_packets()

            callback.assert_has_calls(
                [
                    mock.call(call, self.method.response_type(payload='!!!')),
                    mock.call(call, self.method.response_type(payload='?')),
                    mock.call(call, Status.ABORTED),
                ]
            )

    def test_nonblocking_cancel(self) -> None:
        resp = self.rpc.method.response_type(payload='!!!')
        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp)

        callback = mock.Mock()
        call = self.rpc.invoke(self._request(magic_number=3), callback)
        callback.assert_called_once_with(
            call, self.rpc.method.response_type(payload='!!!')
        )

        callback.reset_mock()

        call.cancel()

        self.assertEqual(
            self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR
        )
        self.assertEqual(self.last_request().status, Status.CANCELLED.value)

        # Ensure the RPC can be called after being cancelled.
        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp)
        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)

        call = self.rpc.invoke(
            self._request(magic_number=3), callback, callback
        )

        callback.assert_has_calls(
            [
                mock.call(call, self.method.response_type(payload='!!!')),
                mock.call(call, Status.OK),
            ]
        )

    def test_request_completion(self) -> None:
        resp = self.rpc.method.response_type(payload='!!!')
        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp)

        callback = mock.Mock()
        call = self.rpc.invoke(self._request(magic_number=3), callback)
        callback.assert_called_once_with(
            call, self.rpc.method.response_type(payload='!!!')
        )

        callback.reset_mock()

        call.request_completion()

        self.assertEqual(
            self.last_request().type,
            packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
        )

        # Ensure the RPC can be called after being completed.
        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp)
        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)

        call = self.rpc.invoke(
            self._request(magic_number=3), callback, callback
        )

        callback.assert_has_calls(
            [
                mock.call(call, self.method.response_type(payload='!!!')),
                mock.call(call, Status.OK),
            ]
        )

    def test_nonblocking_with_request_args(self) -> None:
        self.rpc.invoke(request_args=dict(magic_number=1138))
        self.assertEqual(
            self._sent_payload(self.rpc.request).magic_number, 1138
        )

    def test_blocking_timeout(self) -> None:
        with self.assertRaises(callback_client.RpcTimeout):
            self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)

    def test_nonblocking_iteration_timeout(self) -> None:
        call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001)
        with self.assertRaises(callback_client.RpcTimeout):
            for _ in call:
                pass

    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
        first_call = self.rpc.invoke()
        self.assertFalse(first_call.completed())

        second_call = self.rpc.invoke()

        self.assertIs(first_call.error, None)
        self.assertIs(second_call.error, None)

    def test_nonblocking_iterate_over_count(self) -> None:
        reply = self.method.response_type(payload='!?')

        for _ in range(4):
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)

        call = self.rpc.invoke()

        self.assertEqual(list(call.get_responses(count=1)), [reply])
        self.assertEqual(next(iter(call)), reply)
        self.assertEqual(list(call.get_responses(count=2)), [reply, reply])

    def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None:
        reply = self.method.response_type(payload='!?')
        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)

        call = self.rpc.invoke()

        self.assertEqual(list(call.get_responses()), [reply])
        self.assertEqual(list(call.get_responses()), [])
        self.assertEqual(list(call), [])


class ClientStreamingTest(_CallbackClientImplTestBase):
    """Tests for client streaming RPCs."""

    def setUp(self) -> None:
        super().setUp()
        self.rpc = self._service.SomeClientStreaming
        self.method = self.rpc.method

    def test_blocking_call(self) -> None:
        requests = [
            self.method.request_type(magic_number=123),
            self.method.request_type(magic_number=456),
        ]

        # Send after len(requests) and the client stream end packet.
        self.send_responses_after_packets = 3
        response = self.method.response_type(payload='yo')
        self._enqueue_response(
            CLIENT_CHANNEL_ID, self.method, Status.OK, response
        )

        results = self.rpc(requests)
        self.assertIs(results.status, Status.OK)
        self.assertEqual(results.response, response)

    def test_blocking_server_error(self) -> None:
        requests = [self.method.request_type(magic_number=123)]

        # Send after len(requests) and the client stream end packet.
        self._enqueue_error(
            CLIENT_CHANNEL_ID,
            self.method.service,
            self.method,
            Status.NOT_FOUND,
        )

        with self.assertRaises(callback_client.RpcError) as context:
            self.rpc(requests)

        self.assertIs(context.exception.status, Status.NOT_FOUND)

    def test_nonblocking_call(self) -> None:
        """Tests a successful client streaming RPC ended by the server."""
        payload_1 = self.method.response_type(payload='-_-')

        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()
            self.assertFalse(stream.completed())

            stream.send(magic_number=31)
            self.assertIs(
                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
            )
            self.assertEqual(
                31, self._sent_payload(self.method.request_type).magic_number
            )
            self.assertFalse(stream.completed())

            # Enqueue the server response to be sent after the next message.
            self._enqueue_response(
                CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1
            )

            stream.send(magic_number=32)
            self.assertIs(
                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
            )
            self.assertEqual(
                32, self._sent_payload(self.method.request_type).magic_number
            )

            self.assertTrue(stream.completed())
            self.assertIs(Status.OK, stream.status)
            self.assertIsNone(stream.error)
            self.assertEqual(payload_1, stream.response)

    def test_open(self) -> None:
        self.output_exception = IOError('something went wrong sending!')
        payload = self.method.response_type(payload='-_-')

        for _ in range(3):
            self._enqueue_response(
                CLIENT_CHANNEL_ID, self.method, Status.OK, payload
            )

            callback = mock.Mock()
            call = self.rpc.open(callback, callback, callback)
            self.assertEqual(self.requests, [])

            self._process_enqueued_packets()

            callback.assert_has_calls(
                [
                    mock.call(call, payload),
                    mock.call(call, Status.OK),
                ]
            )

    def test_nonblocking_finish(self) -> None:
        """Tests a client streaming RPC ended by the client."""
        payload_1 = self.method.response_type(payload='-_-')

        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()
            self.assertFalse(stream.completed())

            stream.send(magic_number=37)
            self.assertIs(
                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
            )
            self.assertEqual(
                37, self._sent_payload(self.method.request_type).magic_number
            )
            self.assertFalse(stream.completed())

            # Enqueue the server response to be sent after the next message.
            self._enqueue_response(
                CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1
            )

            stream.finish_and_wait()
            self.assertIs(
                packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
                self.last_request().type,
            )

            self.assertTrue(stream.completed())
            self.assertIs(Status.OK, stream.status)
            self.assertIsNone(stream.error)
            self.assertEqual(payload_1, stream.response)

    def test_nonblocking_cancel(self) -> None:
        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()
            stream.send(magic_number=37)

            self.assertTrue(stream.cancel())
            self.assertIs(
                packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type
            )
            self.assertIs(Status.CANCELLED.value, self.last_request().status)
            self.assertFalse(stream.cancel())

            self.assertTrue(stream.completed())
            self.assertIs(stream.error, Status.CANCELLED)

    def test_nonblocking_server_error(self) -> None:
        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()

            self._enqueue_error(
                CLIENT_CHANNEL_ID,
                self.method.service,
                self.method,
                Status.INVALID_ARGUMENT,
            )
            stream.send(magic_number=2**32 - 1)

            with self.assertRaises(callback_client.RpcError) as context:
                stream.finish_and_wait()

            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)

    def test_nonblocking_server_error_after_stream_end(self) -> None:
        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()

            # Error will be sent in response to the CLIENT_REQUEST_COMPLETION
            # packet.
            self._enqueue_error(
                CLIENT_CHANNEL_ID,
                self.method.service,
                self.method,
                Status.INVALID_ARGUMENT,
            )

            with self.assertRaises(callback_client.RpcError) as context:
                stream.finish_and_wait()

            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)

    def test_nonblocking_send_after_cancelled(self) -> None:
        call = self._service.SomeClientStreaming.invoke()
        self.assertTrue(call.cancel())

        with self.assertRaises(callback_client.RpcError) as context:
            call.send(payload='hello')

        self.assertIs(context.exception.status, Status.CANCELLED)

    def test_nonblocking_finish_after_completed(self) -> None:
        reply = self.method.response_type(payload='!?')
        self._enqueue_response(
            CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE, reply
        )

        call = self.rpc.invoke()
        result = call.finish_and_wait()
        self.assertEqual(result.response, reply)

        self.assertEqual(result, call.finish_and_wait())
        self.assertEqual(result, call.finish_and_wait())

    def test_nonblocking_finish_after_error(self) -> None:
        self._enqueue_error(
            CLIENT_CHANNEL_ID,
            self.method.service,
            self.method,
            Status.UNAVAILABLE,
        )

        call = self.rpc.invoke()

        for _ in range(3):
            with self.assertRaises(callback_client.RpcError) as context:
                call.finish_and_wait()

            self.assertIs(context.exception.status, Status.UNAVAILABLE)
            self.assertIs(call.error, Status.UNAVAILABLE)
            self.assertIsNone(call.response)

    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
        first_call = self.rpc.invoke()
        self.assertFalse(first_call.completed())

        second_call = self.rpc.invoke()

        self.assertIs(first_call.error, None)
        self.assertIs(second_call.error, None)


class BidirectionalStreamingTest(_CallbackClientImplTestBase):
    """Tests for bidirectional streaming RPCs."""

    def setUp(self) -> None:
        super().setUp()
        self.rpc = self._service.SomeBidiStreaming
        self.method = self.rpc.method

    def test_blocking_call(self) -> None:
        requests = [
            self.method.request_type(magic_number=123),
            self.method.request_type(magic_number=456),
        ]

        # Send after len(requests) and the client stream end packet.
        self.send_responses_after_packets = 3
        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.NOT_FOUND)

        results = self.rpc(requests)
        self.assertIs(results.status, Status.NOT_FOUND)
        self.assertFalse(results.responses)

    def test_blocking_server_error(self) -> None:
        requests = [self.method.request_type(magic_number=123)]

        # Send after len(requests) and the client stream end packet.
        self._enqueue_error(
            CLIENT_CHANNEL_ID,
            self.method.service,
            self.method,
            Status.NOT_FOUND,
        )

        with self.assertRaises(callback_client.RpcError) as context:
            self.rpc(requests)

        self.assertIs(context.exception.status, Status.NOT_FOUND)

    def test_nonblocking_call(self) -> None:
        """Tests a bidirectional streaming RPC ended by the server."""
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            responses: list = []
            stream = self._service.SomeBidiStreaming.invoke(
                lambda _, res, responses=responses: responses.append(res)
            )
            self.assertFalse(stream.completed())

            stream.send(magic_number=55)
            self.assertIs(
                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
            )
            self.assertEqual(
                55, self._sent_payload(self.method.request_type).magic_number
            )
            self.assertFalse(stream.completed())
            self.assertEqual([], responses)

            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)

            stream.send(magic_number=66)
            self.assertIs(
                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
            )
            self.assertEqual(
                66, self._sent_payload(self.method.request_type).magic_number
            )
            self.assertFalse(stream.completed())
            self.assertEqual([rep1, rep2], responses)

            self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)

            stream.send(magic_number=77)
            self.assertTrue(stream.completed())
            self.assertEqual([rep1, rep2], responses)

            self.assertIs(Status.OK, stream.status)
            self.assertIsNone(stream.error)

    def test_open(self) -> None:
        self.output_exception = IOError('something went wrong sending!')
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
            self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)

            callback = mock.Mock()
            call = self.rpc.open(callback, callback, callback)
            self.assertEqual(self.requests, [])

            self._process_enqueued_packets()

            callback.assert_has_calls(
                [
                    mock.call(call, self.method.response_type(payload='!!!')),
                    mock.call(call, self.method.response_type(payload='?')),
                    mock.call(call, Status.OK),
                ]
            )

    @mock.patch('pw_rpc.callback_client.call.Call._default_response')
    def test_nonblocking(self, callback) -> None:
        """Tests a bidirectional streaming RPC ended by the server."""
        reply = self.method.response_type(payload='This is the payload!')
        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)

        self._service.SomeBidiStreaming.invoke()

        callback.assert_called_once_with(mock.ANY, reply)

    def test_nonblocking_server_error(self) -> None:
        rep1 = self.method.response_type(payload='!!!')

        for _ in range(3):
            responses: list = []
            stream = self._service.SomeBidiStreaming.invoke(
                lambda _, res, responses=responses: responses.append(res)
            )
            self.assertFalse(stream.completed())

            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)

            stream.send(magic_number=55)
            self.assertFalse(stream.completed())
            self.assertEqual([rep1], responses)

            self._enqueue_error(
                CLIENT_CHANNEL_ID,
                self.method.service,
                self.method,
                Status.OUT_OF_RANGE,
            )

            stream.send(magic_number=99999)
            self.assertTrue(stream.completed())
            self.assertEqual([rep1], responses)

            self.assertIsNone(stream.status)
            self.assertIs(Status.OUT_OF_RANGE, stream.error)

            with self.assertRaises(callback_client.RpcError) as context:
                stream.finish_and_wait()
            self.assertIs(context.exception.status, Status.OUT_OF_RANGE)

    def test_nonblocking_server_error_after_stream_end(self) -> None:
        for _ in range(3):
            stream = self._service.SomeBidiStreaming.invoke()

            # Error will be sent in response to the CLIENT_REQUEST_COMPLETION
            # packet.
            self._enqueue_error(
                CLIENT_CHANNEL_ID,
                self.method.service,
                self.method,
                Status.INVALID_ARGUMENT,
            )

            with self.assertRaises(callback_client.RpcError) as context:
                stream.finish_and_wait()

            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)

    def test_nonblocking_send_after_cancelled(self) -> None:
        call = self._service.SomeBidiStreaming.invoke()
        self.assertTrue(call.cancel())

        with self.assertRaises(callback_client.RpcError) as context:
            call.send(payload='hello')

        self.assertIs(context.exception.status, Status.CANCELLED)

    def test_nonblocking_finish_after_completed(self) -> None:
        reply = self.method.response_type(payload='!?')
        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
        self._enqueue_response(
            CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE
        )

        call = self.rpc.invoke()
        result = call.finish_and_wait()
        self.assertEqual(result.responses, [reply])

        self.assertEqual(result, call.finish_and_wait())
        self.assertEqual(result, call.finish_and_wait())

    def test_nonblocking_finish_after_error(self) -> None:
        reply = self.method.response_type(payload='!?')
        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
        self._enqueue_error(
            CLIENT_CHANNEL_ID,
            self.method.service,
            self.method,
            Status.UNAVAILABLE,
        )

        call = self.rpc.invoke()

        for _ in range(3):
            with self.assertRaises(callback_client.RpcError) as context:
                call.finish_and_wait()

            self.assertIs(context.exception.status, Status.UNAVAILABLE)
            self.assertIs(call.error, Status.UNAVAILABLE)
            self.assertEqual(call.responses, [reply])

    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
        first_call = self.rpc.invoke()
        self.assertFalse(first_call.completed())

        second_call = self.rpc.invoke()

        self.assertIs(first_call.error, None)
        self.assertIs(second_call.error, None)

    def test_stream_response(self) -> None:
        proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123)
        self.assertEqual(
            repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)),
            '(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), '
            'pw.test1.SomeMessage(magic_number=123)])',
        )
        self.assertEqual(
            repr(callback_client.StreamResponse(Status.OK, [])),
            '(Status.OK, [])',
        )


if __name__ == '__main__':
    unittest.main()
