# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Security grpc interface."""

import asyncio
import logging
from typing import AsyncGenerator
from typing import AsyncIterator

from floss.pandora.floss import adapter_client
from floss.pandora.floss import floss_enums
from floss.pandora.floss import utils
from floss.pandora.server import bluetooth as bluetooth_module
from google.protobuf import empty_pb2
from google.protobuf import wrappers_pb2
import grpc
from pandora import security_grpc_aio
from pandora import security_pb2


class SecurityService(security_grpc_aio.SecurityServicer):
    """Service to trigger Bluetooth Host security pairing procedures.

    This class implements the Pandora bluetooth test interfaces,
    where the meta class definition is automatically generated by the protobuf.
    The interface definition can be found in:
    https://cs.android.com/android/platform/superproject/+/main:external
    /pandora/bt-test-interfaces/pandora/security.proto
    """

    def __init__(self, bluetooth: bluetooth_module.Bluetooth):
        self.bluetooth = bluetooth
        self.manually_confirm = False
        self.on_pairing_count = 0

        class PairingObserver(adapter_client.BluetoothCallbacks):
            """Observer to observe pairing events."""

            def __init__(self, client: adapter_client, security: security_grpc_aio.SecurityServicer):
                self.client = client
                self.security = security

            @utils.glib_callback()
            def on_ssp_request(self, remote_device, class_of_device, variant, passkey):
                if self.security.manually_confirm:
                    return

                logging.info("Security: on_ssp_request variant: %s passkey: %s", variant, passkey)
                address, _ = remote_device

                if variant in (floss_enums.PairingVariant.CONSENT, floss_enums.PairingVariant.PASSKEY_CONFIRMATION):
                    self.client.set_pairing_confirmation(address,
                                                         True,
                                                         method_callback=self.on_set_pairing_confirmation)

            @utils.glib_callback()
            def on_set_pairing_confirmation(self, err, result):
                if err or not result:
                    logging.info('Security: on_set_pairing_confirmation failed. err: %s result: %s', err, result)

        observer = PairingObserver(self.bluetooth.adapter_client, self)
        name = utils.create_observer_name(observer)
        self.bluetooth.adapter_client.register_callback_observer(name, observer)
        self.pairing_observer = observer

    async def wait_le_security_level(self, level, address):

        class BondingObserver(adapter_client.BluetoothCallbacks):
            """Observer to observe the bond state."""

            def __init__(self, task):
                self.task = task

            @utils.glib_callback()
            def on_bond_state_changed(self, status, address, state):
                if address != self.task['address']:
                    return

                future = self.task['wait_bond']
                if status != floss_enums.BtStatus.SUCCESS:
                    future.get_loop().call_soon_threadsafe(future.set_result, (False, f'Status: {status}'))
                    return

                if state == floss_enums.BondState.BONDED:
                    future.get_loop().call_soon_threadsafe(future.set_result, (True, None))
                elif state == floss_enums.BondState.NOT_BONDED:
                    future.get_loop().call_soon_threadsafe(future.set_result,
                                                           (False, f'Status: {status}, State: {state}'))

        if level == security_pb2.LE_LEVEL1:
            return True
        if level == security_pb2.LE_LEVEL4:
            logging.error('wait_le_security_level: Low-energy level 4 not supported.')
            return False

        if self.bluetooth.is_bonded(address):
            is_bonded = True
        else:
            try:
                wait_bond = asyncio.get_running_loop().create_future()
                observer = BondingObserver({'wait_bond': wait_bond, 'address': address})
                name = utils.create_observer_name(observer)
                self.bluetooth.adapter_client.register_callback_observer(name, observer)
                is_bonded, reason = await wait_bond
                if not is_bonded:
                    logging.error('Failed to bond to the address: %s, reason: %s', address, reason)
            finally:
                self.bluetooth.adapter_client.unregister_callback_observer(name, observer)

        is_encrypted = self.bluetooth.is_encrypted(address)
        if level == security_pb2.LE_LEVEL2:
            return is_encrypted
        if level == security_pb2.LE_LEVEL3:
            return is_encrypted and is_bonded

        logging.error('wait_le_security_level: Invalid security level %s.', level)
        return False

    async def wait_classic_security_level(self, level, address):

        class BondingObserver(adapter_client.BluetoothCallbacks):
            """Observer to observe the bond state"""

            def __init__(self, task):
                self.task = task

            @utils.glib_callback()
            def on_bond_state_changed(self, status, address, state):
                if address != self.task['address']:
                    return

                future = self.task['wait_bond']
                if status != floss_enums.BtStatus.SUCCESS:
                    future.get_loop().call_soon_threadsafe(future.set_result, (False, f'Status: {status}'))
                    return

                if state == floss_enums.BondState.BONDED:
                    future.get_loop().call_soon_threadsafe(future.set_result, (True, None))
                elif state == floss_enums.BondState.NOT_BONDED:
                    future.get_loop().call_soon_threadsafe(future.set_result,
                                                           (False, f'Status: {status}, State: {state}'))

        if level == security_pb2.LEVEL0:
            return True
        if level == security_pb2.LEVEL3:
            logging.error('wait_classic_security_level: Classic level 3 not supported')
            return False

        if self.bluetooth.is_bonded(address):
            is_bonded = True
        else:
            try:
                wait_bond = asyncio.get_running_loop().create_future()
                observer = BondingObserver({'wait_bond': wait_bond, 'address': address})
                name = utils.create_observer_name(observer)
                self.bluetooth.adapter_client.register_callback_observer(name, observer)
                is_bonded, reason = await wait_bond
                if not is_bonded:
                    logging.error('Failed to bond to the address: %s, reason: %s', address, reason)
            finally:
                self.bluetooth.adapter_client.unregister_callback_observer(name, observer)

        is_encrypted = self.bluetooth.is_encrypted(address)
        if level == security_pb2.LEVEL1:
            return not is_encrypted or is_bonded
        if level == security_pb2.LEVEL2:
            return is_encrypted and is_bonded
        return False

    async def OnPairing(self, request: AsyncIterator[security_pb2.PairingEventAnswer],
                        context: grpc.ServicerContext) -> AsyncGenerator[security_pb2.PairingEvent, None]:
        logging.info('OnPairing')
        on_pairing_id = self.on_pairing_count
        self.on_pairing_count = self.on_pairing_count + 1

        class PairingObserver(adapter_client.BluetoothCallbacks):
            """Observer to observe all pairing events."""

            def __init__(self, loop: asyncio.AbstractEventLoop, task):
                self.loop = loop
                self.task = task

            @utils.glib_callback()
            def on_ssp_request(self, remote_device, class_of_device, variant, passkey):
                address, name = remote_device

                result = (address, name, variant, passkey)
                asyncio.run_coroutine_threadsafe(self.task['pairing_events'].put(result), self.loop)

            @utils.glib_callback()
            def on_pin_request(self, remote_device, cod, min_16_digit):
                address, name = remote_device

                if min_16_digit:
                    variant = floss_enums.PairingVariant.PIN_16_DIGITS_ENTRY
                else:
                    variant = floss_enums.PairingVariant.PIN_ENTRY
                result = (address, name, variant, min_16_digit)
                asyncio.run_coroutine_threadsafe(self.task['pairing_events'].put(result), self.loop)

            @utils.glib_callback()
            def on_pin_display(self, remote_device, pincode):
                address, name = remote_device

                variant = floss_enums.PairingVariant.PIN_NOTIFICATION
                result = (address, name, variant, pincode)
                asyncio.run_coroutine_threadsafe(self.task['pairing_events'].put(result), self.loop)

        pairing_answers = request

        async def streaming_answers(self):
            while True:
                nonlocal pairing_answers
                nonlocal on_pairing_id

                logging.info('OnPairing[%s]: Wait for pairing answer...', on_pairing_id)
                pairing_answer = await utils.anext(pairing_answers)

                answer = pairing_answer.WhichOneof('answer')
                address = utils.address_from(pairing_answer.event.address)
                logging.info('OnPairing[%s]: Pairing answer: %s address: %s', on_pairing_id, answer, address)

                if answer == 'confirm':
                    self.bluetooth.set_pairing_confirmation(address, True)
                elif answer == 'passkey':
                    self.bluetooth.set_pin(address, True, list(str(answer.passkey).zfill(6).encode()))
                elif answer == 'pin':
                    self.bluetooth.set_pin(address, True, list(answer.pin))

        observers = []
        try:
            self.manually_confirm = True

            pairing_events = asyncio.Queue()
            observer = PairingObserver(asyncio.get_running_loop(), {'pairing_events': pairing_events})
            name = utils.create_observer_name(observer)
            self.bluetooth.adapter_client.register_callback_observer(name, observer)
            observers.append((name, observer))

            streaming_answers_task = asyncio.create_task(streaming_answers(self))

            while True:
                logging.info('OnPairing[%s]: Wait for pairing events...', on_pairing_id)
                address, name, variant, *variables = await pairing_events.get()
                logging.info('OnPairing[%s]: Pairing event: address: %s, name: %s, variant: %s, variables: %s',
                             on_pairing_id, address, name, variant, variables)

                event = security_pb2.PairingEvent()
                event.address = utils.address_to(address)

                # SSP
                if variant == floss_enums.PairingVariant.PASSKEY_CONFIRMATION:
                    [passkey] = variables
                    event.numeric_comparison = passkey
                elif variant == floss_enums.PairingVariant.PASSKEY_ENTRY:
                    event.passkey_entry_request.CopyFrom(empty_pb2.Empty())
                elif variant == floss_enums.PairingVariant.CONSENT:
                    event.just_works.CopyFrom(empty_pb2.Empty())
                elif variant == floss_enums.PairingVariant.PASSKEY_NOTIFICATION:
                    [passkey] = variables
                    event.passkey_entry_notification = passkey
                # Legacy
                elif variant == floss_enums.PairingVariant.PIN_ENTRY:
                    transport = self.bluetooth.get_remote_type(address)

                    if transport == floss_enums.BtTransport.BREDR:
                        event.pin_code_request.CopyFrom(empty_pb2.Empty())
                    elif transport == floss_enums.BtTransport.LE:
                        event.passkey_entry_request.CopyFrom(empty_pb2.Empty())
                    else:
                        logging.error('Cannot determine pairing variant from unknown transport.')
                        continue
                elif variant == floss_enums.PairingVariant.PIN_16_DIGITS_ENTRY:
                    event.pin_code_request.CopyFrom(empty_pb2.Empty())
                elif variant == floss_enums.PairingVarint.PIN_NOTIFICATION:
                    transport = self.bluetooth.get_remote_type(address)
                    [pincode] = variables

                    if transport == floss_enums.BtTransport.BREDR:
                        event.pin_code_notification = pincode.encode()
                    elif transport == floss_enums.BtTransport.LE:
                        event.passkey_entry_notification = int(pincode)
                    else:
                        logging.error('Cannot determine pairing variant from unknown transport.')
                        continue
                else:
                    logging.error('Unknown pairing variant: %s', variant)
                    continue

                yield event
        finally:
            streaming_answers_task.cancel()
            for name, observer in observers:
                self.bluetooth.adapter_client.unregister_callback_observer(name, observer)

            pairing_events = None
            pairing_answers = None

    async def Secure(self, request: security_pb2.SecureRequest,
                     context: grpc.ServicerContext) -> security_pb2.SecureResponse:
        connection = utils.connection_from(request.connection)
        address = connection.address
        transport = connection.transport

        if transport == floss_enums.BtTransport.LE:
            if not request.HasField('le'):
                await context.abort(grpc.StatusCode.INVALID_ARGUMENT, 'Request le field must be set.')
            if request.le == security_pb2.LE_LEVEL1:
                security_level_reached = True
            elif request.le == security_pb2.LE_LEVEL4:
                await context.abort(grpc.StatusCode.INVALID_ARGUMENT, 'Low-energy security level 4 is not supported.')
            else:
                if not self.bluetooth.is_bonded(address):
                    self.bluetooth.create_bond(address, transport)
                security_level_reached = await self.wait_le_security_level(request.le, address)
        elif transport == floss_enums.BtTransport.BREDR:
            if not request.HasField('classic'):
                await context.abort(grpc.StatusCode.INVALID_ARGUMENT, 'Request classic field must be set.')
            if request.classic == security_pb2.LEVEL0:
                security_level_reached = True
            elif request.classic >= security_pb2.LEVEL3:
                await context.abort(grpc.StatusCode.INVALID_ARGUMENT,
                                    'Classic security level up to 3 is not supported.')
            else:
                if not self.bluetooth.is_bonded(address):
                    self.bluetooth.create_bond(address, transport)
                security_level_reached = await self.wait_classic_security_level(request.classic, address)
        else:
            await context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'Invalid bluetooth transport type: {transport}.')

        secure_response = security_pb2.SecureResponse()
        if security_level_reached:
            secure_response.success.CopyFrom(empty_pb2.Empty())
        else:
            secure_response.not_reached.CopyFrom(empty_pb2.Empty())
        return secure_response

    async def WaitSecurity(self, request: security_pb2.WaitSecurityRequest,
                           context: grpc.ServicerContext) -> security_pb2.WaitSecurityResponse:
        address = utils.connection_from(request.connection).address
        transport = floss_enums.BtTransport.BREDR if request.HasField('classic') else floss_enums.BtTransport.LE

        if transport == floss_enums.BtTransport.LE:
            security_level_reached = await self.wait_le_security_level(request.le, address)
        elif transport == floss_enums.BtTransport.BREDR:
            security_level_reached = await self.wait_classic_security_level(request.classic, address)
        else:
            await context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'Invalid bluetooth transport type: {transport}.')

        wait_security_response = security_pb2.WaitSecurityResponse()
        if security_level_reached:
            wait_security_response.success.CopyFrom(empty_pb2.Empty())
        else:
            wait_security_response.pairing_failure.CopyFrom(empty_pb2.Empty())
        return wait_security_response


class SecurityStorageService(security_grpc_aio.SecurityStorageServicer):
    """Service to trigger Bluetooth Host security persistent storage procedures.

    This class implements the Pandora bluetooth test interfaces,
    where the meta class definition is automatically generated by the protobuf.
    The interface definition can be found in:
    https://cs.android.com/android/platform/superproject/+/main:external
    /pandora/bt-test-interfaces/pandora/security.proto
    """

    def __init__(self, bluetooth: bluetooth_module.Bluetooth):
        self.bluetooth = bluetooth

    async def IsBonded(self, request: security_pb2.IsBondedRequest,
                       context: grpc.ServicerContext) -> wrappers_pb2.BoolValue:

        address = utils.address_from(request.address)
        is_bonded = self.bluetooth.is_bonded(address)
        return wrappers_pb2.BoolValue(value=is_bonded)

    async def DeleteBond(self, request: security_pb2.DeleteBondRequest,
                         context: grpc.ServicerContext) -> empty_pb2.Empty:

        class BondingObserver(adapter_client.BluetoothCallbacks):
            """Observer to observe the bond state"""

            def __init__(self, task):
                self.task = task

            @utils.glib_callback()
            def on_bond_state_changed(self, status, address, state):
                if address != self.task['address']:
                    return

                future = self.task['remove_bond']
                if status != 0:
                    future.get_loop().call_soon_threadsafe(future.set_result,
                                                           (False, f'{address} failed to remove bond. Status: {status},'
                                                            f' State: {state}'))
                    return

                if state == floss_enums.BondState.NOT_BONDED:
                    future.get_loop().call_soon_threadsafe(future.set_result, (True, None))
                else:
                    future.get_loop().call_soon_threadsafe(
                        future.set_result, (False, f'{address} failed on remove_bond, got bond state {state},'
                                            f' want {floss_enums.BondState.NOT_BONDED}'))

        address = utils.address_from(request.address)
        if not self.bluetooth.is_bonded(address):
            return empty_pb2.Empty()
        try:
            remove_bond = asyncio.get_running_loop().create_future()
            observer = BondingObserver({'remove_bond': remove_bond, 'address': address})
            name = utils.create_observer_name(observer)
            self.bluetooth.adapter_client.register_callback_observer(name, observer)
            self.bluetooth.remove_bond(address)
            success, reason = await remove_bond
            if not success:
                await context.abort(grpc.StatusCode.INVALID_ARGUMENT,
                                    f'Failed to remove bond of address: {address}. Reason: {reason}.')
        finally:
            self.bluetooth.adapter_client.unregister_callback_observer(name, observer)
        return empty_pb2.Empty()
