# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import enum
import logging
import os
import struct
import time

import click

from bumble import l2cap
from bumble.core import (
    BT_BR_EDR_TRANSPORT,
    BT_LE_TRANSPORT,
    BT_L2CAP_PROTOCOL_ID,
    BT_RFCOMM_PROTOCOL_ID,
    UUID,
    CommandTimeoutError,
)
from bumble.colors import color
from bumble.device import Connection, ConnectionParametersPreferences, Device, Peer
from bumble.gatt import Characteristic, CharacteristicValue, Service
from bumble.hci import (
    HCI_LE_1M_PHY,
    HCI_LE_2M_PHY,
    HCI_LE_CODED_PHY,
    HCI_CENTRAL_ROLE,
    HCI_PERIPHERAL_ROLE,
    HCI_Constant,
    HCI_Error,
    HCI_StatusError,
)
from bumble.sdp import (
    SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
    SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
    SDP_PUBLIC_BROWSE_ROOT,
    SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
    SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
    DataElement,
    ServiceAttribute,
)
from bumble.transport import open_transport_or_link
import bumble.rfcomm
import bumble.core
from bumble.utils import AsyncRunner
from bumble.pairing import PairingConfig


# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_CENTRAL_ADDRESS = 'F0:F0:F0:F0:F0:F0'
DEFAULT_CENTRAL_NAME = 'Speed Central'
DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1'
DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral'

SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5'
SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53'
SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'

DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
DEFAULT_L2CAP_PSM = 128
DEFAULT_L2CAP_MAX_CREDITS = 128
DEFAULT_L2CAP_MTU = 1024
DEFAULT_L2CAP_MPS = 1024

DEFAULT_LINGER_TIME = 1.0
DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0

DEFAULT_RFCOMM_CHANNEL = 8
DEFAULT_RFCOMM_MTU = 2048


# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def parse_packet(packet):
    if len(packet) < 1:
        logging.info(
            color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
        )
        raise ValueError('packet too short')

    try:
        packet_type = PacketType(packet[0])
    except ValueError:
        logging.info(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
        raise

    return (packet_type, packet[1:])


def parse_packet_sequence(packet_data):
    if len(packet_data) < 5:
        logging.info(
            color(
                f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
                'red',
            )
        )
        raise ValueError('packet too short')
    return struct.unpack_from('>bI', packet_data, 0)


def le_phy_name(phy_id):
    return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
        phy_id, HCI_Constant.le_phy_name(phy_id)
    )


def print_connection(connection):
    params = []
    if connection.transport == BT_LE_TRANSPORT:
        params.append(
            'PHY='
            f'TX:{le_phy_name(connection.phy.tx_phy)}/'
            f'RX:{le_phy_name(connection.phy.rx_phy)}'
        )

        params.append(
            'DL=('
            f'TX:{connection.data_length[0]}/{connection.data_length[1]},'
            f'RX:{connection.data_length[2]}/{connection.data_length[3]}'
            ')'
        )

        params.append(
            'Parameters='
            f'{connection.parameters.connection_interval * 1.25:.2f}/'
            f'{connection.parameters.peripheral_latency}/'
            f'{connection.parameters.supervision_timeout * 10} '
        )

        params.append(f'MTU={connection.att_mtu}')

    else:
        params.append(f'Role={HCI_Constant.role_name(connection.role)}')

    logging.info(color('@@@ Connection: ', 'yellow') + ' '.join(params))


def make_sdp_records(channel):
    return {
        0x00010001: [
            ServiceAttribute(
                SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
                DataElement.unsigned_integer_32(0x00010001),
            ),
            ServiceAttribute(
                SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
                DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
            ),
            ServiceAttribute(
                SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
                DataElement.sequence([DataElement.uuid(UUID(DEFAULT_RFCOMM_UUID))]),
            ),
            ServiceAttribute(
                SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
                DataElement.sequence(
                    [
                        DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
                        DataElement.sequence(
                            [
                                DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
                                DataElement.unsigned_integer_8(channel),
                            ]
                        ),
                    ]
                ),
            ),
        ]
    }


def log_stats(title, stats):
    stats_min = min(stats)
    stats_max = max(stats)
    stats_avg = sum(stats) / len(stats)
    logging.info(
        color(
            (
                f'### {title} stats: '
                f'min={stats_min:.2f}, '
                f'max={stats_max:.2f}, '
                f'average={stats_avg:.2f}'
            ),
            'cyan',
        )
    )


async def switch_roles(connection, role):
    target_role = HCI_CENTRAL_ROLE if role == "central" else HCI_PERIPHERAL_ROLE
    if connection.role != target_role:
        logging.info(f'{color("### Switching roles to:", "cyan")} {role}')
        try:
            await connection.switch_role(target_role)
            logging.info(color('### Role switch complete', 'cyan'))
        except HCI_Error as error:
            logging.info(f'{color("### Role switch failed:", "red")} {error}')


class PacketType(enum.IntEnum):
    RESET = 0
    SEQUENCE = 1
    ACK = 2


PACKET_FLAG_LAST = 1


# -----------------------------------------------------------------------------
# Sender
# -----------------------------------------------------------------------------
class Sender:
    def __init__(
        self,
        packet_io,
        start_delay,
        repeat,
        repeat_delay,
        pace,
        packet_size,
        packet_count,
    ):
        self.tx_start_delay = start_delay
        self.tx_packet_size = packet_size
        self.tx_packet_count = packet_count
        self.packet_io = packet_io
        self.packet_io.packet_listener = self
        self.repeat = repeat
        self.repeat_delay = repeat_delay
        self.pace = pace
        self.start_time = 0
        self.bytes_sent = 0
        self.stats = []
        self.done = asyncio.Event()

    def reset(self):
        pass

    async def run(self):
        logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
        await self.packet_io.ready.wait()
        logging.info(color('--- Go!', 'blue'))

        for run in range(self.repeat + 1):
            self.done.clear()

            if run > 0 and self.repeat and self.repeat_delay:
                logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
                await asyncio.sleep(self.repeat_delay)

            if self.tx_start_delay:
                logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
                await asyncio.sleep(self.tx_start_delay)

            logging.info(color('=== Sending RESET', 'magenta'))
            await self.packet_io.send_packet(bytes([PacketType.RESET]))
            self.start_time = time.time()
            self.bytes_sent = 0
            for tx_i in range(self.tx_packet_count):
                packet_flags = (
                    PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
                )
                packet = struct.pack(
                    '>bbI',
                    PacketType.SEQUENCE,
                    packet_flags,
                    tx_i,
                ) + bytes(self.tx_packet_size - 6 - self.packet_io.overhead_size)
                logging.info(
                    color(
                        f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow'
                    )
                )
                self.bytes_sent += len(packet)
                await self.packet_io.send_packet(packet)

                if self.pace is None:
                    continue

                if self.pace > 0:
                    await asyncio.sleep(self.pace / 1000)
                else:
                    await self.packet_io.drain()

            await self.done.wait()

            run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
            logging.info(color(f'=== {run_counter} Done!', 'magenta'))

            if self.repeat:
                log_stats('Run', self.stats)

        if self.repeat:
            logging.info(color('--- End of runs', 'blue'))

    def on_packet_received(self, packet):
        try:
            packet_type, _ = parse_packet(packet)
        except ValueError:
            return

        if packet_type == PacketType.ACK:
            elapsed = time.time() - self.start_time
            average_tx_speed = self.bytes_sent / elapsed
            self.stats.append(average_tx_speed)
            logging.info(
                color(
                    f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}'
                    f' ({self.bytes_sent} bytes in {elapsed:.2f} seconds)',
                    'green',
                )
            )
            self.done.set()


# -----------------------------------------------------------------------------
# Receiver
# -----------------------------------------------------------------------------
class Receiver:
    expected_packet_index: int
    start_timestamp: float
    last_timestamp: float

    def __init__(self, packet_io, linger):
        self.reset()
        self.packet_io = packet_io
        self.packet_io.packet_listener = self
        self.linger = linger
        self.done = asyncio.Event()

    def reset(self):
        self.expected_packet_index = 0
        self.measurements = [(time.time(), 0)]
        self.total_bytes_received = 0

    def on_packet_received(self, packet):
        try:
            packet_type, packet_data = parse_packet(packet)
        except ValueError:
            return

        if packet_type == PacketType.RESET:
            logging.info(color('=== Received RESET', 'magenta'))
            self.reset()
            return

        try:
            packet_flags, packet_index = parse_packet_sequence(packet_data)
        except ValueError:
            return
        logging.info(
            f'<<< Received packet {packet_index}: '
            f'flags=0x{packet_flags:02X}, '
            f'{len(packet) + self.packet_io.overhead_size} bytes'
        )

        if packet_index != self.expected_packet_index:
            logging.info(
                color(
                    f'!!! Unexpected packet, expected {self.expected_packet_index} '
                    f'but received {packet_index}'
                )
            )

        now = time.time()
        elapsed_since_start = now - self.measurements[0][0]
        elapsed_since_last = now - self.measurements[-1][0]
        self.measurements.append((now, len(packet)))
        self.total_bytes_received += len(packet)
        instant_rx_speed = len(packet) / elapsed_since_last
        average_rx_speed = self.total_bytes_received / elapsed_since_start
        window = self.measurements[-64:]
        windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / (
            window[-1][0] - window[0][0]
        )
        logging.info(
            color(
                'Speed: '
                f'instant={instant_rx_speed:.4f}, '
                f'windowed={windowed_rx_speed:.4f}, '
                f'average={average_rx_speed:.4f}',
                'yellow',
            )
        )

        self.expected_packet_index = packet_index + 1

        if packet_flags & PACKET_FLAG_LAST:
            AsyncRunner.spawn(
                self.packet_io.send_packet(
                    struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
                )
            )
            logging.info(color('@@@ Received last packet', 'green'))
            if not self.linger:
                self.done.set()

    async def run(self):
        await self.done.wait()
        logging.info(color('=== Done!', 'magenta'))


# -----------------------------------------------------------------------------
# Ping
# -----------------------------------------------------------------------------
class Ping:
    def __init__(
        self,
        packet_io,
        start_delay,
        repeat,
        repeat_delay,
        pace,
        packet_size,
        packet_count,
    ):
        self.tx_start_delay = start_delay
        self.tx_packet_size = packet_size
        self.tx_packet_count = packet_count
        self.packet_io = packet_io
        self.packet_io.packet_listener = self
        self.repeat = repeat
        self.repeat_delay = repeat_delay
        self.pace = pace
        self.done = asyncio.Event()
        self.current_packet_index = 0
        self.ping_sent_time = 0.0
        self.latencies = []
        self.min_stats = []
        self.max_stats = []
        self.avg_stats = []

    def reset(self):
        pass

    async def run(self):
        logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
        await self.packet_io.ready.wait()
        logging.info(color('--- Go!', 'blue'))

        for run in range(self.repeat + 1):
            self.done.clear()

            if run > 0 and self.repeat and self.repeat_delay:
                logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
                await asyncio.sleep(self.repeat_delay)

            if self.tx_start_delay:
                logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
                await asyncio.sleep(self.tx_start_delay)

            logging.info(color('=== Sending RESET', 'magenta'))
            await self.packet_io.send_packet(bytes([PacketType.RESET]))

            self.current_packet_index = 0
            self.latencies = []
            await self.send_next_ping()

            await self.done.wait()

            min_latency = min(self.latencies)
            max_latency = max(self.latencies)
            avg_latency = sum(self.latencies) / len(self.latencies)
            logging.info(
                color(
                    '@@@ Latencies: '
                    f'min={min_latency:.2f}, '
                    f'max={max_latency:.2f}, '
                    f'average={avg_latency:.2f}'
                )
            )

            self.min_stats.append(min_latency)
            self.max_stats.append(max_latency)
            self.avg_stats.append(avg_latency)

            run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
            logging.info(color(f'=== {run_counter} Done!', 'magenta'))

            if self.repeat:
                log_stats('Min Latency', self.min_stats)
                log_stats('Max Latency', self.max_stats)
                log_stats('Average Latency', self.avg_stats)

        if self.repeat:
            logging.info(color('--- End of runs', 'blue'))

    async def send_next_ping(self):
        if self.pace:
            await asyncio.sleep(self.pace / 1000)

        packet = struct.pack(
            '>bbI',
            PacketType.SEQUENCE,
            (
                PACKET_FLAG_LAST
                if self.current_packet_index == self.tx_packet_count - 1
                else 0
            ),
            self.current_packet_index,
        ) + bytes(self.tx_packet_size - 6)
        logging.info(color(f'Sending packet {self.current_packet_index}', 'yellow'))
        self.ping_sent_time = time.time()
        await self.packet_io.send_packet(packet)

    def on_packet_received(self, packet):
        elapsed = time.time() - self.ping_sent_time

        try:
            packet_type, packet_data = parse_packet(packet)
        except ValueError:
            return

        try:
            packet_flags, packet_index = parse_packet_sequence(packet_data)
        except ValueError:
            return

        if packet_type == PacketType.ACK:
            latency = elapsed * 1000
            self.latencies.append(latency)
            logging.info(
                color(
                    f'<<< Received ACK [{packet_index}], latency={latency:.2f}ms',
                    'green',
                )
            )

            if packet_index == self.current_packet_index:
                self.current_packet_index += 1
            else:
                logging.info(
                    color(
                        f'!!! Unexpected packet, expected {self.current_packet_index} '
                        f'but received {packet_index}'
                    )
                )

        if packet_flags & PACKET_FLAG_LAST:
            self.done.set()
            return

        AsyncRunner.spawn(self.send_next_ping())


# -----------------------------------------------------------------------------
# Pong
# -----------------------------------------------------------------------------
class Pong:
    expected_packet_index: int

    def __init__(self, packet_io, linger):
        self.reset()
        self.packet_io = packet_io
        self.packet_io.packet_listener = self
        self.linger = linger
        self.done = asyncio.Event()

    def reset(self):
        self.expected_packet_index = 0

    def on_packet_received(self, packet):
        try:
            packet_type, packet_data = parse_packet(packet)
        except ValueError:
            return

        if packet_type == PacketType.RESET:
            logging.info(color('=== Received RESET', 'magenta'))
            self.reset()
            return

        try:
            packet_flags, packet_index = parse_packet_sequence(packet_data)
        except ValueError:
            return
        logging.info(
            color(
                f'<<< Received packet {packet_index}: '
                f'flags=0x{packet_flags:02X}, {len(packet)} bytes',
                'green',
            )
        )

        if packet_index != self.expected_packet_index:
            logging.info(
                color(
                    f'!!! Unexpected packet, expected {self.expected_packet_index} '
                    f'but received {packet_index}'
                )
            )

        self.expected_packet_index = packet_index + 1

        AsyncRunner.spawn(
            self.packet_io.send_packet(
                struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
            )
        )

        if packet_flags & PACKET_FLAG_LAST and not self.linger:
            self.done.set()

    async def run(self):
        await self.done.wait()
        logging.info(color('=== Done!', 'magenta'))


# -----------------------------------------------------------------------------
# GattClient
# -----------------------------------------------------------------------------
class GattClient:
    def __init__(self, _device, att_mtu=None):
        self.att_mtu = att_mtu
        self.speed_rx = None
        self.speed_tx = None
        self.packet_listener = None
        self.ready = asyncio.Event()
        self.overhead_size = 0

    async def on_connection(self, connection):
        peer = Peer(connection)

        if self.att_mtu:
            logging.info(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue'))
            await peer.request_mtu(self.att_mtu)

        logging.info(color('*** Discovering services...', 'blue'))
        await peer.discover_services()

        speed_services = peer.get_services_by_uuid(SPEED_SERVICE_UUID)
        if not speed_services:
            logging.info(color('!!! Speed Service not found', 'red'))
            return
        speed_service = speed_services[0]
        logging.info(color('*** Discovering characteristics...', 'blue'))
        await speed_service.discover_characteristics()

        speed_txs = speed_service.get_characteristics_by_uuid(SPEED_TX_UUID)
        if not speed_txs:
            logging.info(color('!!! Speed TX not found', 'red'))
            return
        self.speed_tx = speed_txs[0]

        speed_rxs = speed_service.get_characteristics_by_uuid(SPEED_RX_UUID)
        if not speed_rxs:
            logging.info(color('!!! Speed RX not found', 'red'))
            return
        self.speed_rx = speed_rxs[0]

        logging.info(color('*** Subscribing to RX', 'blue'))
        await self.speed_rx.subscribe(self.on_packet_received)

        logging.info(color('*** Discovery complete', 'blue'))

        connection.on('disconnection', self.on_disconnection)
        self.ready.set()

    def on_disconnection(self, _):
        self.ready.clear()

    def on_packet_received(self, packet):
        if self.packet_listener:
            self.packet_listener.on_packet_received(packet)

    async def send_packet(self, packet):
        await self.speed_tx.write_value(packet)

    async def drain(self):
        pass


# -----------------------------------------------------------------------------
# GattServer
# -----------------------------------------------------------------------------
class GattServer:
    def __init__(self, device):
        self.device = device
        self.packet_listener = None
        self.ready = asyncio.Event()
        self.overhead_size = 0

        # Setup the GATT service
        self.speed_tx = Characteristic(
            SPEED_TX_UUID,
            Characteristic.Properties.WRITE,
            Characteristic.WRITEABLE,
            CharacteristicValue(write=self.on_tx_write),
        )
        self.speed_rx = Characteristic(
            SPEED_RX_UUID, Characteristic.Properties.NOTIFY, 0
        )

        speed_service = Service(
            SPEED_SERVICE_UUID,
            [self.speed_tx, self.speed_rx],
        )
        device.add_services([speed_service])

        self.speed_rx.on('subscription', self.on_rx_subscription)

    async def on_connection(self, connection):
        connection.on('disconnection', self.on_disconnection)

    def on_disconnection(self, _):
        self.ready.clear()

    def on_rx_subscription(self, _connection, notify_enabled, _indicate_enabled):
        if notify_enabled:
            logging.info(color('*** RX subscription', 'blue'))
            self.ready.set()
        else:
            logging.info(color('*** RX un-subscription', 'blue'))
            self.ready.clear()

    def on_tx_write(self, _, value):
        if self.packet_listener:
            self.packet_listener.on_packet_received(value)

    async def send_packet(self, packet):
        await self.device.notify_subscribers(self.speed_rx, packet)

    async def drain(self):
        pass


# -----------------------------------------------------------------------------
# StreamedPacketIO
# -----------------------------------------------------------------------------
class StreamedPacketIO:
    def __init__(self):
        self.packet_listener = None
        self.io_sink = None
        self.rx_packet = b''
        self.rx_packet_header = b''
        self.rx_packet_need = 0
        self.overhead_size = 2

    def on_packet(self, packet):
        while packet:
            if self.rx_packet_need:
                chunk = packet[: self.rx_packet_need]
                self.rx_packet += chunk
                packet = packet[len(chunk) :]
                self.rx_packet_need -= len(chunk)
                if not self.rx_packet_need:
                    # Packet completed
                    if self.packet_listener:
                        self.packet_listener.on_packet_received(self.rx_packet)

                    self.rx_packet = b''
                    self.rx_packet_header = b''
            else:
                # Expect the next packet
                header_bytes_needed = 2 - len(self.rx_packet_header)
                header_bytes = packet[:header_bytes_needed]
                self.rx_packet_header += header_bytes
                if len(self.rx_packet_header) != 2:
                    return
                packet = packet[len(header_bytes) :]
                self.rx_packet_need = struct.unpack('>H', self.rx_packet_header)[0]

    async def send_packet(self, packet):
        if not self.io_sink:
            logging.info(color('!!! No sink, dropping packet', 'red'))
            return

        # pylint: disable-next=not-callable
        self.io_sink(struct.pack('>H', len(packet)) + packet)


# -----------------------------------------------------------------------------
# L2capClient
# -----------------------------------------------------------------------------
class L2capClient(StreamedPacketIO):
    def __init__(
        self,
        _device,
        psm=DEFAULT_L2CAP_PSM,
        max_credits=DEFAULT_L2CAP_MAX_CREDITS,
        mtu=DEFAULT_L2CAP_MTU,
        mps=DEFAULT_L2CAP_MPS,
    ):
        super().__init__()
        self.psm = psm
        self.max_credits = max_credits
        self.mtu = mtu
        self.mps = mps
        self.l2cap_channel = None
        self.ready = asyncio.Event()

    async def on_connection(self, connection: Connection) -> None:
        connection.on('disconnection', self.on_disconnection)

        # Connect a new L2CAP channel
        logging.info(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
        try:
            l2cap_channel = await connection.create_l2cap_channel(
                spec=l2cap.LeCreditBasedChannelSpec(
                    psm=self.psm,
                    max_credits=self.max_credits,
                    mtu=self.mtu,
                    mps=self.mps,
                )
            )
            logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
        except Exception as error:
            logging.info(color(f'!!! Connection failed: {error}', 'red'))
            return

        self.io_sink = l2cap_channel.write
        self.l2cap_channel = l2cap_channel
        l2cap_channel.on('close', self.on_l2cap_close)
        l2cap_channel.sink = self.on_packet

        self.ready.set()

    def on_disconnection(self, _):
        pass

    def on_l2cap_close(self):
        logging.info(color('*** L2CAP channel closed', 'red'))

    async def drain(self):
        assert self.l2cap_channel
        await self.l2cap_channel.drain()


# -----------------------------------------------------------------------------
# L2capServer
# -----------------------------------------------------------------------------
class L2capServer(StreamedPacketIO):
    def __init__(
        self,
        device: Device,
        psm=DEFAULT_L2CAP_PSM,
        max_credits=DEFAULT_L2CAP_MAX_CREDITS,
        mtu=DEFAULT_L2CAP_MTU,
        mps=DEFAULT_L2CAP_MPS,
    ):
        super().__init__()
        self.l2cap_channel = None
        self.ready = asyncio.Event()

        # Listen for incoming L2CAP connections
        device.create_l2cap_server(
            spec=l2cap.LeCreditBasedChannelSpec(
                psm=psm, mtu=mtu, mps=mps, max_credits=max_credits
            ),
            handler=self.on_l2cap_channel,
        )
        logging.info(
            color(f'### Listening for L2CAP connection on PSM {psm}', 'yellow')
        )

    async def on_connection(self, connection):
        connection.on('disconnection', self.on_disconnection)

    def on_disconnection(self, _):
        pass

    def on_l2cap_channel(self, l2cap_channel):
        logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))

        self.io_sink = l2cap_channel.write
        self.l2cap_channel = l2cap_channel
        l2cap_channel.on('close', self.on_l2cap_close)
        l2cap_channel.sink = self.on_packet

        self.ready.set()

    def on_l2cap_close(self):
        logging.info(color('*** L2CAP channel closed', 'red'))
        self.l2cap_channel = None

    async def drain(self):
        assert self.l2cap_channel
        await self.l2cap_channel.drain()


# -----------------------------------------------------------------------------
# RfcommClient
# -----------------------------------------------------------------------------
class RfcommClient(StreamedPacketIO):
    def __init__(
        self,
        device,
        channel,
        uuid,
        l2cap_mtu,
        max_frame_size,
        initial_credits,
        max_credits,
        credits_threshold,
    ):
        super().__init__()
        self.device = device
        self.channel = channel
        self.uuid = uuid
        self.l2cap_mtu = l2cap_mtu
        self.max_frame_size = max_frame_size
        self.initial_credits = initial_credits
        self.max_credits = max_credits
        self.credits_threshold = credits_threshold
        self.rfcomm_session = None
        self.ready = asyncio.Event()

    async def on_connection(self, connection):
        connection.on('disconnection', self.on_disconnection)

        # Find the channel number if not specified
        channel = self.channel
        if channel == 0:
            logging.info(
                color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan')
            )
            channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid(
                connection, self.uuid
            )
            logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
            if channel == 0:
                logging.info(color('!!! No RFComm service with this UUID found', 'red'))
                await connection.disconnect()
                return

        # Create a client and start it
        logging.info(color('*** Starting RFCOMM client...', 'blue'))
        rfcomm_options = {}
        if self.l2cap_mtu:
            rfcomm_options['l2cap_mtu'] = self.l2cap_mtu
        rfcomm_client = bumble.rfcomm.Client(connection, **rfcomm_options)
        rfcomm_mux = await rfcomm_client.start()
        logging.info(color('*** Started', 'blue'))

        logging.info(color(f'### Opening session for channel {channel}...', 'yellow'))
        try:
            dlc_options = {}
            if self.max_frame_size is not None:
                dlc_options['max_frame_size'] = self.max_frame_size
            if self.initial_credits is not None:
                dlc_options['initial_credits'] = self.initial_credits
            rfcomm_session = await rfcomm_mux.open_dlc(channel, **dlc_options)
            logging.info(color(f'### Session open: {rfcomm_session}', 'yellow'))
            if self.max_credits is not None:
                rfcomm_session.rx_max_credits = self.max_credits
            if self.credits_threshold is not None:
                rfcomm_session.rx_credits_threshold = self.credits_threshold

        except bumble.core.ConnectionError as error:
            logging.info(color(f'!!! Session open failed: {error}', 'red'))
            await rfcomm_mux.disconnect()
            return

        rfcomm_session.sink = self.on_packet
        self.io_sink = rfcomm_session.write
        self.rfcomm_session = rfcomm_session

        self.ready.set()

    def on_disconnection(self, _):
        pass

    async def drain(self):
        assert self.rfcomm_session
        await self.rfcomm_session.drain()


# -----------------------------------------------------------------------------
# RfcommServer
# -----------------------------------------------------------------------------
class RfcommServer(StreamedPacketIO):
    def __init__(
        self,
        device,
        channel,
        l2cap_mtu,
        max_frame_size,
        initial_credits,
        max_credits,
        credits_threshold,
    ):
        super().__init__()
        self.max_credits = max_credits
        self.credits_threshold = credits_threshold
        self.dlc = None
        self.ready = asyncio.Event()

        # Create and register a server
        server_options = {}
        if l2cap_mtu:
            server_options['l2cap_mtu'] = l2cap_mtu
        rfcomm_server = bumble.rfcomm.Server(device, **server_options)

        # Listen for incoming DLC connections
        dlc_options = {}
        if max_frame_size is not None:
            dlc_options['max_frame_size'] = max_frame_size
        if initial_credits is not None:
            dlc_options['initial_credits'] = initial_credits
        channel_number = rfcomm_server.listen(self.on_dlc, channel, **dlc_options)

        # Setup the SDP to advertise this channel
        device.sdp_service_records = make_sdp_records(channel_number)

        logging.info(
            color(
                f'### Listening for RFComm connection on channel {channel_number}',
                'yellow',
            )
        )

    async def on_connection(self, connection):
        connection.on('disconnection', self.on_disconnection)

    def on_disconnection(self, _):
        pass

    def on_dlc(self, dlc):
        logging.info(color(f'*** DLC connected: {dlc}', 'blue'))
        if self.credits_threshold is not None:
            dlc.rx_threshold = self.credits_threshold
        if self.max_credits is not None:
            dlc.rx_max_credits = self.max_credits
        dlc.sink = self.on_packet
        self.io_sink = dlc.write
        self.dlc = dlc
        if self.max_credits is not None:
            dlc.rx_max_credits = self.max_credits
        if self.credits_threshold is not None:
            dlc.rx_credits_threshold = self.credits_threshold

    async def drain(self):
        assert self.dlc
        await self.dlc.drain()


# -----------------------------------------------------------------------------
# Central
# -----------------------------------------------------------------------------
class Central(Connection.Listener):
    def __init__(
        self,
        transport,
        peripheral_address,
        classic,
        role_factory,
        mode_factory,
        connection_interval,
        phy,
        authenticate,
        encrypt,
        extended_data_length,
        role_switch,
    ):
        super().__init__()
        self.transport = transport
        self.peripheral_address = peripheral_address
        self.classic = classic
        self.role_factory = role_factory
        self.mode_factory = mode_factory
        self.authenticate = authenticate
        self.encrypt = encrypt or authenticate
        self.extended_data_length = extended_data_length
        self.role_switch = role_switch
        self.device = None
        self.connection = None

        if phy:
            self.phy = {
                '1m': HCI_LE_1M_PHY,
                '2m': HCI_LE_2M_PHY,
                'coded': HCI_LE_CODED_PHY,
            }[phy]
        else:
            self.phy = None

        if connection_interval:
            connection_parameter_preferences = ConnectionParametersPreferences()
            connection_parameter_preferences.connection_interval_min = (
                connection_interval
            )
            connection_parameter_preferences.connection_interval_max = (
                connection_interval
            )

            # Preferences for the 1M PHY are always set.
            self.connection_parameter_preferences = {
                HCI_LE_1M_PHY: connection_parameter_preferences,
            }

            if self.phy not in (None, HCI_LE_1M_PHY):
                # Add an connections parameters entry for this PHY.
                self.connection_parameter_preferences[self.phy] = (
                    connection_parameter_preferences
                )
        else:
            self.connection_parameter_preferences = None

    async def run(self):
        logging.info(color('>>> Connecting to HCI...', 'green'))
        async with await open_transport_or_link(self.transport) as (
            hci_source,
            hci_sink,
        ):
            logging.info(color('>>> Connected', 'green'))

            central_address = DEFAULT_CENTRAL_ADDRESS
            self.device = Device.with_hci(
                DEFAULT_CENTRAL_NAME, central_address, hci_source, hci_sink
            )
            mode = self.mode_factory(self.device)
            role = self.role_factory(mode)
            self.device.classic_enabled = self.classic

            # Set up a pairing config factory with minimal requirements.
            self.device.pairing_config_factory = lambda _: PairingConfig(
                sc=False, mitm=False, bonding=False
            )

            await self.device.power_on()

            if self.classic:
                await self.device.set_discoverable(False)
                await self.device.set_connectable(False)

            logging.info(
                color(f'### Connecting to {self.peripheral_address}...', 'cyan')
            )
            try:
                self.connection = await self.device.connect(
                    self.peripheral_address,
                    connection_parameters_preferences=self.connection_parameter_preferences,
                    transport=BT_BR_EDR_TRANSPORT if self.classic else BT_LE_TRANSPORT,
                )
            except CommandTimeoutError:
                logging.info(color('!!! Connection timed out', 'red'))
                return
            except bumble.core.ConnectionError as error:
                logging.info(color(f'!!! Connection error: {error}', 'red'))
                return
            except HCI_StatusError as error:
                logging.info(color(f'!!! Connection failed: {error.error_name}'))
                return
            logging.info(color('### Connected', 'cyan'))
            self.connection.listener = self
            print_connection(self.connection)

            # Switch roles if needed.
            if self.role_switch:
                await switch_roles(self.connection, self.role_switch)

            # Wait a bit after the connection, some controllers aren't very good when
            # we start sending data right away while some connection parameters are
            # updated post connection
            await asyncio.sleep(DEFAULT_POST_CONNECTION_WAIT_TIME)

            # Request a new data length if requested
            if self.extended_data_length:
                logging.info(color('+++ Requesting extended data length', 'cyan'))
                await self.connection.set_data_length(
                    self.extended_data_length[0], self.extended_data_length[1]
                )

            # Authenticate if requested
            if self.authenticate:
                # Request authentication
                logging.info(color('*** Authenticating...', 'cyan'))
                await self.connection.authenticate()
                logging.info(color('*** Authenticated', 'cyan'))

            # Encrypt if requested
            if self.encrypt:
                # Enable encryption
                logging.info(color('*** Enabling encryption...', 'cyan'))
                await self.connection.encrypt()
                logging.info(color('*** Encryption on', 'cyan'))

            # Set the PHY if requested
            if self.phy is not None:
                try:
                    await self.connection.set_phy(
                        tx_phys=[self.phy], rx_phys=[self.phy]
                    )
                except HCI_Error as error:
                    logging.info(
                        color(
                            f'!!! Unable to set the PHY: {error.error_name}', 'yellow'
                        )
                    )

            await mode.on_connection(self.connection)

            await role.run()
            await asyncio.sleep(DEFAULT_LINGER_TIME)
            await self.connection.disconnect()

    def on_disconnection(self, reason):
        logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
        self.connection = None

    def on_connection_parameters_update(self):
        print_connection(self.connection)

    def on_connection_phy_update(self):
        print_connection(self.connection)

    def on_connection_att_mtu_update(self):
        print_connection(self.connection)

    def on_connection_data_length_change(self):
        print_connection(self.connection)

    def on_role_change(self):
        print_connection(self.connection)


# -----------------------------------------------------------------------------
# Peripheral
# -----------------------------------------------------------------------------
class Peripheral(Device.Listener, Connection.Listener):
    def __init__(
        self,
        transport,
        role_factory,
        mode_factory,
        classic,
        extended_data_length,
        role_switch,
    ):
        self.transport = transport
        self.classic = classic
        self.role_factory = role_factory
        self.mode_factory = mode_factory
        self.extended_data_length = extended_data_length
        self.role_switch = role_switch
        self.role = None
        self.mode = None
        self.device = None
        self.connection = None
        self.connected = asyncio.Event()

    async def run(self):
        logging.info(color('>>> Connecting to HCI...', 'green'))
        async with await open_transport_or_link(self.transport) as (
            hci_source,
            hci_sink,
        ):
            logging.info(color('>>> Connected', 'green'))

            peripheral_address = DEFAULT_PERIPHERAL_ADDRESS
            self.device = Device.with_hci(
                DEFAULT_PERIPHERAL_NAME, peripheral_address, hci_source, hci_sink
            )
            self.device.listener = self
            self.mode = self.mode_factory(self.device)
            self.role = self.role_factory(self.mode)
            self.device.classic_enabled = self.classic

            # Set up a pairing config factory with minimal requirements.
            self.device.pairing_config_factory = lambda _: PairingConfig(
                sc=False, mitm=False, bonding=False
            )

            await self.device.power_on()

            if self.classic:
                await self.device.set_discoverable(True)
                await self.device.set_connectable(True)
            else:
                await self.device.start_advertising(auto_restart=True)

            if self.classic:
                logging.info(
                    color(
                        '### Waiting for connection on'
                        f' {self.device.public_address}...',
                        'cyan',
                    )
                )
            else:
                logging.info(
                    color(
                        f'### Waiting for connection on {peripheral_address}...',
                        'cyan',
                    )
                )

            await self.connected.wait()
            logging.info(color('### Connected', 'cyan'))
            print_connection(self.connection)

            await self.mode.on_connection(self.connection)
            await self.role.run()
            await asyncio.sleep(DEFAULT_LINGER_TIME)

    def on_connection(self, connection):
        connection.listener = self
        self.connection = connection
        self.connected.set()

        # Stop being discoverable and connectable
        if self.classic:
            AsyncRunner.spawn(self.device.set_discoverable(False))
            AsyncRunner.spawn(self.device.set_connectable(False))

        # Request a new data length if needed
        if not self.classic and self.extended_data_length:
            logging.info("+++ Requesting extended data length")
            AsyncRunner.spawn(
                connection.set_data_length(
                    self.extended_data_length[0], self.extended_data_length[1]
                )
            )

        # Switch roles if needed.
        if self.role_switch:
            AsyncRunner.spawn(switch_roles(connection, self.role_switch))

    def on_disconnection(self, reason):
        logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
        self.connection = None
        self.role.reset()

        if self.classic:
            AsyncRunner.spawn(self.device.set_discoverable(True))
            AsyncRunner.spawn(self.device.set_connectable(True))

    def on_connection_parameters_update(self):
        print_connection(self.connection)

    def on_connection_phy_update(self):
        print_connection(self.connection)

    def on_connection_att_mtu_update(self):
        print_connection(self.connection)

    def on_connection_data_length_change(self):
        print_connection(self.connection)

    def on_role_change(self):
        print_connection(self.connection)


# -----------------------------------------------------------------------------
def create_mode_factory(ctx, default_mode):
    mode = ctx.obj['mode']
    if mode is None:
        mode = default_mode

    def create_mode(device):
        if mode == 'gatt-client':
            return GattClient(device, att_mtu=ctx.obj['att_mtu'])

        if mode == 'gatt-server':
            return GattServer(device)

        if mode == 'l2cap-client':
            return L2capClient(
                device,
                psm=ctx.obj['l2cap_psm'],
                mtu=ctx.obj['l2cap_mtu'],
                mps=ctx.obj['l2cap_mps'],
                max_credits=ctx.obj['l2cap_max_credits'],
            )

        if mode == 'l2cap-server':
            return L2capServer(
                device,
                psm=ctx.obj['l2cap_psm'],
                mtu=ctx.obj['l2cap_mtu'],
                mps=ctx.obj['l2cap_mps'],
                max_credits=ctx.obj['l2cap_max_credits'],
            )

        if mode == 'rfcomm-client':
            return RfcommClient(
                device,
                channel=ctx.obj['rfcomm_channel'],
                uuid=ctx.obj['rfcomm_uuid'],
                l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
                max_frame_size=ctx.obj['rfcomm_max_frame_size'],
                initial_credits=ctx.obj['rfcomm_initial_credits'],
                max_credits=ctx.obj['rfcomm_max_credits'],
                credits_threshold=ctx.obj['rfcomm_credits_threshold'],
            )

        if mode == 'rfcomm-server':
            return RfcommServer(
                device,
                channel=ctx.obj['rfcomm_channel'],
                l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
                max_frame_size=ctx.obj['rfcomm_max_frame_size'],
                initial_credits=ctx.obj['rfcomm_initial_credits'],
                max_credits=ctx.obj['rfcomm_max_credits'],
                credits_threshold=ctx.obj['rfcomm_credits_threshold'],
            )

        raise ValueError('invalid mode')

    return create_mode


# -----------------------------------------------------------------------------
def create_role_factory(ctx, default_role):
    role = ctx.obj['role']
    if role is None:
        role = default_role

    def create_role(packet_io):
        if role == 'sender':
            return Sender(
                packet_io,
                start_delay=ctx.obj['start_delay'],
                repeat=ctx.obj['repeat'],
                repeat_delay=ctx.obj['repeat_delay'],
                pace=ctx.obj['pace'],
                packet_size=ctx.obj['packet_size'],
                packet_count=ctx.obj['packet_count'],
            )

        if role == 'receiver':
            return Receiver(packet_io, ctx.obj['linger'])

        if role == 'ping':
            return Ping(
                packet_io,
                start_delay=ctx.obj['start_delay'],
                repeat=ctx.obj['repeat'],
                repeat_delay=ctx.obj['repeat_delay'],
                pace=ctx.obj['pace'],
                packet_size=ctx.obj['packet_size'],
                packet_count=ctx.obj['packet_count'],
            )

        if role == 'pong':
            return Pong(packet_io, ctx.obj['linger'])

        raise ValueError('invalid role')

    return create_role


# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
@click.group()
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
@click.option('--role', type=click.Choice(['sender', 'receiver', 'ping', 'pong']))
@click.option(
    '--mode',
    type=click.Choice(
        [
            'gatt-client',
            'gatt-server',
            'l2cap-client',
            'l2cap-server',
            'rfcomm-client',
            'rfcomm-server',
        ]
    ),
)
@click.option(
    '--att-mtu',
    metavar='MTU',
    type=click.IntRange(23, 517),
    help='GATT MTU (gatt-client mode)',
)
@click.option(
    '--extended-data-length',
    help='Request a data length upon connection, specified as tx_octets/tx_time',
)
@click.option(
    '--role-switch',
    type=click.Choice(['central', 'peripheral']),
    help='Request role switch upon connection (central or peripheral)',
)
@click.option(
    '--rfcomm-channel',
    type=int,
    default=DEFAULT_RFCOMM_CHANNEL,
    help='RFComm channel to use',
)
@click.option(
    '--rfcomm-uuid',
    default=DEFAULT_RFCOMM_UUID,
    help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)',
)
@click.option(
    '--rfcomm-l2cap-mtu',
    type=int,
    help='RFComm L2CAP MTU',
)
@click.option(
    '--rfcomm-max-frame-size',
    type=int,
    help='RFComm maximum frame size',
)
@click.option(
    '--rfcomm-initial-credits',
    type=int,
    help='RFComm initial credits',
)
@click.option(
    '--rfcomm-max-credits',
    type=int,
    help='RFComm max credits',
)
@click.option(
    '--rfcomm-credits-threshold',
    type=int,
    help='RFComm credits threshold',
)
@click.option(
    '--l2cap-psm',
    type=int,
    default=DEFAULT_L2CAP_PSM,
    help='L2CAP PSM to use',
)
@click.option(
    '--l2cap-mtu',
    type=int,
    default=DEFAULT_L2CAP_MTU,
    help='L2CAP MTU to use',
)
@click.option(
    '--l2cap-mps',
    type=int,
    default=DEFAULT_L2CAP_MPS,
    help='L2CAP MPS to use',
)
@click.option(
    '--l2cap-max-credits',
    type=int,
    default=DEFAULT_L2CAP_MAX_CREDITS,
    help='L2CAP maximum number of credits allowed for the peer',
)
@click.option(
    '--packet-size',
    '-s',
    metavar='SIZE',
    type=click.IntRange(8, 8192),
    default=500,
    help='Packet size (client or ping role)',
)
@click.option(
    '--packet-count',
    '-c',
    metavar='COUNT',
    type=int,
    default=10,
    help='Packet count (client or ping role)',
)
@click.option(
    '--start-delay',
    '-sd',
    metavar='SECONDS',
    type=int,
    default=1,
    help='Start delay (client or ping role)',
)
@click.option(
    '--repeat',
    metavar='N',
    type=int,
    default=0,
    help=(
        'Repeat the run N times (client and ping roles)'
        '(0, which is the fault, to run just once) '
    ),
)
@click.option(
    '--repeat-delay',
    metavar='SECONDS',
    type=int,
    default=1,
    help=('Delay, in seconds, between repeats'),
)
@click.option(
    '--pace',
    metavar='MILLISECONDS',
    type=int,
    default=0,
    help=(
        'Wait N milliseconds between packets '
        '(0, which is the fault, to send as fast as possible) '
    ),
)
@click.option(
    '--linger',
    is_flag=True,
    help="Don't exit at the end of a run (server and pong roles)",
)
@click.pass_context
def bench(
    ctx,
    device_config,
    role,
    mode,
    att_mtu,
    extended_data_length,
    role_switch,
    packet_size,
    packet_count,
    start_delay,
    repeat,
    repeat_delay,
    pace,
    linger,
    rfcomm_channel,
    rfcomm_uuid,
    rfcomm_l2cap_mtu,
    rfcomm_max_frame_size,
    rfcomm_initial_credits,
    rfcomm_max_credits,
    rfcomm_credits_threshold,
    l2cap_psm,
    l2cap_mtu,
    l2cap_mps,
    l2cap_max_credits,
):
    ctx.ensure_object(dict)
    ctx.obj['device_config'] = device_config
    ctx.obj['role'] = role
    ctx.obj['mode'] = mode
    ctx.obj['att_mtu'] = att_mtu
    ctx.obj['rfcomm_channel'] = rfcomm_channel
    ctx.obj['rfcomm_uuid'] = rfcomm_uuid
    ctx.obj['rfcomm_l2cap_mtu'] = rfcomm_l2cap_mtu
    ctx.obj['rfcomm_max_frame_size'] = rfcomm_max_frame_size
    ctx.obj['rfcomm_initial_credits'] = rfcomm_initial_credits
    ctx.obj['rfcomm_max_credits'] = rfcomm_max_credits
    ctx.obj['rfcomm_credits_threshold'] = rfcomm_credits_threshold
    ctx.obj['l2cap_psm'] = l2cap_psm
    ctx.obj['l2cap_mtu'] = l2cap_mtu
    ctx.obj['l2cap_mps'] = l2cap_mps
    ctx.obj['l2cap_max_credits'] = l2cap_max_credits
    ctx.obj['packet_size'] = packet_size
    ctx.obj['packet_count'] = packet_count
    ctx.obj['start_delay'] = start_delay
    ctx.obj['repeat'] = repeat
    ctx.obj['repeat_delay'] = repeat_delay
    ctx.obj['pace'] = pace
    ctx.obj['linger'] = linger
    ctx.obj['extended_data_length'] = (
        [int(x) for x in extended_data_length.split('/')]
        if extended_data_length
        else None
    )
    ctx.obj['role_switch'] = role_switch
    ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server')


@bench.command()
@click.argument('transport')
@click.option(
    '--peripheral',
    'peripheral_address',
    metavar='ADDRESS_OR_NAME',
    default=DEFAULT_PERIPHERAL_ADDRESS,
    help='Address or name to connect to',
)
@click.option(
    '--connection-interval',
    '--ci',
    metavar='CONNECTION_INTERVAL',
    type=int,
    help='Connection interval (in ms)',
)
@click.option('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use')
@click.option('--authenticate', is_flag=True, help='Authenticate (RFComm only)')
@click.option('--encrypt', is_flag=True, help='Encrypt the connection (RFComm only)')
@click.pass_context
def central(
    ctx, transport, peripheral_address, connection_interval, phy, authenticate, encrypt
):
    """Run as a central (initiates the connection)"""
    role_factory = create_role_factory(ctx, 'sender')
    mode_factory = create_mode_factory(ctx, 'gatt-client')
    classic = ctx.obj['classic']

    async def run_central():
        await Central(
            transport,
            peripheral_address,
            classic,
            role_factory,
            mode_factory,
            connection_interval,
            phy,
            authenticate,
            encrypt or authenticate,
            ctx.obj['extended_data_length'],
            ctx.obj['role_switch'],
        ).run()

    asyncio.run(run_central())


@bench.command()
@click.argument('transport')
@click.pass_context
def peripheral(ctx, transport):
    """Run as a peripheral (waits for a connection)"""
    role_factory = create_role_factory(ctx, 'receiver')
    mode_factory = create_mode_factory(ctx, 'gatt-server')

    async def run_peripheral():
        await Peripheral(
            transport,
            role_factory,
            mode_factory,
            ctx.obj['classic'],
            ctx.obj['extended_data_length'],
            ctx.obj['role_switch'],
        ).run()

    asyncio.run(run_peripheral())


def main():
    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
    bench()


# -----------------------------------------------------------------------------
if __name__ == "__main__":
    main()  # pylint: disable=no-value-for-parameter
