# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements tink primitives from gRPC testing_api stubs."""

import datetime
import io
import json
from typing import BinaryIO, Dict, Optional, Mapping, Tuple

import tink
from tink import aead
from tink import daead
from tink import hybrid
from tink import jwt
from tink import mac
from tink import prf
from tink import signature as tink_signature
from tink import streaming_aead

from tink.proto import tink_pb2
from protos import testing_api_pb2
from protos import testing_api_pb2_grpc


def key_template(stub: testing_api_pb2_grpc.KeysetStub,
                 template_name: str) -> tink_pb2.KeyTemplate:
  request = testing_api_pb2.KeysetTemplateRequest(template_name=template_name)
  response = stub.GetTemplate(request)
  if response.err:
    raise tink.TinkError(response.err)
  return tink_pb2.KeyTemplate.FromString(response.key_template)


def new_keyset(stub: testing_api_pb2_grpc.KeysetStub,
               template: tink_pb2.KeyTemplate) -> bytes:
  gen_request = testing_api_pb2.KeysetGenerateRequest(
      template=template.SerializeToString())
  gen_response = stub.Generate(gen_request)
  if gen_response.err:
    raise tink.TinkError(gen_response.err)
  return gen_response.keyset


def public_keyset(stub: testing_api_pb2_grpc.KeysetStub,
                  private_keyset: bytes) -> bytes:
  request = testing_api_pb2.KeysetPublicRequest(private_keyset=private_keyset)
  response = stub.Public(request)
  if response.err:
    raise tink.TinkError(response.err)
  return response.public_keyset


def keyset_to_json(
    stub: testing_api_pb2_grpc.KeysetStub,
    keyset: bytes) -> str:
  request = testing_api_pb2.KeysetToJsonRequest(keyset=keyset)
  response = stub.ToJson(request)
  if response.err:
    raise tink.TinkError(response.err)
  return response.json_keyset


def keyset_from_json(
    stub: testing_api_pb2_grpc.KeysetStub,
    json_keyset: str) -> bytes:
  request = testing_api_pb2.KeysetFromJsonRequest(json_keyset=json_keyset)
  response = stub.FromJson(request)
  if response.err:
    raise tink.TinkError(response.err)
  return response.keyset


def keyset_read_encrypted(stub: testing_api_pb2_grpc.KeysetStub,
                          encrypted_keyset: bytes, master_keyset: bytes,
                          associated_data: Optional[bytes],
                          keyset_reader_type: str) -> bytes:
  """Reads an encrypted keyset."""
  request = testing_api_pb2.KeysetReadEncryptedRequest(
      encrypted_keyset=encrypted_keyset,
      master_keyset=master_keyset,
      keyset_reader_type=testing_api_pb2.KeysetReaderType.Value(
          keyset_reader_type))
  if associated_data is not None:
    request.associated_data.value = associated_data
  response = stub.ReadEncrypted(request)
  if response.err:
    raise tink.TinkError(response.err)
  return response.keyset


def keyset_write_encrypted(stub: testing_api_pb2_grpc.KeysetStub, keyset: bytes,
                           master_keyset: bytes,
                           associated_data: Optional[bytes],
                           keyset_writer_type: str) -> bytes:
  """Writes an encrypted keyset."""
  request = testing_api_pb2.KeysetWriteEncryptedRequest(
      keyset=keyset,
      master_keyset=master_keyset,
      keyset_writer_type=testing_api_pb2.KeysetWriterType.Value(
          keyset_writer_type))
  if associated_data is not None:
    request.associated_data.value = associated_data
  response = stub.WriteEncrypted(request)
  if response.err:
    raise tink.TinkError(response.err)
  return response.encrypted_keyset


def jwk_set_to_keyset(stub: testing_api_pb2_grpc.JwtStub,
                      jwk_set: str) -> bytes:
  request = testing_api_pb2.JwtFromJwkSetRequest(jwk_set=jwk_set)
  response = stub.FromJwkSet(request)
  if response.err:
    raise tink.TinkError(response.err)
  return response.keyset


def jwk_set_from_keyset(stub: testing_api_pb2_grpc.JwtStub,
                        keyset: bytes) -> str:
  request = testing_api_pb2.JwtToJwkSetRequest(keyset=keyset)
  response = stub.ToJwkSet(request)
  if response.err:
    raise tink.TinkError(response.err)
  return response.jwk_set


class Aead(aead.Aead):
  """Wraps AEAD service stub into an Aead primitive."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.AeadStub,
               keyset: bytes, annotations: Optional[Dict[str, str]]) -> None:
    self.lang = lang
    self._stub = stub
    self._keyset = keyset
    self._annotations = annotations
    creation_response = self._stub.Create(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._keyset,
                annotations=self._annotations)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
    enc_request = testing_api_pb2.AeadEncryptRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset, annotations=self._annotations),
        plaintext=plaintext,
        associated_data=associated_data)
    enc_response = self._stub.Encrypt(enc_request)
    if enc_response.err:
      raise tink.TinkError(enc_response.err)
    return enc_response.ciphertext

  def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
    dec_request = testing_api_pb2.AeadDecryptRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset, annotations=self._annotations),
        ciphertext=ciphertext,
        associated_data=associated_data)
    dec_response = self._stub.Decrypt(dec_request)
    if dec_response.err:
      raise tink.TinkError(dec_response.err)
    return dec_response.plaintext


class DeterministicAead(daead.DeterministicAead):
  """Wraps DAEAD services stub into an DeterministicAead primitive."""

  def __init__(self, lang: str,
               stub: testing_api_pb2_grpc.DeterministicAeadStub, keyset: bytes,
               annotations: Optional[Dict[str, str]]) -> None:
    self.lang = lang
    self._stub = stub
    self._keyset = keyset
    self._annotations = annotations
    creation_response = self._stub.Create(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                annotations=self._annotations, serialized_keyset=self._keyset)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def encrypt_deterministically(self, plaintext: bytes,
                                associated_data: bytes) -> bytes:
    """Encrypts."""
    enc_request = testing_api_pb2.DeterministicAeadEncryptRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset, annotations=self._annotations),
        plaintext=plaintext,
        associated_data=associated_data)
    enc_response = self._stub.EncryptDeterministically(enc_request)
    if enc_response.err:
      raise tink.TinkError(enc_response.err)
    return enc_response.ciphertext

  def decrypt_deterministically(self, ciphertext: bytes,
                                associated_data: bytes) -> bytes:
    """Decrypts."""
    dec_request = testing_api_pb2.DeterministicAeadDecryptRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset, annotations=self._annotations),
        ciphertext=ciphertext,
        associated_data=associated_data)
    dec_response = self._stub.DecryptDeterministically(dec_request)
    if dec_response.err:
      raise tink.TinkError(dec_response.err)
    return dec_response.plaintext


class StreamingAead(streaming_aead.StreamingAead):
  """Wraps Streaming AEAD service stub into a StreamingAead primitive."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.StreamingAeadStub,
               keyset: bytes) -> None:
    self.lang = lang
    self._stub = stub
    self._keyset = keyset
    creation_response = self._stub.Create(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._keyset)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def new_encrypting_stream(self, plaintext: BinaryIO,
                            associated_data: bytes) -> BinaryIO:
    enc_request = testing_api_pb2.StreamingAeadEncryptRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset),
        plaintext=plaintext.read(),
        associated_data=associated_data)
    enc_response = self._stub.Encrypt(enc_request)
    if enc_response.err:
      raise tink.TinkError(enc_response.err)
    return io.BytesIO(enc_response.ciphertext)

  def new_decrypting_stream(self, ciphertext: BinaryIO,
                            associated_data: bytes) -> BinaryIO:
    dec_request = testing_api_pb2.StreamingAeadDecryptRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset),
        ciphertext=ciphertext.read(),
        associated_data=associated_data)
    dec_response = self._stub.Decrypt(dec_request)
    if dec_response.err:
      raise tink.TinkError(dec_response.err)
    return io.BytesIO(dec_response.plaintext)


class Mac(mac.Mac):
  """Wraps MAC service stub into an Mac primitive."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.MacStub,
               keyset: bytes, annotations: Optional[Dict[str, str]]) -> None:
    self.lang = lang
    self._stub = stub
    self._keyset = keyset
    self._annotations = annotations
    creation_response = self._stub.Create(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._keyset, annotations=self._annotations)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def compute_mac(self, data: bytes) -> bytes:
    request = testing_api_pb2.ComputeMacRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset, annotations=self._annotations),
        data=data)
    response = self._stub.ComputeMac(request)
    if response.err:
      raise tink.TinkError(response.err)
    return response.mac_value

  def verify_mac(self, mac_value: bytes, data: bytes) -> None:
    request = testing_api_pb2.VerifyMacRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset, annotations=self._annotations),
        mac_value=mac_value,
        data=data)
    response = self._stub.VerifyMac(request)
    if response.err:
      raise tink.TinkError(response.err)


class HybridEncrypt(hybrid.HybridEncrypt):
  """Implements the HybridEncrypt primitive using a hybrid service stub."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.HybridStub,
               public_handle: bytes, annotations: Optional[Dict[str,
                                                                str]]) -> None:
    self.lang = lang
    self._stub = stub
    self._public_handle = public_handle
    self._annotations = annotations
    creation_response = self._stub.CreateHybridEncrypt(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._public_handle,
                annotations=self._annotations)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def encrypt(self, plaintext: bytes, context_info: bytes) -> bytes:
    enc_request = testing_api_pb2.HybridEncryptRequest(
        public_annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._public_handle,
            annotations=self._annotations),
        plaintext=plaintext,
        context_info=context_info)
    enc_response = self._stub.Encrypt(enc_request)
    if enc_response.err:
      raise tink.TinkError(enc_response.err)
    return enc_response.ciphertext


class HybridDecrypt(hybrid.HybridDecrypt):
  """Implements the HybridDecrypt primitive using a hybrid service stub."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.HybridStub,
               private_handle: bytes, annotations: Optional[Dict[str,
                                                                 str]]) -> None:
    self.lang = lang
    self._stub = stub
    self._private_handle = private_handle
    self._annotations = annotations
    creation_response = self._stub.CreateHybridDecrypt(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._private_handle,
                annotations=self._annotations)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def decrypt(self, ciphertext: bytes, context_info: bytes) -> bytes:
    dec_request = testing_api_pb2.HybridDecryptRequest(
        private_annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._private_handle,
            annotations=self._annotations),
        ciphertext=ciphertext,
        context_info=context_info)
    dec_response = self._stub.Decrypt(dec_request)
    if dec_response.err:
      raise tink.TinkError(dec_response.err)
    return dec_response.plaintext


class PublicKeySign(tink_signature.PublicKeySign):
  """Implements the PublicKeySign primitive using a signature service stub."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.SignatureStub,
               private_handle: bytes, annotations: Optional[Dict[str,
                                                                 str]]) -> None:
    self.lang = lang
    self._stub = stub
    self._private_handle = private_handle
    self._annotations = annotations
    creation_response = self._stub.CreatePublicKeySign(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._private_handle,
                annotations=self._annotations)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def sign(self, data: bytes) -> bytes:
    request = testing_api_pb2.SignatureSignRequest(
        private_annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._private_handle,
            annotations=self._annotations),
        data=data)
    response = self._stub.Sign(request)
    if response.err:
      raise tink.TinkError(response.err)
    return response.signature


class PublicKeyVerify(tink_signature.PublicKeyVerify):
  """Implements the PublicKeyVerify primitive using a signature service stub."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.SignatureStub,
               public_handle: bytes, annotations: Optional[Dict[str,
                                                                str]]) -> None:
    self.lang = lang
    self._stub = stub
    self._public_handle = public_handle
    self._annotations = annotations
    creation_response = self._stub.CreatePublicKeyVerify(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._public_handle,
                annotations=self._annotations)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def verify(self, signature: bytes, data: bytes) -> None:  # pytype: disable=signature-mismatch  # overriding-return-type-checks
    request = testing_api_pb2.SignatureVerifyRequest(
        public_annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._public_handle,
            annotations=self._annotations),
        signature=signature,
        data=data)
    response = self._stub.Verify(request)
    if response.err:
      raise tink.TinkError(response.err)


class _Prf(prf.Prf):
  """Implements a Prf from a PrfSet service stub."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.PrfSetStub,
               keyset: bytes, key_id: int,
               annotations: Optional[Dict[str, str]]) -> None:
    self.lang = lang
    self._stub = stub
    self._keyset = keyset
    self._key_id = key_id
    self._annotations = annotations

  def compute(self, input_data: bytes, output_length: int) -> bytes:
    request = testing_api_pb2.PrfSetComputeRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset, annotations=self._annotations),
        key_id=self._key_id,
        input_data=input_data,
        output_length=output_length)
    response = self._stub.Compute(request)
    if response.err:
      raise tink.TinkError(response.err)
    return response.output


class PrfSet(prf.PrfSet):
  """Implements a PrfSet from a PrfSet service stub."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.PrfSetStub,
               keyset: bytes, annotations: Optional[Dict[str, str]]) -> None:
    self.lang = lang
    self._stub = stub
    self._keyset = keyset
    self._key_ids_initialized = False
    self._primary_key_id = None
    self._prfs = None
    self._annotations = annotations
    creation_response = self._stub.Create(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._keyset, annotations=self._annotations)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def _initialize_key_ids(self) -> None:
    if not self._key_ids_initialized:
      request = testing_api_pb2.PrfSetKeyIdsRequest(
          annotated_keyset=testing_api_pb2.AnnotatedKeyset(
              serialized_keyset=self._keyset, annotations=self._annotations))
      response = self._stub.KeyIds(request)
      if response.err:
        raise tink.TinkError(response.err)
      self._primary_key_id = response.output.primary_key_id
      self._prfs = {}
      for key_id in response.output.key_id:
        self._prfs[key_id] = _Prf(self.lang, self._stub, self._keyset, key_id,
                                  self._annotations)
      self._key_ids_initialized = True

  def primary_id(self) -> int:
    self._initialize_key_ids()
    return self._primary_key_id

  def all(self) -> Mapping[int, prf.Prf]:
    self._initialize_key_ids()
    return self._prfs.copy()

  def primary(self) -> prf.Prf:
    self._initialize_key_ids()
    return self._prfs[self._primary_key_id]


def split_datetime(dt: datetime.datetime) -> Tuple[int, int]:
  t = dt.timestamp()
  seconds = int(t)
  nanos = int((t - seconds) * 1e9)
  return (seconds, nanos)


def to_datetime(seconds: int, nanos: int) -> datetime.datetime:
  t = seconds + (nanos / 1e9)
  return datetime.datetime.fromtimestamp(t, datetime.timezone.utc)


def raw_jwt_to_proto(raw_jwt: jwt.RawJwt) -> testing_api_pb2.JwtToken:
  """Converts a jwt.RawJwt into a proto."""
  raw_token = testing_api_pb2.JwtToken()
  if raw_jwt.has_type_header():
    raw_token.type_header.value = raw_jwt.type_header()
  if raw_jwt.has_issuer():
    raw_token.issuer.value = raw_jwt.issuer()
  if raw_jwt.has_subject():
    raw_token.subject.value = raw_jwt.subject()
  if raw_jwt.has_audiences():
    raw_token.audiences.extend(raw_jwt.audiences())
  if raw_jwt.has_jwt_id():
    raw_token.jwt_id.value = raw_jwt.jwt_id()
  if raw_jwt.has_expiration():
    seconds, nanos = split_datetime(raw_jwt.expiration())
    raw_token.expiration.seconds = seconds
    raw_token.expiration.nanos = nanos
  if raw_jwt.has_not_before():
    seconds, nanos = split_datetime(raw_jwt.not_before())
    raw_token.not_before.seconds = seconds
    raw_token.not_before.nanos = nanos
  if raw_jwt.has_issued_at():
    seconds, nanos = split_datetime(raw_jwt.issued_at())
    raw_token.issued_at.seconds = seconds
    raw_token.issued_at.nanos = nanos
  for name in raw_jwt.custom_claim_names():
    value = raw_jwt.custom_claim(name)
    if value is None:
      raw_token.custom_claims[name].null_value = testing_api_pb2.NULL_VALUE
    if isinstance(value, (int, float)):
      raw_token.custom_claims[name].number_value = value
    if isinstance(value, str):
      raw_token.custom_claims[name].string_value = value
    if isinstance(value, bool):
      raw_token.custom_claims[name].bool_value = value
    if isinstance(value, dict):
      raw_token.custom_claims[name].json_object_value = json.dumps(value)
    if isinstance(value, list):
      raw_token.custom_claims[name].json_array_value = json.dumps(value)
  return raw_token


def proto_to_verified_jwt(
    token: testing_api_pb2.JwtToken) -> jwt.VerifiedJwt:
  """Converts a proto JwtToken into a jwt.VerifiedJwt."""
  type_header = None
  if token.HasField('type_header'):
    type_header = token.type_header.value
  issuer = None
  if token.HasField('issuer'):
    issuer = token.issuer.value
  subject = None
  if token.HasField('subject'):
    subject = token.subject.value
  jwt_id = None
  if token.HasField('jwt_id'):
    jwt_id = token.jwt_id.value
  audiences = None
  if token.audiences:
    audiences = list(token.audiences)
  if token.HasField('expiration'):
    expiration = to_datetime(token.expiration.seconds, token.expiration.nanos)
    without_expiration = False
  else:
    expiration = None
    without_expiration = True
  not_before = None
  if token.HasField('not_before'):
    not_before = to_datetime(token.not_before.seconds, token.not_before.nanos)
  issued_at = None
  if token.HasField('issued_at'):
    issued_at = to_datetime(token.issued_at.seconds, token.issued_at.nanos)
  custom_claims = {}
  for name in token.custom_claims:
    value = token.custom_claims[name]
    if value.HasField('null_value'):
      custom_claims[name] = None
    if value.HasField('number_value'):
      custom_claims[name] = value.number_value
    if value.HasField('string_value'):
      custom_claims[name] = value.string_value
    if value.HasField('bool_value'):
      custom_claims[name] = value.bool_value
    if value.HasField('json_object_value'):
      custom_claims[name] = json.loads(value.json_object_value)
    if value.HasField('json_array_value'):
      custom_claims[name] = json.loads(value.json_array_value)
  raw_jwt = jwt.new_raw_jwt(
      type_header=type_header,
      issuer=issuer,
      subject=subject,
      audiences=audiences,
      jwt_id=jwt_id,
      expiration=expiration,
      without_expiration=without_expiration,
      not_before=not_before,
      issued_at=issued_at,
      custom_claims=custom_claims)
  return jwt.VerifiedJwt._create(raw_jwt)  # pylint: disable=protected-access


def jwt_validator_to_proto(
    validator: jwt.JwtValidator) -> testing_api_pb2.JwtValidator:
  """Converts a jwt.JwtValidator into a proto JwtValidator."""
  proto_validator = testing_api_pb2.JwtValidator()
  if validator.has_expected_type_header():
    proto_validator.expected_type_header.value = validator.expected_type_header(
    )
  if validator.has_expected_issuer():
    proto_validator.expected_issuer.value = validator.expected_issuer()
  if validator.has_expected_audience():
    proto_validator.expected_audience.value = validator.expected_audience()
  proto_validator.ignore_type_header = validator.ignore_type_header()
  proto_validator.ignore_issuer = validator.ignore_issuer()
  proto_validator.ignore_audience = validator.ignore_audiences()
  proto_validator.allow_missing_expiration = validator.allow_missing_expiration(
  )
  proto_validator.expect_issued_in_the_past = (
      validator.expect_issued_in_the_past())
  proto_validator.clock_skew.seconds = validator.clock_skew().seconds
  if validator.has_fixed_now():
    seconds, nanos = split_datetime(validator.fixed_now())
    proto_validator.now.seconds = seconds
    proto_validator.now.nanos = nanos
  return proto_validator


class JwtMac():
  """Implements a JwtMac from a Jwt service stub."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.JwtStub,
               keyset: bytes) -> None:
    self.lang = lang
    self._stub = stub
    self._keyset = keyset
    creation_response = self._stub.CreateJwtMac(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._keyset)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def compute_mac_and_encode(self, raw_jwt: jwt.RawJwt) -> str:
    request = testing_api_pb2.JwtSignRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset),
        raw_jwt=raw_jwt_to_proto(raw_jwt))
    response = self._stub.ComputeMacAndEncode(request)
    if response.err:
      raise tink.TinkError(response.err)
    return response.signed_compact_jwt

  def verify_mac_and_decode(self, signed_compact_jwt: str,
                            validator: jwt.JwtValidator) -> jwt.VerifiedJwt:
    """verifies and decodes a jwt in compact serialization using a mac.

    Args:
      signed_compact_jwt: the sign jwt in compact serialization form.
      validator: validator to validate the jwt.

    Returns:

    Raises:
      tink.TinkError: if verification or validation fails.
    """
    request = testing_api_pb2.JwtVerifyRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset),
        validator=jwt_validator_to_proto(validator),
        signed_compact_jwt=signed_compact_jwt)
    response = self._stub.VerifyMacAndDecode(request)
    if response.err:
      raise tink.TinkError(response.err)
    return proto_to_verified_jwt(response.verified_jwt)


class JwtPublicKeySign():
  """Implements a JwtPublicKeySign from a Jwt service stub."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.JwtStub,
               keyset: bytes) -> None:
    self.lang = lang
    self._stub = stub
    self._keyset = keyset
    creation_response = self._stub.CreateJwtPublicKeySign(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._keyset)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def sign_and_encode(self, raw_jwt: jwt.RawJwt) -> str:
    request = testing_api_pb2.JwtSignRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset),
        raw_jwt=raw_jwt_to_proto(raw_jwt))
    response = self._stub.PublicKeySignAndEncode(request)
    if response.err:
      raise tink.TinkError(response.err)
    return response.signed_compact_jwt


class JwtPublicKeyVerify():
  """Implements a JwtPublicKeyVerify from a Jwt service stub."""

  def __init__(self, lang: str, stub: testing_api_pb2_grpc.JwtStub,
               keyset: bytes) -> None:
    self.lang = lang
    self._stub = stub
    self._keyset = keyset
    creation_response = self._stub.CreateJwtPublicKeyVerify(
        testing_api_pb2.CreationRequest(
            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
                serialized_keyset=self._keyset)))
    if creation_response.err:
      raise tink.TinkError(creation_response.err)

  def verify_and_decode(self, signed_compact_jwt: str,
                        validator: jwt.JwtValidator) -> jwt.VerifiedJwt:
    """verifies and decodes a jwt in compact serialization using a digital signature.

    Args:
      signed_compact_jwt: the sign jwt in compact serialization form.
      validator: validator to validate the jwt.

    Returns:

    Raises:
      tink.TinkError: if verification or validation fails.
    """
    request = testing_api_pb2.JwtVerifyRequest(
        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
            serialized_keyset=self._keyset),
        validator=jwt_validator_to_proto(validator),
        signed_compact_jwt=signed_compact_jwt)
    response = self._stub.PublicKeyVerifyAndDecode(request)
    if response.err:
      raise tink.TinkError(response.err)
    return proto_to_verified_jwt(response.verified_jwt)
