#!/usr/bin/env python3
# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""device module unit tests"""

from contextlib import contextmanager
import logging
import queue
import threading
import time
import unittest

from pw_hdlc.rpc import RpcClient, HdlcRpcClient, CancellableReader


class QueueFile:
    """A fake file object backed by a queue for testing."""

    EOF = object()

    def __init__(self):
        # Operator puts; consumer gets
        self._q = queue.Queue()

        # Consumer side access only!
        self._readbuf = b''
        self._eof = False

    ###############
    # Consumer side

    def __enter__(self):
        return self

    def __exit__(self, *exc_info):
        self.close()

    def _read_from_buf(self, size: int) -> bytes:
        data = self._readbuf[:size]
        self._readbuf = self._readbuf[size:]
        return data

    def read(self, size: int = 1) -> bytes:
        """Reads data from the queue"""
        # First try to get buffered data
        data = self._read_from_buf(size)
        assert len(data) <= size
        size -= len(data)

        # if size == 0:
        if data:
            return data

        # No more data in the buffer
        assert not self._readbuf

        if self._eof:
            return data  # may be empty

        # Not enough in the buffer; block on the queue
        item = self._q.get()

        # NOTE: We can't call Queue.task_done() here because the reader hasn't
        # actually *acted* on the read item yet.

        # Queued data
        if isinstance(item, bytes):
            self._readbuf = item
            return self._read_from_buf(size)

        # Queued exception
        if isinstance(item, Exception):
            raise item

        # Report EOF
        if item is self.EOF:
            self._eof = True
            return data  # may be empty

        raise Exception('unexpected item type')

    def write(self, data: bytes) -> None:
        pass

    #####################
    # Weird middle ground

    # It is a violation of most file-like object APIs for one thread to call
    # close() while another thread is calling read(). The behavior is
    # undefined.
    #
    # - On Linux, close() may wake up a select(), leaving the caller with a bad
    #   file descriptor (which could get reused!)
    # - Or the read() could continue to block indefinitely.
    #
    # We choose to cause a subsequent/parallel read to receive an exception.
    def close(self) -> None:
        self.cause_read_exc(Exception('closed'))

    ###############
    # Operator side

    def put_read_data(self, data: bytes) -> None:
        self._q.put(data)

    def cause_read_exc(self, exc: Exception) -> None:
        self._q.put(exc)

    def set_read_eof(self) -> None:
        self._q.put(self.EOF)

    def wait_for_drain(self, timeout=None) -> None:
        """Wait for the queue to drain (be fully consumed).

        Args:
          timeout: The maximum time (in seconds) to wait, or wait forever
            if None.

        Raises:
          TimeoutError: If timeout is given and has elapsed.
        """
        # It would be great to use Queue.join() here, but that requires the
        # consumer to call Queue.task_done(), and we can't do that because
        # the consumer of read() doesn't know anything about it.
        # Instead, we poll.  ¯\_(ツ)_/¯
        start_time = time.time()
        while not self._q.empty():
            if timeout is not None:
                elapsed = time.time() - start_time
                if elapsed > timeout:
                    raise TimeoutError(f"Queue not empty after {elapsed} sec")
            time.sleep(0.1)


class QueueFileTest(unittest.TestCase):
    """Test the QueueFile class"""

    def test_read_data(self) -> None:
        file = QueueFile()
        file.put_read_data(b'hello')
        self.assertEqual(file.read(5), b'hello')

    def test_read_data_multi_read(self) -> None:
        file = QueueFile()
        file.put_read_data(b'helloworld')
        self.assertEqual(file.read(5), b'hello')
        self.assertEqual(file.read(5), b'world')

    def test_read_data_multi_put(self) -> None:
        file = QueueFile()
        file.put_read_data(b'hello')
        file.put_read_data(b'world')
        self.assertEqual(file.read(5), b'hello')
        self.assertEqual(file.read(5), b'world')

    def test_read_eof(self) -> None:
        file = QueueFile()
        file.set_read_eof()
        result = file.read(5)
        self.assertEqual(result, b'')

    def test_read_exception(self) -> None:
        file = QueueFile()
        message = 'test exception'
        file.cause_read_exc(ValueError(message))
        with self.assertRaisesRegex(ValueError, message):
            file.read(5)

    def test_wait_for_drain_works(self) -> None:
        file = QueueFile()
        file.put_read_data(b'hello')
        file.read()
        try:
            # Timeout is arbitrary; will return immediately.
            file.wait_for_drain(0.1)
        except TimeoutError:
            self.fail("wait_for_drain raised TimeoutError")

    def test_wait_for_drain_raises(self) -> None:
        file = QueueFile()
        file.put_read_data(b'hello')
        # don't read
        with self.assertRaises(TimeoutError):
            # Timeout is arbitrary; it will raise no matter what.
            file.wait_for_drain(0.1)


class Sentinel:
    def __repr__(self):
        return 'Sentinel'


class _QueueReader(CancellableReader):
    def cancel_read(self) -> None:
        self._base_obj.close()


def _get_client(file) -> RpcClient:
    return HdlcRpcClient(
        _QueueReader(file),
        paths_or_modules=[],
        channels=[],
    )


# This should take <10ms but we'll wait up to 1000x longer.
_QUEUE_DRAIN_TIMEOUT = 10.0


class HdlcRpcClientTest(unittest.TestCase):
    """Tests the pw_hdlc.rpc.HdlcRpcClient class."""

    # NOTE: There is no test here for stream EOF because Serial.read()
    # can return an empty result if configured with timeout != None.
    # The reader thread will continue in this case.

    def test_clean_close_after_stream_close(self) -> None:
        """Assert RpcClient closes cleanly when stream closes."""
        # See b/293595266.
        file = QueueFile()

        with self.assert_no_hdlc_rpc_error_logs():
            with file:
                with _get_client(file):
                    # We want to make sure the reader thread is blocked on
                    # read() and doesn't exit immediately.
                    file.put_read_data(b'')
                    file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT)

                # RpcClient.__exit__ calls stop() on the reader thread, but
                # it is blocked on file.read().

            # QueueFile.close() is called, triggering an exception in the
            # blocking read() (by implementation choice). The reader should
            # handle it by *not* logging it and exiting immediately.

        self.assert_no_background_threads_running()

    def test_device_handles_read_exception(self) -> None:
        """Assert RpcClient closes cleanly when read raises an exception."""
        # See b/293595266.
        file = QueueFile()

        logger = logging.getLogger('pw_hdlc.rpc')
        test_exc = Exception('boom')
        with self.assertLogs(logger, level=logging.ERROR) as ctx:
            with _get_client(file):
                # Cause read() to raise an exception. The reader should
                # handle it by logging it and exiting immediately.
                file.cause_read_exc(test_exc)
                file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT)

        # Assert one exception was raised
        self.assertEqual(len(ctx.records), 1)
        rec = ctx.records[0]
        self.assertIsNotNone(rec.exc_info)
        assert rec.exc_info is not None  # for mypy
        self.assertEqual(rec.exc_info[1], test_exc)

        self.assert_no_background_threads_running()

    @contextmanager
    def assert_no_hdlc_rpc_error_logs(self):
        logger = logging.getLogger('pw_hdlc.rpc')
        sentinel = Sentinel()
        with self.assertLogs(logger, level=logging.ERROR) as ctx:
            # TODO: b/294861320 - use assertNoLogs() in Python 3.10+
            # We actually want to assert there are no errors, but
            # TestCase.assertNoLogs() is not available until Python 3.10.
            # So we log one error to keep the test from failing and manually
            # inspect the list of captured records.
            logger.error(sentinel)

            yield ctx

        self.assertEqual([record.msg for record in ctx.records], [sentinel])

    def assert_no_background_threads_running(self):
        self.assertEqual(threading.enumerate(), [threading.current_thread()])


if __name__ == '__main__':
    unittest.main()
