Skip to content

Commit 02f9520

Browse files
committed
Fix bug in validation of multiple audiences
In a situation where multiple audiences are validated by the validator, the order of evaluation of the for-range loop affects the result. If we produce matches such as: ``` { "example.org": true, "example.com": false, } ``` and we configured the validator to expect a single match on audience, the code would either: 1. produce "token has invalid audience" if "example.org" was evaluated first 2. produce a passing result if "example.com" was evaluated first This commit fixes this bug, and adds a suite of tests as well as regression tests to prevent this issue in future.
1 parent 048854f commit 02f9520

File tree

2 files changed

+134
-30
lines changed

2 files changed

+134
-30
lines changed

validator.go

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package jwt
22

33
import (
44
"fmt"
5+
"slices"
56
"time"
67
)
78

@@ -235,46 +236,31 @@ func (v *Validator) verifyAudience(claims Claims, cmp []string, expectAllAud boo
235236
return err
236237
}
237238

238-
if len(aud) == 0 {
239+
// Check that aud exists and is not empty.
240+
if len(aud) == 0 || len(aud) == 1 && aud[0] == "" {
239241
return errorIfRequired(required, "aud")
240242
}
241243

242-
// use a var here to keep constant time compare when looping over a number of claims
243-
matching := make(map[string]bool, 0)
244-
245-
// build a matching hashmap out of the expected aud
246-
for _, expected := range cmp {
247-
matching[expected] = false
248-
}
249-
250-
// compare the expected aud with the actual aud in a constant time manner by looping over all actual values
251-
var stringClaims string
252-
for _, a := range aud {
253-
a := a
254-
_, ok := matching[a]
255-
if ok {
256-
matching[a] = true
244+
if !expectAllAud {
245+
for _, a := range aud {
246+
// If we only expect one match, we can stop early if we find a match
247+
if slices.Contains(cmp, a) {
248+
return nil
249+
}
257250
}
258251

259-
stringClaims = stringClaims + a
252+
return ErrTokenInvalidAudience
260253
}
261254

262-
// check if all expected auds are present
263-
result := true
264-
for _, match := range matching {
265-
if !expectAllAud && match {
266-
break
267-
} else if !match {
268-
result = false
255+
// Note that we are looping cmp here to ensure that all expected audiences
256+
// are present in the aud claim.
257+
for _, a := range cmp {
258+
if !slices.Contains(aud, a) {
259+
return ErrTokenInvalidAudience
269260
}
270261
}
271262

272-
// case where "" is sent in one or many aud claims
273-
if stringClaims == "" {
274-
return errorIfRequired(required, "aud")
275-
}
276-
277-
return errorIfFalse(result, ErrTokenInvalidAudience)
263+
return nil
278264
}
279265

280266
// verifyIssuer compares the iss claim in claims against cmp.

validator_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,121 @@ func Test_Validator_verifyIssuedAt(t *testing.T) {
261261
})
262262
}
263263
}
264+
265+
func Test_Validator_verifyAudience(t *testing.T) {
266+
type fields struct {
267+
expectedAud []string
268+
}
269+
type args struct {
270+
claims Claims
271+
cmp []string
272+
expectAllAud bool
273+
required bool
274+
}
275+
tests := []struct {
276+
name string
277+
fields fields
278+
args args
279+
wantErr error
280+
}{
281+
{
282+
name: "good without audience when expecting one aud match",
283+
fields: fields{expectedAud: []string{"example.com"}},
284+
args: args{
285+
claims: MapClaims{},
286+
cmp: []string{"example.com"},
287+
expectAllAud: false,
288+
required: false,
289+
},
290+
wantErr: nil,
291+
},
292+
{
293+
name: "good without audience when expecting all aud matches",
294+
fields: fields{expectedAud: []string{"example.com"}},
295+
args: args{
296+
claims: MapClaims{},
297+
cmp: []string{"example.com"},
298+
expectAllAud: true,
299+
required: false,
300+
},
301+
wantErr: nil,
302+
},
303+
{
304+
name: "audience matches",
305+
fields: fields{expectedAud: []string{"example.com"}},
306+
args: args{
307+
claims: RegisteredClaims{Audience: ClaimStrings{"example.com"}},
308+
cmp: []string{"example.com"},
309+
expectAllAud: false,
310+
required: true,
311+
},
312+
wantErr: nil,
313+
},
314+
{
315+
name: "audience matches with one value",
316+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
317+
args: args{
318+
claims: RegisteredClaims{Audience: ClaimStrings{"example.com"}},
319+
cmp: []string{"example.org", "example.com"},
320+
expectAllAud: false,
321+
required: true,
322+
},
323+
wantErr: nil,
324+
},
325+
{
326+
name: "audience matches with all values",
327+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
328+
args: args{
329+
claims: RegisteredClaims{Audience: ClaimStrings{"example.org", "example.com"}},
330+
cmp: []string{"example.org", "example.com"},
331+
expectAllAud: true,
332+
required: true,
333+
},
334+
wantErr: nil,
335+
},
336+
{
337+
name: "audience not matching",
338+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
339+
args: args{
340+
claims: RegisteredClaims{Audience: ClaimStrings{"example.net"}},
341+
cmp: []string{"example.org", "example.com"},
342+
expectAllAud: false,
343+
required: true,
344+
},
345+
wantErr: ErrTokenInvalidAudience,
346+
},
347+
{
348+
name: "audience not matching all values",
349+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
350+
args: args{
351+
claims: RegisteredClaims{Audience: ClaimStrings{"example.org", "example.net"}},
352+
cmp: []string{"example.org", "example.com"},
353+
expectAllAud: true,
354+
required: true,
355+
},
356+
wantErr: ErrTokenInvalidAudience,
357+
},
358+
{
359+
name: "audience missing when required",
360+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
361+
args: args{
362+
claims: MapClaims{},
363+
cmp: []string{"example.org", "example.com"},
364+
expectAllAud: true,
365+
required: true,
366+
},
367+
wantErr: ErrTokenRequiredClaimMissing,
368+
},
369+
}
370+
for _, tt := range tests {
371+
t.Run(tt.name, func(t *testing.T) {
372+
v := &Validator{
373+
expectedAud: tt.fields.expectedAud,
374+
expectAllAud: tt.args.expectAllAud,
375+
}
376+
if err := v.verifyAudience(tt.args.claims, tt.args.cmp, tt.args.expectAllAud, tt.args.required); (err != nil) && !errors.Is(err, tt.wantErr) {
377+
t.Errorf("validator.verifyAudience() error = %v, wantErr %v", err, tt.wantErr)
378+
}
379+
})
380+
}
381+
}

0 commit comments

Comments
 (0)