#!/usr/bin/env python3
# Copyright 2022 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Unit test for proxy.py"""

import abc
import asyncio
from struct import pack
import time
import unittest

from pigweed.pw_rpc.internal import packet_pb2
from pigweed.pw_transfer import transfer_pb2
from pw_hdlc import encode
from pw_transfer.chunk import Chunk, ProtocolVersion

import proxy


class MockRng(abc.ABC):
    def __init__(self, results: list[float]):
        self._results = results

    def uniform(self, from_val: float, to_val: float) -> float:
        val_range = to_val - from_val
        val = self._results.pop()
        val *= val_range
        val += from_val
        return val


class ProxyTest(unittest.IsolatedAsyncioTestCase):
    async def test_transposer_simple(self):
        sent_packets: list[bytes] = []
        new_packets_event: asyncio.Event = asyncio.Event()

        # Async helper so DataTransposer can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)
            # Notify that a new packet was "sent".
            new_packets_event.set()

        transposer = proxy.DataTransposer(
            lambda data: append(sent_packets, data),
            name="test",
            rate=0.5,
            timeout=100,
            seed=1234567890,
        )
        transposer._rng = MockRng([0.6, 0.4])
        await transposer.process(b'aaaaaaaaaa')
        await transposer.process(b'bbbbbbbbbb')

        expected_packets = [b'bbbbbbbbbb', b'aaaaaaaaaa']
        while True:
            # Wait for new packets with a generous timeout.
            try:
                await asyncio.wait_for(new_packets_event.wait(), timeout=60.0)
            except TimeoutError:
                self.fail(
                    f'Timeout waiting for data.  Packets sent: {sent_packets}'
                )

            # Only assert the sent packets are corrected when we've sent the
            # expected number.
            if len(sent_packets) == len(expected_packets):
                self.assertEqual(sent_packets, expected_packets)
                return

    async def test_transposer_timeout(self):
        sent_packets: list[bytes] = []

        # Async helper so DataTransposer can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)

        transposer = proxy.DataTransposer(
            lambda data: append(sent_packets, data),
            name="test",
            rate=0.5,
            timeout=0.100,
            seed=1234567890,
        )
        transposer._rng = MockRng([0.4, 0.6])
        await transposer.process(b'aaaaaaaaaa')

        # Even though this should be transposed, there is no following data so
        # the transposer should timout and send this in-order.
        await transposer.process(b'bbbbbbbbbb')

        # Give the transposer time to timeout.
        await asyncio.sleep(0.5)

        self.assertEqual(sent_packets, [b'aaaaaaaaaa', b'bbbbbbbbbb'])

    async def test_server_failure(self):
        sent_packets: list[bytes] = []

        # Async helper so DataTransposer can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)

        packets_before_failure = [1, 2, 3]
        server_failure = proxy.ServerFailure(
            lambda data: append(sent_packets, data),
            name="test",
            packets_before_failure_list=packets_before_failure.copy(),
            start_immediately=True,
        )

        # After passing the list to ServerFailure, add a test for no
        # packets dropped
        packets_before_failure.append(5)

        packets = [
            b'1',
            b'2',
            b'3',
            b'4',
            b'5',
        ]

        for num_packets in packets_before_failure:
            sent_packets.clear()
            for packet in packets:
                await server_failure.process(packet)
            self.assertEqual(len(sent_packets), num_packets)
            server_failure.handle_event(
                proxy.Event(
                    proxy.EventType.TRANSFER_START,
                    Chunk(ProtocolVersion.VERSION_TWO, Chunk.Type.START),
                )
            )

    async def test_server_failure_transfer_chunks_only(self):
        sent_packets = []

        # Async helper so DataTransposer can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)

        packets_before_failure = [2]
        server_failure = proxy.ServerFailure(
            lambda data: append(sent_packets, data),
            name="test",
            packets_before_failure_list=packets_before_failure.copy(),
            start_immediately=True,
            only_consider_transfer_chunks=True,
        )

        transfer_chunk = _encode_rpc_frame(
            Chunk(ProtocolVersion.VERSION_TWO, Chunk.Type.DATA, data=b'1')
        )

        packets = [
            b'1',
            b'2',
            transfer_chunk,  # 1
            b'3',
            transfer_chunk,  # 2
            b'4',
            b'5',
            transfer_chunk,  # Transfer chunks should be dropped starting here.
            transfer_chunk,
            b'6',
            b'7',
            transfer_chunk,
        ]

        for packet in packets:
            await server_failure.process(packet)

        expected_result = [
            b'1',
            b'2',
            transfer_chunk,
            b'3',
            transfer_chunk,
            b'4',
            b'5',
            b'6',
            b'7',
        ]
        self.assertEqual(sent_packets, expected_result)

    async def test_keep_drop_queue_loop(self):
        sent_packets: list[bytes] = []

        # Async helper so DataTransposer can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)

        keep_drop_queue = proxy.KeepDropQueue(
            lambda data: append(sent_packets, data),
            name="test",
            keep_drop_queue=[2, 1, 3],
        )

        expected_sequence = [
            b'1',
            b'2',
            b'4',
            b'5',
            b'6',
            b'9',
        ]
        input_packets = [
            b'1',
            b'2',
            b'3',
            b'4',
            b'5',
            b'6',
            b'7',
            b'8',
            b'9',
        ]

        for packet in input_packets:
            await keep_drop_queue.process(packet)
        self.assertEqual(sent_packets, expected_sequence)

    async def test_keep_drop_queue(self):
        sent_packets: list[bytes] = []

        # Async helper so DataTransposer can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)

        keep_drop_queue = proxy.KeepDropQueue(
            lambda data: append(sent_packets, data),
            name="test",
            keep_drop_queue=[2, 1, 1, -1],
        )

        expected_sequence = [
            b'1',
            b'2',
            b'4',
        ]
        input_packets = [
            b'1',
            b'2',
            b'3',
            b'4',
            b'5',
            b'6',
            b'7',
            b'8',
            b'9',
        ]

        for packet in input_packets:
            await keep_drop_queue.process(packet)
        self.assertEqual(sent_packets, expected_sequence)

    async def test_keep_drop_queue_transfer_chunks_only(self):
        sent_packets: list[bytes] = []

        # Async helper so DataTransposer can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)

        keep_drop_queue = proxy.KeepDropQueue(
            lambda data: append(sent_packets, data),
            name="test",
            keep_drop_queue=[2, 1, 1, -1],
            only_consider_transfer_chunks=True,
        )

        transfer_chunk = _encode_rpc_frame(
            Chunk(ProtocolVersion.VERSION_TWO, Chunk.Type.DATA, data=b'1')
        )

        expected_sequence = [
            b'1',
            transfer_chunk,
            b'2',
            transfer_chunk,
            b'3',
            b'4',
            b'5',
            b'6',
            b'7',
            transfer_chunk,
            b'8',
            b'9',
            b'10',
        ]
        input_packets = [
            b'1',
            transfer_chunk,  # keep
            b'2',
            transfer_chunk,  # keep
            b'3',
            b'4',
            b'5',
            transfer_chunk,  # drop
            b'6',
            b'7',
            transfer_chunk,  # keep
            transfer_chunk,  # drop
            b'8',
            transfer_chunk,  # drop
            b'9',
            transfer_chunk,  # drop
            transfer_chunk,  # drop
            b'10',
        ]

        for packet in input_packets:
            await keep_drop_queue.process(packet)
        self.assertEqual(sent_packets, expected_sequence)

    async def test_window_packet_dropper(self):
        sent_packets: list[bytes] = []

        # Async helper so DataTransposer can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)

        window_packet_dropper = proxy.WindowPacketDropper(
            lambda data: append(sent_packets, data),
            name="test",
            window_packet_to_drop=0,
        )

        packets = [
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'1',
                    session_id=1,
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'2',
                    session_id=1,
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'3',
                    session_id=1,
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'4',
                    session_id=1,
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'5',
                    session_id=1,
                )
            ),
        ]

        expected_packets = packets[1:]

        # Test each even twice to assure the filter does not have issues
        # on new window bondaries.
        events = [
            proxy.Event(
                proxy.EventType.PARAMETERS_RETRANSMIT,
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.PARAMETERS_RETRANSMIT,
                ),
            ),
            proxy.Event(
                proxy.EventType.PARAMETERS_CONTINUE,
                Chunk(
                    ProtocolVersion.VERSION_TWO, Chunk.Type.PARAMETERS_CONTINUE
                ),
            ),
            proxy.Event(
                proxy.EventType.PARAMETERS_RETRANSMIT,
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.PARAMETERS_RETRANSMIT,
                ),
            ),
            proxy.Event(
                proxy.EventType.PARAMETERS_CONTINUE,
                Chunk(
                    ProtocolVersion.VERSION_TWO, Chunk.Type.PARAMETERS_CONTINUE
                ),
            ),
        ]

        for event in events:
            sent_packets.clear()
            for packet in packets:
                await window_packet_dropper.process(packet)
            self.assertEqual(sent_packets, expected_packets)
            window_packet_dropper.handle_event(event)

    async def test_window_packet_dropper_extra_in_flight_packets(self):
        sent_packets: list[bytes] = []

        # Async helper so DataTransposer can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)

        window_packet_dropper = proxy.WindowPacketDropper(
            lambda data: append(sent_packets, data),
            name="test",
            window_packet_to_drop=1,
        )

        packets = [
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'1',
                    offset=0,
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'2',
                    offset=1,
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'3',
                    offset=2,
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'2',
                    offset=1,
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'3',
                    offset=2,
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    data=b'4',
                    offset=3,
                )
            ),
        ]

        expected_packets = packets[1:]

        # Test each even twice to assure the filter does not have issues
        # on new window bondaries.
        events = [
            None,
            proxy.Event(
                proxy.EventType.PARAMETERS_RETRANSMIT,
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.PARAMETERS_RETRANSMIT,
                    offset=1,
                ),
            ),
            None,
            None,
            None,
            None,
        ]

        for packet, event in zip(packets, events):
            await window_packet_dropper.process(packet)
            if event is not None:
                window_packet_dropper.handle_event(event)

        expected_packets = [packets[0], packets[2], packets[3], packets[5]]
        self.assertEqual(sent_packets, expected_packets)

    async def test_event_filter(self):
        sent_packets: list[bytes] = []

        # Async helper so EventFilter can await on it.
        async def append(list: list[bytes], data: bytes):
            list.append(data)

        queue = asyncio.Queue()

        event_filter = proxy.EventFilter(
            lambda data: append(sent_packets, data),
            name="test",
            event_queue=queue,
        )

        request = packet_pb2.RpcPacket(
            type=packet_pb2.PacketType.REQUEST,
            channel_id=101,
            service_id=1001,
            method_id=100001,
        ).SerializeToString()

        packets = [
            request,
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO, Chunk.Type.START, session_id=1
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    session_id=1,
                    data=b'3',
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    session_id=1,
                    data=b'3',
                )
            ),
            request,
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO, Chunk.Type.START, session_id=2
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    session_id=2,
                    data=b'4',
                )
            ),
            _encode_rpc_frame(
                Chunk(
                    ProtocolVersion.VERSION_TWO,
                    Chunk.Type.DATA,
                    session_id=2,
                    data=b'5',
                )
            ),
        ]

        expected_events = [
            None,  # request
            proxy.EventType.TRANSFER_START,
            None,  # data chunk
            None,  # data chunk
            None,  # request
            proxy.EventType.TRANSFER_START,
            None,  # data chunk
            None,  # data chunk
        ]

        for packet, expected_event_type in zip(packets, expected_events):
            await event_filter.process(packet)
            try:
                event_type = queue.get_nowait().type
            except asyncio.QueueEmpty:
                event_type = None
            self.assertEqual(event_type, expected_event_type)


def _encode_rpc_frame(chunk: Chunk) -> bytes:
    packet = packet_pb2.RpcPacket(
        type=packet_pb2.PacketType.SERVER_STREAM,
        channel_id=101,
        service_id=1001,
        method_id=100001,
        payload=chunk.to_message().SerializeToString(),
    ).SerializeToString()
    return encode.ui_frame(73, packet)


if __name__ == '__main__':
    unittest.main()
