Skip to content

Commit 90a025f

Browse files
committed
Adding canonical Keyfunc functions for RSA, ECDSA, EdDSA and HMAC
This PR adds ready-to-use keyfunc functions for the various signing methods. This should simplify a lot of standard use-cases and also includes a proper signing method check.
1 parent b357385 commit 90a025f

File tree

9 files changed

+133
-38
lines changed

9 files changed

+133
-38
lines changed

ecdsa.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,10 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte
132132
return nil, err
133133
}
134134
}
135+
136+
// ECDSAPublicKey represents a [Keyfunc] that returns the ECDSA key specified in
137+
// key. Furthermore, it checks, whether the signing method matches
138+
// [SigningMethodECDSA].
139+
func ECDSAPublicKey(key *ecdsa.PublicKey) Keyfunc {
140+
return secureKeyFunc(key, []string{"ES256", "ES384", "ES512"})
141+
}

ed25519.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,10 @@ func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]by
7878

7979
return sig, nil
8080
}
81+
82+
// Ed25519PublicKey represents a [Keyfunc] that returns the Ed25519 key
83+
// specified in key. Furthermore, it checks, whether the signing method matches
84+
// [SigningMethodEdDSA].
85+
func Ed25519PublicKey(key ed25519.PublicKey) Keyfunc {
86+
return secureKeyFunc(key, []string{"EdDSA"})
87+
}

example_test.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ func ExampleParseWithClaims_customClaimsType() {
8080
jwt.RegisteredClaims
8181
}
8282

83-
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
84-
return []byte("AllYourBase"), nil
85-
})
83+
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, jwt.PresharedKey([]byte("AllYourBase")))
8684

8785
if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
8886
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
@@ -103,9 +101,11 @@ func ExampleParseWithClaims_validationOptions() {
103101
jwt.RegisteredClaims
104102
}
105103

106-
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
107-
return []byte("AllYourBase"), nil
108-
}, jwt.WithLeeway(5*time.Second))
104+
token, err := jwt.ParseWithClaims(
105+
tokenString, &MyCustomClaims{},
106+
jwt.PresharedKey([]byte("AllYourBase")),
107+
jwt.WithLeeway(5*time.Second),
108+
)
109109

110110
if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
111111
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
@@ -138,9 +138,10 @@ func (m MyCustomClaims) Validate() error {
138138
func ExampleParseWithClaims_customValidation() {
139139
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"
140140

141-
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
142-
return []byte("AllYourBase"), nil
143-
}, jwt.WithLeeway(5*time.Second))
141+
token, err := jwt.ParseWithClaims(
142+
tokenString, &MyCustomClaims{},
143+
jwt.PresharedKey([]byte("AllYourBase")),
144+
jwt.WithLeeway(5*time.Second))
144145

145146
if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
146147
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
@@ -156,9 +157,7 @@ func ExampleParse_errorChecking() {
156157
// Token from another example. This token is expired
157158
var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"
158159

159-
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
160-
return []byte("AllYourBase"), nil
161-
})
160+
token, err := jwt.Parse(tokenString, jwt.PresharedKey([]byte("AllYourBase")))
162161

163162
if token.Valid {
164163
fmt.Println("You look nice today")

hmac.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,8 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte,
8787

8888
return nil, ErrInvalidKeyType
8989
}
90+
91+
// PresharedKey represents a [Keyfunc] that simply returns the key specified in the byte slice.
92+
func PresharedKey(key []byte) Keyfunc {
93+
return secureKeyFunc(key, []string{"HS256", "HS384", "HS512"})
94+
}

parser.go

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"encoding/base64"
66
"encoding/json"
7-
"fmt"
87
"strings"
98
)
109

@@ -60,17 +59,8 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
6059

6160
// Verify signing method is in the required set
6261
if p.validMethods != nil {
63-
var signingMethodValid = false
64-
var alg = token.Method.Alg()
65-
for _, m := range p.validMethods {
66-
if m == alg {
67-
signingMethodValid = true
68-
break
69-
}
70-
}
71-
if !signingMethodValid {
72-
// signing method is not in the listed set
73-
return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
62+
if err = token.hasValidSigningMethod(p.validMethods); err != nil {
63+
return token, err
7464
}
7565
}
7666

rsa.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,10 @@ func (m *SigningMethodRSA) Sign(signingString string, key interface{}) ([]byte,
9191
return nil, err
9292
}
9393
}
94+
95+
// RSAPublicKey represents a [Keyfunc] that returns the RSA key specified in
96+
// key. Furthermore, it checks, whether the signing method matches
97+
// [SigningMethodRSA].
98+
func RSAPublicKey(key *rsa.PublicKey) Keyfunc {
99+
return secureKeyFunc(key, []string{"RS256", "RS384", "RS512"})
100+
}

test/helpers.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package test
22

33
import (
44
"crypto"
5+
"crypto/ecdsa"
56
"crypto/rsa"
67
"os"
78

@@ -56,7 +57,7 @@ func LoadECPrivateKeyFromDisk(location string) crypto.PrivateKey {
5657
return key
5758
}
5859

59-
func LoadECPublicKeyFromDisk(location string) crypto.PublicKey {
60+
func LoadECPublicKeyFromDisk(location string) *ecdsa.PublicKey {
6061
keyData, e := os.ReadFile(location)
6162
if e != nil {
6263
panic(e.Error())

token.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package jwt
33
import (
44
"encoding/base64"
55
"encoding/json"
6+
"fmt"
67
)
78

89
// Keyfunc will be used by the Parse methods as a callback function to supply
@@ -81,3 +82,36 @@ func (t *Token) SigningString() (string, error) {
8182
func (*Token) EncodeSegment(seg []byte) string {
8283
return base64.RawURLEncoding.EncodeToString(seg)
8384
}
85+
86+
// hasValidSigningMethod is a utility function that checks, if the signing
87+
// method of the token is included in the validMethods slice.
88+
func (token *Token) hasValidSigningMethod(validMethods []string) error {
89+
var signingMethodValid = false
90+
var alg = token.Method.Alg()
91+
for _, m := range validMethods {
92+
if m == alg {
93+
signingMethodValid = true
94+
break
95+
}
96+
}
97+
98+
if !signingMethodValid {
99+
// signing method is not in the listed set
100+
return newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
101+
}
102+
103+
return nil
104+
}
105+
106+
// secureKeyFunc returns a secure [Keyfunc] for the specified key that also
107+
// includes a signing method check.
108+
func secureKeyFunc(key any, validMethods []string) Keyfunc {
109+
return func(t *Token) (interface{}, error) {
110+
// Check, if the signing method matches
111+
if err := t.hasValidSigningMethod(validMethods); err != nil {
112+
return nil, err
113+
}
114+
115+
return key, nil
116+
}
117+
}

token_test.go

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
package jwt_test
1+
package jwt
22

33
import (
4+
"errors"
5+
"reflect"
46
"testing"
5-
6-
"github.com/golang-jwt/jwt/v5"
77
)
88

99
func TestToken_SigningString(t1 *testing.T) {
1010
type fields struct {
1111
Raw string
12-
Method jwt.SigningMethod
12+
Method SigningMethod
1313
Header map[string]interface{}
14-
Claims jwt.Claims
14+
Claims Claims
1515
Signature []byte
1616
Valid bool
1717
}
@@ -25,12 +25,12 @@ func TestToken_SigningString(t1 *testing.T) {
2525
name: "",
2626
fields: fields{
2727
Raw: "",
28-
Method: jwt.SigningMethodHS256,
28+
Method: SigningMethodHS256,
2929
Header: map[string]interface{}{
3030
"typ": "JWT",
31-
"alg": jwt.SigningMethodHS256.Alg(),
31+
"alg": SigningMethodHS256.Alg(),
3232
},
33-
Claims: jwt.RegisteredClaims{},
33+
Claims: RegisteredClaims{},
3434
Valid: false,
3535
},
3636
want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",
@@ -39,7 +39,7 @@ func TestToken_SigningString(t1 *testing.T) {
3939
}
4040
for _, tt := range tests {
4141
t1.Run(tt.name, func(t1 *testing.T) {
42-
t := &jwt.Token{
42+
t := &Token{
4343
Raw: tt.fields.Raw,
4444
Method: tt.fields.Method,
4545
Header: tt.fields.Header,
@@ -60,13 +60,13 @@ func TestToken_SigningString(t1 *testing.T) {
6060
}
6161

6262
func BenchmarkToken_SigningString(b *testing.B) {
63-
t := &jwt.Token{
64-
Method: jwt.SigningMethodHS256,
63+
t := &Token{
64+
Method: SigningMethodHS256,
6565
Header: map[string]interface{}{
6666
"typ": "JWT",
67-
"alg": jwt.SigningMethodHS256.Alg(),
67+
"alg": SigningMethodHS256.Alg(),
6868
},
69-
Claims: jwt.RegisteredClaims{},
69+
Claims: RegisteredClaims{},
7070
}
7171
b.Run("BenchmarkToken_SigningString", func(b *testing.B) {
7272
b.ResetTimer()
@@ -76,3 +76,48 @@ func BenchmarkToken_SigningString(b *testing.B) {
7676
}
7777
})
7878
}
79+
80+
func Test_secureKeyFunc(t *testing.T) {
81+
type fields struct {
82+
token *Token
83+
}
84+
type args struct {
85+
key any
86+
validMethods []string
87+
}
88+
tests := []struct {
89+
name string
90+
fields fields
91+
args args
92+
wantKey any
93+
wantErr error
94+
}{
95+
{
96+
name: "invalid method",
97+
fields: fields{&Token{Header: map[string]interface{}{"alg": "RS512"}, Method: SigningMethodRS512}},
98+
args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}},
99+
wantKey: nil,
100+
wantErr: ErrTokenSignatureInvalid,
101+
},
102+
{
103+
name: "correct method",
104+
fields: fields{&Token{Header: map[string]interface{}{"alg": "HS256"}, Method: SigningMethodHS256}},
105+
args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}},
106+
wantKey: []byte("mysecret"),
107+
wantErr: nil,
108+
},
109+
}
110+
for _, tt := range tests {
111+
t.Run(tt.name, func(t *testing.T) {
112+
keyfunc := secureKeyFunc(tt.args.key, tt.args.validMethods)
113+
gotKey, gotErr := keyfunc(tt.fields.token)
114+
115+
if !reflect.DeepEqual(gotKey, tt.wantKey) {
116+
t.Errorf("secureKeyFunc() key = %v, want %v", gotKey, tt.wantKey)
117+
}
118+
if (gotErr != nil) && !errors.Is(gotErr, tt.wantErr) {
119+
t.Errorf("secureKeyFunc() err = %v, want %v", gotErr, tt.wantErr)
120+
}
121+
})
122+
}
123+
}

0 commit comments

Comments
 (0)