# Copyright 2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests server and client side metadata API."""

import logging
import unittest
import weakref

import grpc
from grpc import _channel

from tests.unit import test_common
from tests.unit.framework.common import test_constants

_CHANNEL_ARGS = (
    ("grpc.primary_user_agent", "primary-agent"),
    ("grpc.secondary_user_agent", "secondary-agent"),
)

_REQUEST = b"\x00\x00\x00"
_RESPONSE = b"\x00\x00\x00"

_UNARY_UNARY = "/test/UnaryUnary"
_UNARY_STREAM = "/test/UnaryStream"
_STREAM_UNARY = "/test/StreamUnary"
_STREAM_STREAM = "/test/StreamStream"

_INVOCATION_METADATA = (
    (
        b"invocation-md-key",
        "invocation-md-value",
    ),
    (
        "invocation-md-key-bin",
        b"\x00\x01",
    ),
)
_EXPECTED_INVOCATION_METADATA = (
    (
        "invocation-md-key",
        "invocation-md-value",
    ),
    (
        "invocation-md-key-bin",
        b"\x00\x01",
    ),
)

_INITIAL_METADATA = (
    (b"initial-md-key", "initial-md-value"),
    ("initial-md-key-bin", b"\x00\x02"),
)
_EXPECTED_INITIAL_METADATA = (
    (
        "initial-md-key",
        "initial-md-value",
    ),
    (
        "initial-md-key-bin",
        b"\x00\x02",
    ),
)

_TRAILING_METADATA = (
    (
        "server-trailing-md-key",
        "server-trailing-md-value",
    ),
    (
        "server-trailing-md-key-bin",
        b"\x00\x03",
    ),
)
_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA


def _user_agent(metadata):
    for key, val in metadata:
        if key == "user-agent":
            return val
    raise KeyError("No user agent!")


def validate_client_metadata(test, servicer_context):
    invocation_metadata = servicer_context.invocation_metadata()
    test.assertTrue(
        test_common.metadata_transmitted(
            _EXPECTED_INVOCATION_METADATA, invocation_metadata
        )
    )
    user_agent = _user_agent(invocation_metadata)
    test.assertTrue(
        user_agent.startswith("primary-agent " + _channel._USER_AGENT)
    )
    test.assertTrue(user_agent.endswith("secondary-agent"))


def handle_unary_unary(test, request, servicer_context):
    validate_client_metadata(test, servicer_context)
    servicer_context.send_initial_metadata(_INITIAL_METADATA)
    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
    return _RESPONSE


def handle_unary_stream(test, request, servicer_context):
    validate_client_metadata(test, servicer_context)
    servicer_context.send_initial_metadata(_INITIAL_METADATA)
    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
    for _ in range(test_constants.STREAM_LENGTH):
        yield _RESPONSE


def handle_stream_unary(test, request_iterator, servicer_context):
    validate_client_metadata(test, servicer_context)
    servicer_context.send_initial_metadata(_INITIAL_METADATA)
    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
    # TODO(issue:#6891) We should be able to remove this loop
    for request in request_iterator:
        pass
    return _RESPONSE


def handle_stream_stream(test, request_iterator, servicer_context):
    validate_client_metadata(test, servicer_context)
    servicer_context.send_initial_metadata(_INITIAL_METADATA)
    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
    # TODO(issue:#6891) We should be able to remove this loop,
    # and replace with return; yield
    for request in request_iterator:
        yield _RESPONSE


class _MethodHandler(grpc.RpcMethodHandler):
    def __init__(self, test, request_streaming, response_streaming):
        self.request_streaming = request_streaming
        self.response_streaming = response_streaming
        self.request_deserializer = None
        self.response_serializer = None
        self.unary_unary = None
        self.unary_stream = None
        self.stream_unary = None
        self.stream_stream = None
        if self.request_streaming and self.response_streaming:
            self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
        elif self.request_streaming:
            self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
        elif self.response_streaming:
            self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
        else:
            self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)


class _GenericHandler(grpc.GenericRpcHandler):
    def __init__(self, test):
        self._test = test

    def service(self, handler_call_details):
        if handler_call_details.method == _UNARY_UNARY:
            return _MethodHandler(self._test, False, False)
        elif handler_call_details.method == _UNARY_STREAM:
            return _MethodHandler(self._test, False, True)
        elif handler_call_details.method == _STREAM_UNARY:
            return _MethodHandler(self._test, True, False)
        elif handler_call_details.method == _STREAM_STREAM:
            return _MethodHandler(self._test, True, True)
        else:
            return None


class MetadataTest(unittest.TestCase):
    def setUp(self):
        self._server = test_common.test_server()
        self._server.add_generic_rpc_handlers(
            (_GenericHandler(weakref.proxy(self)),)
        )
        port = self._server.add_insecure_port("[::]:0")
        self._server.start()
        self._channel = grpc.insecure_channel(
            "localhost:%d" % port, options=_CHANNEL_ARGS
        )

    def tearDown(self):
        self._server.stop(0)
        self._channel.close()

    def testUnaryUnary(self):
        multi_callable = self._channel.unary_unary(
            _UNARY_UNARY, _registered_method=True
        )
        unused_response, call = multi_callable.with_call(
            _REQUEST, metadata=_INVOCATION_METADATA
        )
        self.assertTrue(
            test_common.metadata_transmitted(
                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
            )
        )
        self.assertTrue(
            test_common.metadata_transmitted(
                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
            )
        )

    def testUnaryStream(self):
        multi_callable = self._channel.unary_stream(
            _UNARY_STREAM, _registered_method=True
        )
        call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
        self.assertTrue(
            test_common.metadata_transmitted(
                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
            )
        )
        for _ in call:
            pass
        self.assertTrue(
            test_common.metadata_transmitted(
                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
            )
        )

    def testStreamUnary(self):
        multi_callable = self._channel.stream_unary(
            _STREAM_UNARY, _registered_method=True
        )
        unused_response, call = multi_callable.with_call(
            iter([_REQUEST] * test_constants.STREAM_LENGTH),
            metadata=_INVOCATION_METADATA,
        )
        self.assertTrue(
            test_common.metadata_transmitted(
                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
            )
        )
        self.assertTrue(
            test_common.metadata_transmitted(
                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
            )
        )

    def testStreamStream(self):
        multi_callable = self._channel.stream_stream(
            _STREAM_STREAM, _registered_method=True
        )
        call = multi_callable(
            iter([_REQUEST] * test_constants.STREAM_LENGTH),
            metadata=_INVOCATION_METADATA,
        )
        self.assertTrue(
            test_common.metadata_transmitted(
                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
            )
        )
        for _ in call:
            pass
        self.assertTrue(
            test_common.metadata_transmitted(
                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
            )
        )


if __name__ == "__main__":
    logging.basicConfig()
    unittest.main(verbosity=2)
