#!/usr/bin/env python3
# Copyright 2022 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Proxy for transfer integration testing.

This module contains a proxy for transfer intergation testing.  It is capable
of introducing various link failures into the connection between the client and
server.
"""

import abc
import argparse
import asyncio
from enum import Enum
import logging
import random
import socket
import sys
import time
from typing import Awaitable, Callable, Iterable, NamedTuple

from google.protobuf import text_format

from pigweed.pw_rpc.internal import packet_pb2
from pigweed.pw_transfer import transfer_pb2
from pigweed.pw_transfer.integration_test import config_pb2
from pw_hdlc import decode
from pw_transfer import ProtocolVersion
from pw_transfer.chunk import Chunk

_LOG = logging.getLogger('pw_transfer_intergration_test_proxy')

# This is the maximum size of the socket receive buffers. Ideally, this is set
# to the lowest allowed value to minimize buffering between the proxy and
# clients so rate limiting causes the client to block and wait for the
# integration test proxy to drain rather than allowing OS buffers to backlog
# large quantities of data.
#
# Note that the OS may chose to not strictly follow this requested buffer size.
# Still, setting this value to be relatively small does reduce bufer sizes
# significantly enough to better reflect typical inter-device communication.
#
# For this to be effective, clients should also configure their sockets to a
# smaller send buffer size.
_RECEIVE_BUFFER_SIZE = 2048


class EventType(Enum):
    TRANSFER_START = 1
    PARAMETERS_RETRANSMIT = 2
    PARAMETERS_CONTINUE = 3
    START_ACK_CONFIRMATION = 4


class Event(NamedTuple):
    type: EventType
    chunk: Chunk


class Filter(abc.ABC):
    """An abstract interface for manipulating a stream of data.

    ``Filter``s are used to implement various transforms to simulate real
    world link properties.  Some examples include: data corruption,
    packet loss, packet reordering, rate limiting, latency modeling.

    A ``Filter`` implementation should implement the ``process`` method
    and call ``self.send_data()`` when it has data to send.
    """

    def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
        self.send_data = send_data

    @abc.abstractmethod
    async def process(self, data: bytes) -> None:
        """Processes incoming data.

        Implementations of this method may send arbitrary data, or none, using
        the ``self.send_data()`` handler.
        """

    async def __call__(self, data: bytes) -> None:
        await self.process(data)


class HdlcPacketizer(Filter):
    """A filter which aggregates data into complete HDLC packets.

    Since the proxy transport (SOCK_STREAM) has no framing and we want some
    filters to operates on whole frames, this filter can be used so that
    downstream filters see whole frames.
    """

    def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
        super().__init__(send_data)
        self.decoder = decode.FrameDecoder()

    async def process(self, data: bytes) -> None:
        for frame in self.decoder.process(data):
            await self.send_data(frame.raw_encoded)


class DataDropper(Filter):
    """A filter which drops some data.

    DataDropper will drop data passed through ``process()`` at the
    specified ``rate``.
    """

    def __init__(
        self,
        send_data: Callable[[bytes], Awaitable[None]],
        name: str,
        rate: float,
        seed: int | None = None,
    ):
        super().__init__(send_data)
        self._rate = rate
        self._name = name
        if seed == None:
            seed = time.time_ns()
        self._rng = random.Random(seed)
        _LOG.info(f'{name} DataDropper initialized with seed {seed}')

    async def process(self, data: bytes) -> None:
        if self._rng.uniform(0.0, 1.0) < self._rate:
            _LOG.info(f'{self._name} dropped {len(data)} bytes of data')
        else:
            await self.send_data(data)


class KeepDropQueue(Filter):
    """A filter which alternates between sending packets and dropping packets.

    A KeepDropQueue filter will alternate between keeping packets and dropping
    chunks of data based on a keep/drop queue provided during its creation. The
    queue is looped over unless a negative element is found. A negative number
    is effectively the same as a value of infinity.

     This filter is typically most practical when used with a packetizer so data
     can be dropped as distinct packets.

    Examples:

      keep_drop_queue = [3, 2]:
        Keeps 3 packets,
        Drops 2 packets,
        Keeps 3 packets,
        Drops 2 packets,
        ... [loops indefinitely]

      keep_drop_queue = [5, 99, 1, -1]:
        Keeps 5 packets,
        Drops 99 packets,
        Keeps 1 packet,
        Drops all further packets.
    """

    def __init__(
        self,
        send_data: Callable[[bytes], Awaitable[None]],
        name: str,
        keep_drop_queue: Iterable[int],
        only_consider_transfer_chunks: bool = False,
    ):
        super().__init__(send_data)
        self._keep_drop_queue = list(keep_drop_queue)
        self._loop_idx = 0
        self._current_count = self._keep_drop_queue[0]
        self._keep = True
        self._name = name
        self._only_consider_transfer_chunks = only_consider_transfer_chunks

    async def process(self, data: bytes) -> None:
        if self._only_consider_transfer_chunks:
            try:
                _extract_transfer_chunk(data)
            except Exception:
                await self.send_data(data)
                return

        # Move forward through the queue if needed.
        while self._current_count == 0:
            self._loop_idx += 1
            self._current_count = self._keep_drop_queue[
                self._loop_idx % len(self._keep_drop_queue)
            ]
            self._keep = not self._keep

        if self._current_count > 0:
            self._current_count -= 1

        if self._keep:
            await self.send_data(data)
            _LOG.info(f'{self._name} forwarded {len(data)} bytes of data')
        else:
            _LOG.info(f'{self._name} dropped {len(data)} bytes of data')


class RateLimiter(Filter):
    """A filter which limits transmission rate.

    This filter delays transmission of data by len(data)/rate.
    """

    def __init__(
        self, send_data: Callable[[bytes], Awaitable[None]], rate: float
    ):
        super().__init__(send_data)
        self._rate = rate

    async def process(self, data: bytes) -> None:
        delay = len(data) / self._rate
        await asyncio.sleep(delay)
        await self.send_data(data)


class DataTransposer(Filter):
    """A filter which occasionally transposes two chunks of data.

    This filter transposes data at the specified rate.  It does this by
    holding a chunk to transpose until another chunk arrives. The filter
    will not hold a chunk longer than ``timeout`` seconds.
    """

    def __init__(
        self,
        send_data: Callable[[bytes], Awaitable[None]],
        name: str,
        rate: float,
        timeout: float,
        seed: int,
    ):
        super().__init__(send_data)
        self._name = name
        self._rate = rate
        self._timeout = timeout
        self._data_queue = asyncio.Queue()
        self._rng = random.Random(seed)
        self._transpose_task = asyncio.create_task(self._transpose_handler())

        _LOG.info(f'{name} DataTranspose initialized with seed {seed}')

    def __del__(self):
        _LOG.info(f'{self._name} cleaning up transpose task.')
        self._transpose_task.cancel()

    async def _transpose_handler(self):
        """Async task that handles the packet transposition and timeouts"""
        held_data: bytes | None = None
        while True:
            # Only use timeout if we have data held for transposition
            timeout = None if held_data is None else self._timeout
            try:
                data = await asyncio.wait_for(
                    self._data_queue.get(), timeout=timeout
                )

                if held_data is not None:
                    # If we have held data, send it out of order.
                    await self.send_data(data)
                    await self.send_data(held_data)
                    held_data = None
                else:
                    # Otherwise decide if we should transpose the current data.
                    if self._rng.uniform(0.0, 1.0) < self._rate:
                        _LOG.info(
                            f'{self._name} transposing {len(data)} bytes of data'
                        )
                        held_data = data
                    else:
                        await self.send_data(data)

            except asyncio.TimeoutError:
                _LOG.info(f'{self._name} sending data in order due to timeout')
                await self.send_data(held_data)
                held_data = None

    async def process(self, data: bytes) -> None:
        # Queue data for processing by the transpose task.
        await self._data_queue.put(data)


class ServerFailure(Filter):
    """A filter to simulate the server stopping sending packets.

    ServerFailure takes a list of numbers of packets to send before
    dropping all subsequent packets until a TRANSFER_START packet
    is seen.  This process is repeated for each element in
    packets_before_failure.  After that list is exhausted, ServerFailure
    will send all packets.

    This filter should be instantiated in the same filter stack as an
    HdlcPacketizer so that EventFilter can decode complete packets.
    """

    def __init__(
        self,
        send_data: Callable[[bytes], Awaitable[None]],
        name: str,
        packets_before_failure_list: list[int],
        start_immediately: bool = False,
        only_consider_transfer_chunks: bool = False,
    ):
        super().__init__(send_data)
        self._name = name
        self._relay_packets = True
        self._packets_before_failure_list = packets_before_failure_list
        self._packets_before_failure = None
        self._only_consider_transfer_chunks = only_consider_transfer_chunks
        if start_immediately:
            self.advance_packets_before_failure()

    def advance_packets_before_failure(self):
        if len(self._packets_before_failure_list) > 0:
            self._packets_before_failure = (
                self._packets_before_failure_list.pop(0)
            )
        else:
            self._packets_before_failure = None

    async def process(self, data: bytes) -> None:
        if self._only_consider_transfer_chunks:
            try:
                _extract_transfer_chunk(data)
            except Exception:
                await self.send_data(data)
                return

        if self._packets_before_failure is None:
            await self.send_data(data)
        elif self._packets_before_failure > 0:
            self._packets_before_failure -= 1
            await self.send_data(data)

    def handle_event(self, event: Event) -> None:
        if event.type is EventType.TRANSFER_START:
            self.advance_packets_before_failure()


class WindowPacketDropper(Filter):
    """A filter to allow the same packet in each window to be dropped.

    WindowPacketDropper with drop the nth packet in each window as
    specified by window_packet_to_drop.  This process will happen
    indefinitely for each window.

    This filter should be instantiated in the same filter stack as an
    HdlcPacketizer so that EventFilter can decode complete packets.
    """

    def __init__(
        self,
        send_data: Callable[[bytes], Awaitable[None]],
        name: str,
        window_packet_to_drop: int,
    ):
        super().__init__(send_data)
        self._name = name
        self._relay_packets = True
        self._window_packet_to_drop = window_packet_to_drop
        self._next_window_start_offset: int | None = 0
        self._window_packet = 0

    async def process(self, data: bytes) -> None:
        data_chunk = None
        try:
            chunk = _extract_transfer_chunk(data)
            if chunk.type is Chunk.Type.DATA:
                data_chunk = chunk
        except Exception:
            # Invalid / non-chunk data (e.g. text logs); ignore.
            pass

        # Only count transfer data chunks as part of a window.
        if data_chunk is not None:
            if data_chunk.offset == self._next_window_start_offset:
                # If a new window has been requested, wait until the first
                # chunk matching its requested offset to begin counting window
                # chunks. Any in-flight chunks from the previous window are
                # allowed through.
                self._window_packet = 0
                self._next_window_start_offset = None

            if self._window_packet != self._window_packet_to_drop:
                await self.send_data(data)

            self._window_packet += 1
        else:
            await self.send_data(data)

    def handle_event(self, event: Event) -> None:
        if event.type in (
            EventType.PARAMETERS_RETRANSMIT,
            EventType.PARAMETERS_CONTINUE,
            EventType.START_ACK_CONFIRMATION,
        ):
            # A new transmission window has been requested, starting at the
            # offset specified in the chunk. The receiver may already have data
            # from the previous window in-flight, so don't immediately reset
            # the window packet counter.
            self._next_window_start_offset = event.chunk.offset


class EventFilter(Filter):
    """A filter that inspects packets and send events to other filters.

    This filter should be instantiated in the same filter stack as an
    HdlcPacketizer so that it can decode complete packets.
    """

    def __init__(
        self,
        send_data: Callable[[bytes], Awaitable[None]],
        name: str,
        event_queue: asyncio.Queue,
    ):
        super().__init__(send_data)
        self._name = name
        self._queue = event_queue

    async def process(self, data: bytes) -> None:
        try:
            chunk = _extract_transfer_chunk(data)
            if chunk.type is Chunk.Type.START:
                await self._queue.put(Event(EventType.TRANSFER_START, chunk))
            if chunk.type is Chunk.Type.START_ACK_CONFIRMATION:
                await self._queue.put(
                    Event(EventType.START_ACK_CONFIRMATION, chunk)
                )
            elif chunk.type is Chunk.Type.PARAMETERS_RETRANSMIT:
                await self._queue.put(
                    Event(EventType.PARAMETERS_RETRANSMIT, chunk)
                )
            elif chunk.type is Chunk.Type.PARAMETERS_CONTINUE:
                await self._queue.put(
                    Event(EventType.PARAMETERS_CONTINUE, chunk)
                )
        except:
            # Silently ignore invalid packets
            pass

        await self.send_data(data)


def _extract_transfer_chunk(data: bytes) -> Chunk:
    """Gets a transfer Chunk from an HDLC frame containing an RPC packet.

    Raises an exception if a valid chunk does not exist.
    """

    decoder = decode.FrameDecoder()
    for frame in decoder.process(data):
        packet = packet_pb2.RpcPacket()
        packet.ParseFromString(frame.data)

        if packet.payload:
            raw_chunk = transfer_pb2.Chunk()
            raw_chunk.ParseFromString(packet.payload)
            return Chunk.from_message(raw_chunk)

        # The incoming data is expected to be HDLC-packetized, so only one
        # frame should exist.
        break

    raise ValueError("Invalid transfer chunk frame")


async def _handle_simplex_events(
    event_queue: asyncio.Queue, handlers: list[Callable[[Event], None]]
):
    while True:
        event = await event_queue.get()
        for handler in handlers:
            handler(event)


async def _handle_simplex_connection(
    name: str,
    filter_stack_config: list[config_pb2.FilterConfig],
    reader: asyncio.StreamReader,
    writer: asyncio.StreamWriter,
    inbound_event_queue: asyncio.Queue,
    outbound_event_queue: asyncio.Queue,
) -> None:
    """Handle a single direction of a bidirectional connection between
    server and client."""

    async def send(data: bytes):
        writer.write(data)
        await writer.drain()

    filter_stack = EventFilter(send, name, outbound_event_queue)

    event_handlers: list[Callable[[Event], None]] = []

    # Build the filter stack from the bottom up
    for config in reversed(filter_stack_config):
        filter_name = config.WhichOneof("filter")
        if filter_name == "hdlc_packetizer":
            filter_stack = HdlcPacketizer(filter_stack)
        elif filter_name == "data_dropper":
            data_dropper = config.data_dropper
            filter_stack = DataDropper(
                filter_stack, name, data_dropper.rate, data_dropper.seed
            )
        elif filter_name == "rate_limiter":
            filter_stack = RateLimiter(filter_stack, config.rate_limiter.rate)
        elif filter_name == "data_transposer":
            transposer = config.data_transposer
            filter_stack = DataTransposer(
                filter_stack,
                name,
                transposer.rate,
                transposer.timeout,
                transposer.seed,
            )
        elif filter_name == "server_failure":
            server_failure = config.server_failure
            filter_stack = ServerFailure(
                filter_stack,
                name,
                server_failure.packets_before_failure,
                server_failure.start_immediately,
                server_failure.only_consider_transfer_chunks,
            )
            event_handlers.append(filter_stack.handle_event)
        elif filter_name == "keep_drop_queue":
            keep_drop_queue = config.keep_drop_queue
            filter_stack = KeepDropQueue(
                filter_stack,
                name,
                keep_drop_queue.keep_drop_queue,
                keep_drop_queue.only_consider_transfer_chunks,
            )
        elif filter_name == "window_packet_dropper":
            window_packet_dropper = config.window_packet_dropper
            filter_stack = WindowPacketDropper(
                filter_stack, name, window_packet_dropper.window_packet_to_drop
            )
            event_handlers.append(filter_stack.handle_event)
        else:
            sys.exit(f'Unknown filter {filter_name}')

    event_task = asyncio.create_task(
        _handle_simplex_events(inbound_event_queue, event_handlers)
    )

    while True:
        # Arbitrarily chosen "page sized" read.
        data = await reader.read(4096)

        # An empty data indicates that the connection is closed.
        if not data:
            _LOG.info(f'{name} connection closed.')
            return

        await filter_stack.process(data)


async def _handle_connection(
    server_port: int,
    config: config_pb2.ProxyConfig,
    client_reader: asyncio.StreamReader,
    client_writer: asyncio.StreamWriter,
) -> None:
    """Handle a connection between server and client."""

    client_addr = client_writer.get_extra_info('peername')
    _LOG.info(f'New client connection from {client_addr}')

    # Open a new connection to the server for each client connection.
    #
    # TODO(konkers): catch exception and close client writer
    server_reader, server_writer = await asyncio.open_connection(
        'localhost', server_port
    )
    _LOG.info(f'New connection opened to server')

    # Queues for the simplex connections to pass events to each other.
    server_event_queue = asyncio.Queue()
    client_event_queue = asyncio.Queue()

    # Instantiate two simplex handler one for each direction of the connection.
    _, pending = await asyncio.wait(
        [
            asyncio.create_task(
                _handle_simplex_connection(
                    "client",
                    config.client_filter_stack,
                    client_reader,
                    server_writer,
                    server_event_queue,
                    client_event_queue,
                )
            ),
            asyncio.create_task(
                _handle_simplex_connection(
                    "server",
                    config.server_filter_stack,
                    server_reader,
                    client_writer,
                    client_event_queue,
                    server_event_queue,
                )
            ),
        ],
        return_when=asyncio.FIRST_COMPLETED,
    )

    # When one side terminates the connection, also terminate the other side
    for task in pending:
        task.cancel()

    for stream in [client_writer, server_writer]:
        stream.close()


def _parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    parser.add_argument(
        '--server-port',
        type=int,
        required=True,
        help='Port of the integration test server.  The proxy will forward connections to this port',
    )
    parser.add_argument(
        '--client-port',
        type=int,
        required=True,
        help='Port on which to listen for connections from integration test client.',
    )

    return parser.parse_args()


def _init_logging(level: int) -> None:
    _LOG.setLevel(logging.DEBUG)
    log_to_stderr = logging.StreamHandler()
    log_to_stderr.setLevel(level)
    log_to_stderr.setFormatter(
        logging.Formatter(
            fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s',
            datefmt='%H:%M:%S',
        )
    )

    _LOG.addHandler(log_to_stderr)


async def _main(server_port: int, client_port: int) -> None:
    _init_logging(logging.DEBUG)

    # Load config from stdin using synchronous IO
    text_config = sys.stdin.buffer.read()

    config = text_format.Parse(text_config, config_pb2.ProxyConfig())

    # Instantiate the TCP server.
    server_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
    server_socket.setsockopt(
        socket.SOL_SOCKET, socket.SO_RCVBUF, _RECEIVE_BUFFER_SIZE
    )
    server_socket.bind(('', client_port))
    server = await asyncio.start_server(
        lambda reader, writer: _handle_connection(
            server_port, config, reader, writer
        ),
        limit=_RECEIVE_BUFFER_SIZE,
        sock=server_socket,
    )

    addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets)
    _LOG.info(f'Listening for client connection on {addrs}')

    # Run the TCP server.
    async with server:
        await server.serve_forever()


if __name__ == '__main__':
    asyncio.run(_main(**vars(_parse_args())))
