# Copyright 2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Entry point for running stress tests."""

from concurrent import futures
import queue
import threading

from absl import app
from absl.flags import argparse_flags
import grpc

from src.proto.grpc.testing import metrics_pb2_grpc
from src.proto.grpc.testing import test_pb2_grpc
from tests.interop import methods
from tests.interop import resources
from tests.qps import histogram
from tests.stress import metrics_server
from tests.stress import test_runner


def _args(argv):
    parser = argparse_flags.ArgumentParser()
    parser.add_argument(
        "--server_addresses",
        help="comma separated list of hostname:port to run servers on",
        default="localhost:8080",
        type=str,
    )
    parser.add_argument(
        "--test_cases",
        help="comma separated list of testcase:weighting of tests to run",
        default="large_unary:100",
        type=str,
    )
    parser.add_argument(
        "--test_duration_secs",
        help="number of seconds to run the stress test",
        default=-1,
        type=int,
    )
    parser.add_argument(
        "--num_channels_per_server",
        help="number of channels per server",
        default=1,
        type=int,
    )
    parser.add_argument(
        "--num_stubs_per_channel",
        help="number of stubs to create per channel",
        default=1,
        type=int,
    )
    parser.add_argument(
        "--metrics_port",
        help="the port to listen for metrics requests on",
        default=8081,
        type=int,
    )
    parser.add_argument(
        "--use_test_ca",
        help="Whether to use our fake CA. Requires --use_tls=true",
        default=False,
        type=bool,
    )
    parser.add_argument(
        "--use_tls", help="Whether to use TLS", default=False, type=bool
    )
    parser.add_argument(
        "--server_host_override",
        help="the server host to which to claim to connect",
        type=str,
    )
    return parser.parse_args(argv[1:])


def _test_case_from_arg(test_case_arg):
    for test_case in methods.TestCase:
        if test_case_arg == test_case.value:
            return test_case
    else:
        raise ValueError("No test case {}!".format(test_case_arg))


def _parse_weighted_test_cases(test_case_args):
    weighted_test_cases = {}
    for test_case_arg in test_case_args.split(","):
        name, weight = test_case_arg.split(":", 1)
        test_case = _test_case_from_arg(name)
        weighted_test_cases[test_case] = int(weight)
    return weighted_test_cases


def _get_channel(target, args):
    if args.use_tls:
        if args.use_test_ca:
            root_certificates = resources.test_root_certificates()
        else:
            root_certificates = None  # will load default roots.
        channel_credentials = grpc.ssl_channel_credentials(
            root_certificates=root_certificates
        )
        options = (
            (
                "grpc.ssl_target_name_override",
                args.server_host_override,
            ),
        )
        channel = grpc.secure_channel(
            target, channel_credentials, options=options
        )
    else:
        channel = grpc.insecure_channel(target)

    # waits for the channel to be ready before we start sending messages
    grpc.channel_ready_future(channel).result()
    return channel


def run_test(args):
    test_cases = _parse_weighted_test_cases(args.test_cases)
    test_server_targets = args.server_addresses.split(",")
    # Propagate any client exceptions with a queue
    exception_queue = queue.Queue()
    stop_event = threading.Event()
    hist = histogram.Histogram(1, 1)
    runners = []

    server = grpc.server(futures.ThreadPoolExecutor(max_workers=25))
    metrics_pb2_grpc.add_MetricsServiceServicer_to_server(
        metrics_server.MetricsServer(hist), server
    )
    server.add_insecure_port("[::]:{}".format(args.metrics_port))
    server.start()

    for test_server_target in test_server_targets:
        for _ in range(args.num_channels_per_server):
            channel = _get_channel(test_server_target, args)
            for _ in range(args.num_stubs_per_channel):
                stub = test_pb2_grpc.TestServiceStub(channel)
                runner = test_runner.TestRunner(
                    stub, test_cases, hist, exception_queue, stop_event
                )
                runners.append(runner)

    for runner in runners:
        runner.start()
    try:
        timeout_secs = args.test_duration_secs
        if timeout_secs < 0:
            timeout_secs = None
        raise exception_queue.get(block=True, timeout=timeout_secs)
    except queue.Empty:
        # No exceptions thrown, success
        pass
    finally:
        stop_event.set()
        for runner in runners:
            runner.join()
        runner = None
        server.stop(None)


if __name__ == "__main__":
    app.run(run_test, flags_parser=_args)
