# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Python implementation of the GRPC interoperability test client."""

import os

from absl import app
from absl.flags import argparse_flags
from google import auth as google_auth
from google.auth import jwt as google_auth_jwt
import grpc

from src.proto.grpc.testing import test_pb2_grpc
from tests.interop import methods
from tests.interop import resources


def parse_interop_client_args(argv):
    parser = argparse_flags.ArgumentParser()
    parser.add_argument(
        "--server_host",
        default="localhost",
        type=str,
        help="the host to which to connect",
    )
    parser.add_argument(
        "--server_port",
        type=int,
        required=True,
        help="the port to which to connect",
    )
    parser.add_argument(
        "--test_case",
        default="large_unary",
        type=str,
        help="the test case to execute",
    )
    parser.add_argument(
        "--use_tls",
        default=False,
        type=resources.parse_bool,
        help="require a secure connection",
    )
    parser.add_argument(
        "--use_alts",
        default=False,
        type=resources.parse_bool,
        help="require an ALTS secure connection",
    )
    parser.add_argument(
        "--use_test_ca",
        default=False,
        type=resources.parse_bool,
        help="replace platform root CAs with ca.pem",
    )
    parser.add_argument(
        "--custom_credentials_type",
        choices=["compute_engine_channel_creds"],
        default=None,
        help="use google default credentials",
    )
    parser.add_argument(
        "--server_host_override",
        type=str,
        help="the server host to which to claim to connect",
    )
    parser.add_argument(
        "--oauth_scope", type=str, help="scope for OAuth tokens"
    )
    parser.add_argument(
        "--default_service_account",
        type=str,
        help="email address of the default service account",
    )
    parser.add_argument(
        "--grpc_test_use_grpclb_with_child_policy",
        type=str,
        help=(
            "If non-empty, set a static service config on channels created by "
            + "grpc::CreateTestChannel, that configures the grpclb LB policy "
            + "with a child policy being the value of this flag (e.g."
            " round_robin " + "or pick_first)."
        ),
    )
    return parser.parse_args(argv[1:])


def _create_call_credentials(args):
    if args.test_case == "oauth2_auth_token":
        google_credentials, unused_project_id = google_auth.default(
            scopes=[args.oauth_scope]
        )
        google_credentials.refresh(google_auth.transport.requests.Request())
        return grpc.access_token_call_credentials(google_credentials.token)
    elif args.test_case == "compute_engine_creds":
        google_credentials, unused_project_id = google_auth.default(
            scopes=[args.oauth_scope]
        )
        return grpc.metadata_call_credentials(
            google_auth.transport.grpc.AuthMetadataPlugin(
                credentials=google_credentials,
                request=google_auth.transport.requests.Request(),
            )
        )
    elif args.test_case == "jwt_token_creds":
        google_credentials = (
            google_auth_jwt.OnDemandCredentials.from_service_account_file(
                os.environ[google_auth.environment_vars.CREDENTIALS]
            )
        )
        return grpc.metadata_call_credentials(
            google_auth.transport.grpc.AuthMetadataPlugin(
                credentials=google_credentials, request=None
            )
        )
    else:
        return None


def get_secure_channel_parameters(args):
    call_credentials = _create_call_credentials(args)

    channel_opts = ()
    if args.grpc_test_use_grpclb_with_child_policy:
        channel_opts += (
            (
                "grpc.service_config",
                '{"loadBalancingConfig": [{"grpclb": {"childPolicy": [{"%s":'
                " {}}]}}]}" % args.grpc_test_use_grpclb_with_child_policy,
            ),
        )
    if args.custom_credentials_type is not None:
        if args.custom_credentials_type == "compute_engine_channel_creds":
            assert call_credentials is None
            google_credentials, unused_project_id = google_auth.default(
                scopes=[args.oauth_scope]
            )
            call_creds = grpc.metadata_call_credentials(
                google_auth.transport.grpc.AuthMetadataPlugin(
                    credentials=google_credentials,
                    request=google_auth.transport.requests.Request(),
                )
            )
            channel_credentials = grpc.compute_engine_channel_credentials(
                call_creds
            )
        else:
            raise ValueError(
                "Unknown credentials type '{}'".format(
                    args.custom_credentials_type
                )
            )
    elif args.use_tls:
        if args.use_test_ca:
            root_certificates = resources.test_root_certificates()
        else:
            root_certificates = None  # will load default roots.

        channel_credentials = grpc.ssl_channel_credentials(root_certificates)
        if call_credentials is not None:
            channel_credentials = grpc.composite_channel_credentials(
                channel_credentials, call_credentials
            )

        if args.server_host_override:
            channel_opts += (
                (
                    "grpc.ssl_target_name_override",
                    args.server_host_override,
                ),
            )
    elif args.use_alts:
        channel_credentials = grpc.alts_channel_credentials()

    return channel_credentials, channel_opts


def _create_channel(args):
    target = "{}:{}".format(args.server_host, args.server_port)

    if (
        args.use_tls
        or args.use_alts
        or args.custom_credentials_type is not None
    ):
        channel_credentials, options = get_secure_channel_parameters(args)
        return grpc.secure_channel(target, channel_credentials, options)
    else:
        return grpc.insecure_channel(target)


def create_stub(channel, args):
    if args.test_case == "unimplemented_service":
        return test_pb2_grpc.UnimplementedServiceStub(channel)
    else:
        return test_pb2_grpc.TestServiceStub(channel)


def _test_case_from_arg(test_case_arg):
    for test_case in methods.TestCase:
        if test_case_arg == test_case.value:
            return test_case
    else:
        raise ValueError('No test case "%s"!' % test_case_arg)


def test_interoperability(args):
    channel = _create_channel(args)
    stub = create_stub(channel, args)
    test_case = _test_case_from_arg(args.test_case)
    test_case.test_interoperability(stub, args)


if __name__ == "__main__":
    app.run(test_interoperability, flags_parser=parse_interop_client_args)
