#!/usr/bin/env python3
# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests the pw_rpc.console_tools.console module."""

import types
import unittest

import pw_status

from pw_protobuf_compiler import python_protos
import pw_rpc
from pw_rpc import callback_client
from pw_rpc.console_tools.console import (
    CommandHelper,
    Context,
    ClientInfo,
    alias_deprecated_command,
)


class TestCommandHelper(unittest.TestCase):
    def setUp(self) -> None:
        self._commands = {'command_a': 'A', 'command_B': 'B'}
        self._variables = {'hello': 1, 'world': 2}
        self._helper = CommandHelper(
            self._commands, self._variables, 'The header', 'The footer'
        )

    def test_help_contents(self) -> None:
        help_contents = self._helper.help()

        self.assertTrue(help_contents.startswith('The header'))
        self.assertIn('The footer', help_contents)

        for var_name in self._variables:
            self.assertIn(var_name, help_contents)

        for cmd_name in self._commands:
            self.assertIn(cmd_name, help_contents)

    def test_repr_is_help(self):
        self.assertEqual(repr(self._helper), self._helper.help())


_PROTO = """\
syntax = "proto3";

package the.pkg;

message SomeMessage {
  uint32 magic_number = 1;

    message AnotherMessage {
      string payload = 1;
    }

}

service Service {
  rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage);
}
"""


class TestConsoleContext(unittest.TestCase):
    """Tests console_tools.console.Context."""

    def setUp(self) -> None:
        self._protos = python_protos.Library.from_strings(_PROTO)

        self._info = ClientInfo(
            'the_client',
            object(),
            pw_rpc.Client.from_modules(
                callback_client.Impl(),
                [
                    pw_rpc.Channel(1, lambda _: None),
                    pw_rpc.Channel(2, lambda _: None),
                ],
                self._protos.modules(),
            ),
        )

    def test_sets_expected_variables(self) -> None:
        variables = Context(
            [self._info], default_client=self._info.client, protos=self._protos
        ).variables()

        self.assertIn('set_target', variables)

        self.assertIsInstance(variables['help'], CommandHelper)
        self.assertIs(variables['python_help'], help)
        self.assertIs(pw_status.Status, variables['Status'])
        self.assertIs(self._info.client, variables['the_client'])

    def test_set_target_switches_between_clients(self) -> None:
        client_1_channel = self._info.rpc_client.channel(1).channel

        client_2_channel = pw_rpc.Channel(99, lambda _: None)
        info_2 = ClientInfo(
            'other_client',
            object(),
            pw_rpc.Client.from_modules(
                callback_client.Impl(),
                [client_2_channel],
                self._protos.modules(),
            ),
        )

        context = Context(
            [self._info, info_2],
            default_client=self._info.client,
            protos=self._protos,
        )

        # Make sure the RPC service switches from one client to the other.
        self.assertIs(
            context.variables()['the'].pkg.Service.Unary.channel,
            client_1_channel,
        )

        context.set_target(info_2.client)

        self.assertIs(
            context.variables()['the'].pkg.Service.Unary.channel,
            client_2_channel,
        )

    def test_default_client_must_be_in_clients(self) -> None:
        with self.assertRaises(ValueError):
            Context(
                [self._info],
                default_client='something else',
                protos=self._protos,
            )

    def test_set_target_invalid_channel(self) -> None:
        context = Context(
            [self._info], default_client=self._info.client, protos=self._protos
        )

        with self.assertRaises(KeyError):
            context.set_target(self._info.client, 100)

    def test_set_target_non_default_channel(self) -> None:
        channel_1 = self._info.rpc_client.channel(1).channel
        channel_2 = self._info.rpc_client.channel(2).channel

        context = Context(
            [self._info], default_client=self._info.client, protos=self._protos
        )
        variables = context.variables()

        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1)

        context.set_target(self._info.client, 2)

        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2)

        with self.assertRaises(KeyError):
            context.set_target(self._info.client, 100)

    def test_set_target_requires_client_object(self) -> None:
        context = Context(
            [self._info], default_client=self._info.client, protos=self._protos
        )

        with self.assertRaises(ValueError):
            context.set_target(self._info.rpc_client)

        context.set_target(self._info.client)

    def test_derived_context(self) -> None:
        called_derived_set_target = False

        class DerivedContext(Context):
            def set_target(
                self,
                unused_selected_client,
                unused_channel_id: int | None = None,
            ) -> None:
                nonlocal called_derived_set_target
                called_derived_set_target = True

        variables = DerivedContext(
            client_info=[self._info],
            default_client=self._info.client,
            protos=self._protos,
        ).variables()
        variables['set_target'](self._info.client)
        self.assertTrue(called_derived_set_target)


class TestAliasDeprecatedCommand(unittest.TestCase):
    def test_wraps_command_to_new_package(self) -> None:
        variables = {'abc': types.SimpleNamespace(command=lambda: 123)}
        alias_deprecated_command(variables, 'xyz.one.two.three', 'abc.command')

        self.assertEqual(variables['xyz'].one.two.three(), 123)

    def test_wraps_command_to_existing_package(self) -> None:
        variables = {
            'abc': types.SimpleNamespace(NewCmd=lambda: 456),
            'one': types.SimpleNamespace(),
        }
        alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')

        self.assertEqual(variables['one'].two.OldCmd(), 456)

    def test_error_if_new_command_does_not_exist(self) -> None:
        variables = {
            'abc': types.SimpleNamespace(),
            'one': types.SimpleNamespace(),
        }

        with self.assertRaises(AttributeError):
            alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')


if __name__ == '__main__':
    unittest.main()
