// Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //////////////////////////////////////////////////////////////////////////////// package jwt import ( "encoding/base64" "testing" "time" "github.com/google/go-cmp/cmp" "google.golang.org/protobuf/proto" "github.com/google/tink/go/core/registry" "github.com/google/tink/go/subtle/random" jwtmacpb "github.com/google/tink/go/proto/jwt_hmac_go_proto" tinkpb "github.com/google/tink/go/proto/tink_go_proto" ) type jwtKeyManagerTestCase struct { tag string keyFormat *jwtmacpb.JwtHmacKeyFormat key *jwtmacpb.JwtHmacKey } const ( typeURL = "type.googleapis.com/google.crypto.tink.JwtHmacKey" ) func generateKeyFormat(keySize uint32, algorithm jwtmacpb.JwtHmacAlgorithm) *jwtmacpb.JwtHmacKeyFormat { return &jwtmacpb.JwtHmacKeyFormat{ KeySize: keySize, Algorithm: algorithm, } } func TestDoesSupport(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err) } if !km.DoesSupport(typeURL) { t.Errorf("km.DoesSupport(%q) = false, want true", typeURL) } } func TestTypeURL(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err) } if km.TypeURL() != typeURL { t.Errorf("km.TypeURL() = %q, want %q", km.TypeURL(), typeURL) } } var invalidKeyFormatTestCases = []jwtKeyManagerTestCase{ { tag: "invalid hash algorithm", keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN), }, { tag: "invalid HS256 key size", keyFormat: generateKeyFormat(31, jwtmacpb.JwtHmacAlgorithm_HS256), }, { tag: "invalid HS384 key size", keyFormat: generateKeyFormat(47, jwtmacpb.JwtHmacAlgorithm_HS384), }, { tag: "invalid HS512 key size", keyFormat: generateKeyFormat(63, jwtmacpb.JwtHmacAlgorithm_HS512), }, { tag: "empty key format", keyFormat: &jwtmacpb.JwtHmacKeyFormat{}, }, { tag: "nil key format", keyFormat: nil, }, } func TestNewKeyInvalidFormatFails(t *testing.T) { for _, tc := range invalidKeyFormatTestCases { t.Run(tc.tag, func(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) } serializedKeyFormat, err := proto.Marshal(tc.keyFormat) if err != nil { t.Errorf("serializing key format: %v", err) } if _, err := km.NewKey(serializedKeyFormat); err == nil { t.Errorf("km.NewKey() err = nil, want error") } }) } } func TestNewDataInvalidFormatFails(t *testing.T) { for _, tc := range invalidKeyFormatTestCases { t.Run(tc.tag, func(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) } serializedKeyFormat, err := proto.Marshal(tc.keyFormat) if err != nil { t.Errorf("serializing key format: %v", err) } if _, err := km.NewKeyData(serializedKeyFormat); err == nil { t.Errorf("km.NewKey() err = nil, want error") } }) } } var validKeyFormatTestCases = []jwtKeyManagerTestCase{ { tag: "SHA256 hash algorithm", keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256), }, { tag: "SHA384 hash algorithm", keyFormat: generateKeyFormat(48, jwtmacpb.JwtHmacAlgorithm_HS384), }, { tag: "SHA512 hash algorithm", keyFormat: generateKeyFormat(64, jwtmacpb.JwtHmacAlgorithm_HS512), }, } func TestNewKey(t *testing.T) { for _, tc := range validKeyFormatTestCases { t.Run(tc.tag, func(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) } serializedKeyFormat, err := proto.Marshal(tc.keyFormat) if err != nil { t.Errorf("serializing key format: %v", err) } k, err := km.NewKey(serializedKeyFormat) if err != nil { t.Errorf("km.NewKey() err = %v, want nil", err) } key, ok := k.(*jwtmacpb.JwtHmacKey) if !ok { t.Errorf("key isn't of type JwtHmacKey") } if key.Algorithm != tc.keyFormat.Algorithm { t.Errorf("k.Algorithm = %v, want %v", key.Algorithm, tc.keyFormat.Algorithm) } if len(key.KeyValue) != int(tc.keyFormat.KeySize) { t.Errorf("len(key.KeyValue) = %d, want %d", len(key.KeyValue), tc.keyFormat.KeySize) } }) } } func TestNewKeyData(t *testing.T) { for _, tc := range validKeyFormatTestCases { t.Run(tc.tag, func(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) } serializedKeyFormat, err := proto.Marshal(tc.keyFormat) if err != nil { t.Errorf("serializing key format: %v", err) } k, err := km.NewKeyData(serializedKeyFormat) if err != nil { t.Errorf("km.NewKeyData() err = %v, want nil", err) } if k.GetTypeUrl() != typeURL { t.Errorf("k.GetTypeUrl() = %q, want %q", k.GetTypeUrl(), typeURL) } if k.GetKeyMaterialType() != tinkpb.KeyData_SYMMETRIC { t.Errorf("k.GetKeyMaterialType() = %q, want %q", k.GetKeyMaterialType(), tinkpb.KeyData_SYMMETRIC) } }) } } func generateKey(keySize, version uint32, algorithm jwtmacpb.JwtHmacAlgorithm, kid *jwtmacpb.JwtHmacKey_CustomKid) *jwtmacpb.JwtHmacKey { return &jwtmacpb.JwtHmacKey{ KeyValue: random.GetRandomBytes(keySize), Algorithm: algorithm, CustomKid: kid, Version: version, } } func TestGetPrimitiveWithValidKeys(t *testing.T) { rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true, Audiences: []string{"tink-aud"}}) if err != nil { t.Fatalf("NewRawJWT() err = %v, want nil", err) } validator, err := NewValidator(&ValidatorOpts{AllowMissingExpiration: true, ExpectedAudience: refString("tink-aud")}) if err != nil { t.Fatalf("NewValidator() err = %v, want nil", err) } for _, tc := range []jwtKeyManagerTestCase{ { tag: "SHA256 hash algorithm", key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil), }, { tag: "SHA384 hash algorithm", key: generateKey(48, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil), }, { tag: "SHA512 hash algorithm", key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil), }, { tag: "with custom kid", key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"}), }, } { t.Run(tc.tag, func(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) } serializedKey, err := proto.Marshal(tc.key) if err != nil { t.Errorf("serializing key format: %v", err) } p, err := km.Primitive(serializedKey) if err != nil { t.Errorf("km.Primitive() err = %v, want nil", err) } primitive, ok := p.(*macWithKID) if !ok { t.Errorf("primitive isn't of type: macWithKID") } compact, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, nil) if err != nil { t.Errorf("ComputeMACAndEncodeWithKID() err = %v, want nil", err) } verifiedJWT, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil) if err != nil { t.Errorf("VerifyMACAndDecodeWithKID() err = %v, want nil", err) } audiences, err := verifiedJWT.Audiences() if err != nil { t.Errorf("verifiedJWT.Audiences() err = %v, want nil", err) } if !cmp.Equal(audiences, []string{"tink-aud"}) { t.Errorf("verifiedJWT.Audiences() = %q, want ['tink-aud']", audiences) } }) } } func TestGetPrimitiveWithInvalidKeys(t *testing.T) { for _, tc := range []jwtKeyManagerTestCase{ { tag: "HS256", key: generateKey(31, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil), }, { tag: "HS384", key: generateKey(47, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil), }, { tag: "HS512", key: generateKey(63, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil), }, } { t.Run(tc.tag, func(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Fatalf("registry.GetKeyManager(%q) err=%q, want nil", typeURL, err) } serializedKey, err := proto.Marshal(tc.key) if err != nil { t.Fatalf("proto.Marshal(tc.key) err =%q, want nil", err) } _, err = km.Primitive(serializedKey) if err == nil { t.Error("km.Primitive(serializedKey) err = nil, want error") } }) } } func TestSpecyfingCustomKIDAndTINKKIDFails(t *testing.T) { // key and compact are examples from: https://datatracker.ietf.org/doc/html/rfc7515#appendix-A.1.1 compact := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" rawKey, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString("AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow") if err != nil { t.Fatalf("failed decoding test key: %v", err) } key := &jwtmacpb.JwtHmacKey{ KeyValue: rawKey, Algorithm: jwtmacpb.JwtHmacAlgorithm_HS256, CustomKid: &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"}, Version: 0, } km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) } serializedKey, err := proto.Marshal(key) if err != nil { t.Errorf("serializing key format: %v", err) } p, err := km.Primitive(serializedKey) if err != nil { t.Errorf("km.Primitive() err = %v, want nil", err) } primitive, ok := p.(*macWithKID) if !ok { t.Errorf("primitive isn't of type: macWithKID") } rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true}) if err != nil { t.Errorf("creating new RawJWT: %v", err) } opts := &ValidatorOpts{ ExpectedTypeHeader: refString("JWT"), ExpectedIssuer: refString("joe"), FixedNow: time.Unix(12345, 0), } validator, err := NewValidator(opts) if err != nil { t.Errorf("creating new JWTValidator: %v", err) } if _, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, refString("4566")); err == nil { t.Errorf("primitive.ComputeMACAndEncodeWithKID() err = nil, want error") } if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, refString("4566")); err == nil { t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = 4566) err = nil, want error") } // Verify success without KID if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil); err != nil { t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = nil) err = %v, want nil", err) } } func TestGetPrimitiveWithInvalidKeyFails(t *testing.T) { for _, tc := range []jwtKeyManagerTestCase{ { tag: "empty key", key: &jwtmacpb.JwtHmacKey{}, }, { tag: "nil key", key: nil, }, { tag: "unsupported hash algorithm", key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN, nil), }, { tag: "short key length", key: generateKey(20, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil), }, { tag: "unsupported version", key: generateKey(48, 1, jwtmacpb.JwtHmacAlgorithm_HS384, nil), }, } { t.Run(tc.tag, func(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) } serializedKey, err := proto.Marshal(tc.key) if err != nil { t.Errorf("serializing key format: %v", err) } if _, err := km.Primitive(serializedKey); err == nil { t.Errorf("km.Primitive() err = nil, want error") } }) } } func TestGeneratesDifferentKeys(t *testing.T) { km, err := registry.GetKeyManager(typeURL) if err != nil { t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) } serializedKeyFormat, err := proto.Marshal(generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256)) if err != nil { t.Errorf("serializing key format: %v", err) } k1, err := km.NewKey(serializedKeyFormat) if err != nil { t.Errorf("km.NewKey() err = %v, want nil", err) } k2, err := km.NewKey(serializedKeyFormat) if err != nil { t.Errorf("km.NewKey() err = %v, want nil", err) } key1, ok := k1.(*jwtmacpb.JwtHmacKey) if !ok { t.Errorf("k1 isn't of type JwtHmacKey") } key2, ok := k2.(*jwtmacpb.JwtHmacKey) if !ok { t.Errorf("k2 isn't of type JwtHmacKey") } if cmp.Equal(key1.GetKeyValue(), key2.GetKeyValue()) { t.Errorf("key material should differ") } }