# Copyright 2022 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Test server set up."""

import logging
import os
import sys
import subprocess

from typing import List, Optional, Tuple

from common import DIR_SRC_ROOT, run_ffx_command
from compatible_utils import get_ssh_prefix

sys.path.append(os.path.join(DIR_SRC_ROOT, 'build', 'util', 'lib', 'common'))
# pylint: disable=import-error,wrong-import-position
import chrome_test_server_spawner
# pylint: enable=import-error,wrong-import-position


def port_forward(host_port_pair: str, host_port: int) -> int:
    """Establishes a port forwarding SSH task to a localhost TCP endpoint
    hosted at port |local_port|. Blocks until port forwarding is established.

    Returns the remote port number."""

    ssh_prefix = get_ssh_prefix(host_port_pair)

    # Allow a tunnel to be established.
    subprocess.run(ssh_prefix + ['echo', 'true'], check=True)

    forward_cmd = [
        '-O',
        'forward',  # Send SSH mux control signal.
        '-R',
        '0:localhost:%d' % host_port,
        '-v',  # Get forwarded port info from stderr.
        '-NT'  # Don't execute command; don't allocate terminal.
    ]
    forward_proc = subprocess.run(ssh_prefix + forward_cmd,
                                  capture_output=True,
                                  check=False,
                                  text=True)
    if forward_proc.returncode != 0:
        raise Exception(
            'Got an error code when requesting port forwarding: %d' %
            forward_proc.returncode)

    output = forward_proc.stdout
    parsed_port = int(output.splitlines()[0].strip())
    logging.debug('Port forwarding established (local=%d, device=%d)',
                  host_port, parsed_port)
    return parsed_port


# Disable pylint errors since the subclass is not from this directory.
# pylint: disable=invalid-name,missing-function-docstring
class SSHPortForwarder(chrome_test_server_spawner.PortForwarder):
    """Implementation of chrome_test_server_spawner.PortForwarder that uses
    SSH's remote port forwarding feature to forward ports."""

    def __init__(self, host_port_pair: str) -> None:
        self._host_port_pair = host_port_pair

        # Maps the host (server) port to the device port number.
        self._port_mapping = {}

    def Map(self, port_pairs: List[Tuple[int, int]]) -> None:
        for p in port_pairs:
            _, host_port = p
            self._port_mapping[host_port] = \
                port_forward(self._host_port_pair, host_port)

    def GetDevicePortForHostPort(self, host_port: int) -> int:
        return self._port_mapping[host_port]

    def Unmap(self, device_port: int) -> None:
        for host_port, entry in self._port_mapping.items():
            if entry == device_port:
                ssh_prefix = get_ssh_prefix(self._host_port_pair)
                unmap_cmd = [
                    '-NT', '-O', 'cancel', '-R',
                    '0:localhost:%d' % host_port
                ]
                ssh_proc = subprocess.run(ssh_prefix + unmap_cmd, check=False)
                if ssh_proc.returncode != 0:
                    raise Exception('Error %d when unmapping port %d' %
                                    (ssh_proc.returncode, device_port))
                del self._port_mapping[host_port]
                return

        raise Exception('Unmap called for unknown port: %d' % device_port)


# pylint: enable=invalid-name,missing-function-docstring


def setup_test_server(target_id: Optional[str], test_concurrency: int)\
         -> Tuple[chrome_test_server_spawner.SpawningServer, str]:
    """Provisions a test server and configures |target_id| to use it.

    Args:
        target_id: The target to which port forwarding to the test server will
            be established.
        test_concurrency: The number of parallel test jobs that will be run.

    Returns a tuple of a SpawningServer object and the local url to use on
    |target_id| to reach the test server."""

    logging.debug('Starting test server.')

    host_port_pair = run_ffx_command(cmd=('target', 'get-ssh-address'),
                                     target_id=target_id,
                                     capture_output=True).stdout.strip()

    # The TestLauncher can launch more jobs than the limit specified with
    # --test-launcher-jobs so the max number of spawned test servers is set to
    # twice that limit here. See https://crbug.com/913156#c19.
    spawning_server = chrome_test_server_spawner.SpawningServer(
        0, SSHPortForwarder(host_port_pair), test_concurrency * 2)

    forwarded_port = port_forward(host_port_pair, spawning_server.server_port)
    spawning_server.Start()

    logging.debug('Test server listening for connections (port=%d)',
                  spawning_server.server_port)
    logging.debug('Forwarded port is %d', forwarded_port)

    return (spawning_server, 'http://localhost:%d' % forwarded_port)
