# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tink.testing.python.jwt_service."""

from absl.testing import absltest
import grpc

from tink import jwt

from protos import testing_api_pb2
import jwt_service
import services


class DummyServicerContext(grpc.ServicerContext):

  def is_active(self):
    pass

  def time_remaining(self):
    pass

  def cancel(self):
    pass

  def add_callback(self, callback):
    pass

  def invocation_metadata(self):
    pass

  def peer(self):
    pass

  def peer_identities(self):
    pass

  def peer_identity_key(self):
    pass

  def auth_context(self):
    pass

  def set_compression(self, compression):
    pass

  def send_initial_metadata(self, initial_metadata):
    pass

  def set_trailing_metadata(self, trailing_metadata):
    pass

  def abort(self, code, details):
    pass

  def abort_with_status(self, status):
    pass

  def set_code(self, code):
    pass

  def set_details(self, details):
    pass

  def disable_next_message_compression(self):
    pass


class JwtServiceTest(absltest.TestCase):

  _ctx = DummyServicerContext()

  @classmethod
  def setUpClass(cls):
    super().setUpClass()
    jwt.register_jwt_mac()
    jwt.register_jwt_signature()

  def test_create_jwt_mac(self):
    keyset_servicer = services.KeysetServicer()
    jwt_servicer = jwt_service.JwtServicer()

    template = jwt.jwt_hs256_template().SerializeToString()
    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')

    creation_request = testing_api_pb2.CreationRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=gen_response.keyset))
    creation_response = jwt_servicer.CreateJwtMac(
        creation_request, self._ctx)
    self.assertEmpty(creation_response.err)

  def test_create_jwt_mac_broken_keyset(self):
    jwt_servicer = jwt_service.JwtServicer()

    creation_request = testing_api_pb2.CreationRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=b'\x80'))
    creation_response = jwt_servicer.CreateJwtMac(creation_request, self._ctx)
    self.assertNotEmpty(creation_response.err)

  def test_generate_compute_verify_mac(self):
    keyset_servicer = services.KeysetServicer()
    jwt_servicer = jwt_service.JwtServicer()

    template = jwt.jwt_hs256_template().SerializeToString()
    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
    keyset = gen_response.keyset

    comp_request = testing_api_pb2.JwtSignRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=keyset))
    comp_request.raw_jwt.issuer.value = 'issuer'
    comp_request.raw_jwt.subject.value = 'subject'
    comp_request.raw_jwt.custom_claims['myclaim'].bool_value = True
    comp_request.raw_jwt.expiration.seconds = 1334
    comp_request.raw_jwt.expiration.nanos = 123000000

    comp_response = jwt_servicer.ComputeMacAndEncode(comp_request, self._ctx)
    self.assertEqual(comp_response.WhichOneof('result'), 'signed_compact_jwt')
    signed_compact_jwt = comp_response.signed_compact_jwt
    verify_request = testing_api_pb2.JwtVerifyRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=keyset),
        signed_compact_jwt=signed_compact_jwt)
    verify_request.validator.expected_issuer.value = 'issuer'
    verify_request.validator.now.seconds = 1234
    verify_response = jwt_servicer.VerifyMacAndDecode(verify_request, self._ctx)
    self.assertEqual(verify_response.WhichOneof('result'), 'verified_jwt')
    self.assertEqual(verify_response.verified_jwt.issuer.value, 'issuer')
    self.assertEqual(verify_response.verified_jwt.subject.value, 'subject')
    self.assertEqual(verify_response.verified_jwt.expiration.seconds, 1334)
    self.assertEqual(verify_response.verified_jwt.expiration.nanos, 0)

  def test_generate_compute_verify_mac_without_expiration(self):
    keyset_servicer = services.KeysetServicer()
    jwt_servicer = jwt_service.JwtServicer()

    template = jwt.jwt_hs256_template().SerializeToString()
    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
    keyset = gen_response.keyset

    comp_request = testing_api_pb2.JwtSignRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=keyset))
    comp_request.raw_jwt.issuer.value = 'issuer'

    comp_response = jwt_servicer.ComputeMacAndEncode(comp_request, self._ctx)
    self.assertEqual(comp_response.WhichOneof('result'), 'signed_compact_jwt')
    signed_compact_jwt = comp_response.signed_compact_jwt
    verify_request = testing_api_pb2.JwtVerifyRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=keyset),
        signed_compact_jwt=signed_compact_jwt)
    verify_request.validator.expected_issuer.value = 'issuer'
    verify_request.validator.allow_missing_expiration = True
    verify_response = jwt_servicer.VerifyMacAndDecode(verify_request, self._ctx)
    print(verify_response.err)
    self.assertEqual(verify_response.WhichOneof('result'), 'verified_jwt')
    self.assertEqual(verify_response.verified_jwt.issuer.value, 'issuer')

  def test_create_public_key_sign(self):
    keyset_servicer = services.KeysetServicer()
    jwt_servicer = jwt_service.JwtServicer()

    template = jwt.jwt_es256_template().SerializeToString()
    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')

    creation_request = testing_api_pb2.CreationRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=gen_response.keyset))
    creation_response = jwt_servicer.CreateJwtPublicKeySign(
        creation_request, self._ctx)
    self.assertEmpty(creation_response.err)

  def test_create_public_key_sign_bad_keyset(self):
    jwt_servicer = jwt_service.JwtServicer()

    creation_request = testing_api_pb2.CreationRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=b'\x80'))
    creation_response = jwt_servicer.CreateJwtPublicKeySign(
        creation_request, self._ctx)
    self.assertNotEmpty(creation_response.err)

  def test_create_public_key_verify(self):
    keyset_servicer = services.KeysetServicer()
    jwt_servicer = jwt_service.JwtServicer()

    template = jwt.jwt_es256_template().SerializeToString()
    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
    pub_request = testing_api_pb2.KeysetPublicRequest(
        private_keyset=gen_response.keyset)
    pub_response = keyset_servicer.Public(pub_request, self._ctx)
    self.assertEqual(pub_response.WhichOneof('result'), 'public_keyset')

    creation_request = testing_api_pb2.CreationRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=pub_response.public_keyset))
    creation_response = jwt_servicer.CreateJwtPublicKeyVerify(
        creation_request, self._ctx)
    self.assertEmpty(creation_response.err)

  def test_create_public_key_verify_bad_keyset(self):
    jwt_servicer = jwt_service.JwtServicer()

    creation_request = testing_api_pb2.CreationRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=b'\x80'))
    creation_response = jwt_servicer.CreateJwtPublicKeyVerify(
        creation_request, self._ctx)
    self.assertNotEmpty(creation_response.err)

  def test_generate_sign_export_import_verify_signature(self):
    keyset_servicer = services.KeysetServicer()
    jwt_servicer = jwt_service.JwtServicer()

    template = jwt.jwt_es256_template().SerializeToString()
    gen_request = testing_api_pb2.KeysetGenerateRequest(template=template)
    gen_response = keyset_servicer.Generate(gen_request, self._ctx)
    self.assertEqual(gen_response.WhichOneof('result'), 'keyset')
    private_keyset = gen_response.keyset

    comp_request = testing_api_pb2.JwtSignRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=private_keyset))
    comp_request.raw_jwt.issuer.value = 'issuer'
    comp_request.raw_jwt.subject.value = 'subject'
    comp_request.raw_jwt.custom_claims['myclaim'].bool_value = True
    comp_response = jwt_servicer.PublicKeySignAndEncode(comp_request, self._ctx)
    self.assertEqual(comp_response.WhichOneof('result'), 'signed_compact_jwt')
    signed_compact_jwt = comp_response.signed_compact_jwt

    pub_request = testing_api_pb2.KeysetPublicRequest(
        private_keyset=private_keyset)
    pub_response = keyset_servicer.Public(pub_request, self._ctx)
    self.assertEqual(pub_response.WhichOneof('result'), 'public_keyset')
    public_keyset = pub_response.public_keyset

    to_jwkset_request = testing_api_pb2.JwtToJwkSetRequest(keyset=public_keyset)
    to_jwkset_response = jwt_servicer.ToJwkSet(to_jwkset_request, self._ctx)
    self.assertEqual(to_jwkset_response.WhichOneof('result'), 'jwk_set')

    self.assertStartsWith(to_jwkset_response.jwk_set, '{"keys":[{"')

    from_jwkset_request = testing_api_pb2.JwtFromJwkSetRequest(
        jwk_set=to_jwkset_response.jwk_set)
    from_jwkset_response = jwt_servicer.FromJwkSet(
        from_jwkset_request, self._ctx)
    self.assertEqual(from_jwkset_response.WhichOneof('result'), 'keyset')

    verify_request = testing_api_pb2.JwtVerifyRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=from_jwkset_response.keyset),
        signed_compact_jwt=signed_compact_jwt)
    verify_request.validator.expected_issuer.value = 'issuer'
    verify_request.validator.allow_missing_expiration = True
    verify_response = jwt_servicer.PublicKeyVerifyAndDecode(
        verify_request, self._ctx)
    self.assertEqual(verify_response.WhichOneof('result'), 'verified_jwt')
    self.assertEqual(verify_response.verified_jwt.issuer.value, 'issuer')

  def test_to_jwk_set_with_invalid_keyset_fails(self):
    jwt_servicer = jwt_service.JwtServicer()

    to_jwkset_request = testing_api_pb2.JwtToJwkSetRequest(keyset=b'invalid')
    jwkset_response = jwt_servicer.ToJwkSet(to_jwkset_request, self._ctx)
    self.assertEqual(jwkset_response.WhichOneof('result'), 'err')

  def test_from_jwk_set_with_invalid_jwk_set_fails(self):
    jwt_servicer = jwt_service.JwtServicer()

    from_jwkset_request = testing_api_pb2.JwtFromJwkSetRequest(
        jwk_set='invalid')
    from_jwkset_response = jwt_servicer.FromJwkSet(from_jwkset_request,
                                                   self._ctx)
    self.assertEqual(from_jwkset_response.WhichOneof('result'), 'err')
    print(from_jwkset_response.err)


if __name__ == '__main__':
  absltest.main()
