Skip to content

Commit 6d0139f

Browse files
committed
Verify functions now return errors instead of bool
1 parent e1aaf03 commit 6d0139f

File tree

4 files changed

+132
-84
lines changed

4 files changed

+132
-84
lines changed

map_claims.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package jwt
22

33
import (
44
"encoding/json"
5+
"fmt"
56
)
67

78
// MapClaims is a claims type that uses the map[string]interface{} for JSON decoding.
@@ -60,7 +61,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {
6061
return newNumericDateFromSeconds(v), nil
6162
}
6263

63-
return nil, ErrInvalidType
64+
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
6465
}
6566

6667
// parseClaimsString tries to parse a key in the map claims type as a
@@ -76,7 +77,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
7677
for _, a := range v {
7778
vs, ok := a.(string)
7879
if !ok {
79-
return nil, ErrInvalidType
80+
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
8081
}
8182
cs = append(cs, vs)
8283
}
@@ -101,7 +102,7 @@ func (m MapClaims) parseString(key string) (string, error) {
101102

102103
iss, ok = raw.(string)
103104
if !ok {
104-
return "", ErrInvalidType
105+
return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
105106
}
106107

107108
return iss, nil

none.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package jwt
22

3-
import "fmt"
4-
53
// SigningMethodNone implements the none signing method. This is required by the spec
64
// but you probably should never use it.
75
var SigningMethodNone *signingMethodNone
@@ -15,7 +13,7 @@ type unsafeNoneMagicConstant string
1513

1614
func init() {
1715
SigningMethodNone = &signingMethodNone{}
18-
NoneSignatureTypeDisallowedError = fmt.Errorf("%w: 'none' signature type is not allowed", ErrTokenUnverifiable)
16+
NoneSignatureTypeDisallowedError = newError("'none' signature type is not allowed", ErrTokenUnverifiable)
1917

2018
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
2119
return SigningMethodNone
@@ -35,7 +33,7 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac
3533
}
3634
// If signing method is none, signature must be an empty string
3735
if signature != "" {
38-
return fmt.Errorf("%w: 'none' signing method with non-empty signature", ErrTokenUnverifiable)
36+
return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable)
3937
}
4038

4139
// Accept 'none' signing method.

parser_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ func TestParser_Parse(t *testing.T) {
334334

335335
// Parse the token
336336
var token *jwt.Token
337-
//var ve *jwt.ValidationError
338337
var err error
339338
var parser = data.parser
340339
if parser == nil {

validator.go

Lines changed: 126 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,32 @@ package jwt
22

33
import (
44
"crypto/subtle"
5+
"fmt"
56
"time"
67
)
78

9+
// ClaimsValidator is an interface that can be implemented by custom claims who
10+
// wish to execute any additional claims validation based on
11+
// application-specific logic. The Validate function is then executed in
12+
// addition to the regular claims validation and any error returned is appended
13+
// to the final validation result.
14+
//
15+
// type MyCustomClaims struct {
16+
// Foo string `json:"foo"`
17+
// jwt.RegisteredClaims
18+
// }
19+
//
20+
// func (m MyCustomClaims) Validate() error {
21+
// if m.Foo != "bar" {
22+
// return errors.New("must be foobar")
23+
// }
24+
// return nil
25+
// }
26+
type ClaimsValidator interface {
27+
Claims
28+
Validate() error
29+
}
30+
831
// validator is the core of the new Validation API. It is automatically used by
932
// a [Parser] during parsing and can be modified with various parser options.
1033
//
@@ -46,11 +69,12 @@ func newValidator(opts ...ParserOption) *validator {
4669
}
4770

4871
// Validate validates the given claims. It will also perform any custom
49-
// validation if claims implements the CustomValidator interface.
72+
// validation if claims implements the [ClaimsValidator] interface.
5073
func (v *validator) Validate(claims Claims) error {
5174
var (
5275
now time.Time
53-
errs []error = make([]error, 0)
76+
errs []error = make([]error, 0, 6)
77+
err error
5478
)
5579

5680
// Check, if we have a time func
@@ -61,42 +85,48 @@ func (v *validator) Validate(claims Claims) error {
6185
}
6286

6387
// We always need to check the expiration time, but usage of the claim
64-
// itself is OPTIONAL
65-
if !v.VerifyExpiresAt(claims, now, false) {
66-
errs = append(errs, ErrTokenExpired)
88+
// itself is OPTIONAL.
89+
if err = v.verifyExpiresAt(claims, now, false); err != nil {
90+
errs = append(errs, err)
6791
}
6892

6993
// We always need to check not-before, but usage of the claim itself is
70-
// OPTIONAL
71-
if !v.VerifyNotBefore(claims, now, false) {
72-
errs = append(errs, ErrTokenNotValidYet)
94+
// OPTIONAL.
95+
if err = v.verifyNotBefore(claims, now, false); err != nil {
96+
errs = append(errs, err)
7397
}
7498

7599
// Check issued-at if the option is enabled
76-
if v.verifyIat && !v.VerifyIssuedAt(claims, now, false) {
77-
errs = append(errs, ErrTokenUsedBeforeIssued)
100+
if v.verifyIat {
101+
if err = v.verifyIssuedAt(claims, now, false); err != nil {
102+
errs = append(errs, err)
103+
}
78104
}
79105

80106
// If we have an expected audience, we also require the audience claim
81-
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) {
82-
errs = append(errs, ErrTokenInvalidAudience)
107+
if v.expectedAud != "" {
108+
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
109+
errs = append(errs, err)
110+
}
83111
}
84112

85113
// If we have an expected issuer, we also require the issuer claim
86-
if v.expectedIss != "" && !v.VerifyIssuer(claims, v.expectedIss, true) {
87-
errs = append(errs, ErrTokenInvalidIssuer)
114+
if v.expectedIss != "" {
115+
if err = v.verifyIssuer(claims, v.expectedIss, true); err != nil {
116+
errs = append(errs, err)
117+
}
88118
}
89119

90120
// If we have an expected subject, we also require the subject claim
91-
if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) {
92-
errs = append(errs, ErrTokenInvalidSubject)
121+
if v.expectedSub != "" {
122+
if err = v.verifySubject(claims, v.expectedSub, true); err != nil {
123+
errs = append(errs, ErrTokenInvalidSubject)
124+
}
93125
}
94126

95127
// Finally, we want to give the claim itself some possibility to do some
96128
// additional custom validation based on a custom Validate function.
97-
cvt, ok := claims.(interface {
98-
Validate() error
99-
})
129+
cvt, ok := claims.(ClaimsValidator)
100130
if ok {
101131
if err := cvt.Validate(); err != nil {
102132
errs = append(errs, err)
@@ -110,84 +140,84 @@ func (v *validator) Validate(claims Claims) error {
110140
return joinErrors(errs)
111141
}
112142

113-
// VerifyExpiresAt compares the exp claim in claims against cmp. This function
114-
// will return true if cmp < exp. Additional leeway is taken into account.
143+
// verifyExpiresAt compares the exp claim in claims against cmp. This function
144+
// will succeed if cmp < exp. Additional leeway is taken into account.
115145
//
116-
// If exp is not set, it will return true if the claim is not required,
117-
// otherwise false will be returned.
146+
// If exp is not set, it will succeed if the claim is not required,
147+
// otherwise ErrTokenRequiredClaimMissing will be returned.
118148
//
119149
// Additionally, if any error occurs while retrieving the claim, e.g., when its
120-
// the wrong type, false will be returned.
121-
func (v *validator) VerifyExpiresAt(claims Claims, cmp time.Time, required bool) bool {
150+
// the wrong type, an ErrTokenUnverifiable error will be returned.
151+
func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error {
122152
exp, err := claims.GetExpirationTime()
123153
if err != nil {
124-
return false
154+
return err
125155
}
126156

127-
if exp != nil {
128-
return cmp.Before((exp.Time).Add(+v.leeway))
129-
} else {
130-
return !required
157+
if exp == nil {
158+
return errorIfRequired(required, "exp")
131159
}
160+
161+
return errorIfFalse(cmp.Before((exp.Time).Add(+v.leeway)), ErrTokenExpired)
132162
}
133163

134-
// VerifyIssuedAt compares the iat claim in claims against cmp. This function
135-
// will return true if cmp >= iat. Additional leeway is taken into account.
164+
// verifyIssuedAt compares the iat claim in claims against cmp. This function
165+
// will succeed if cmp >= iat. Additional leeway is taken into account.
136166
//
137-
// If iat is not set, it will return true if the claim is not required,
138-
// otherwise false will be returned.
167+
// If iat is not set, it will succeed if the claim is not required,
168+
// otherwise ErrTokenRequiredClaimMissing will be returned.
139169
//
140170
// Additionally, if any error occurs while retrieving the claim, e.g., when its
141-
// the wrong type, false will be returned.
142-
func (v *validator) VerifyIssuedAt(claims Claims, cmp time.Time, required bool) bool {
171+
// the wrong type, an ErrTokenUnverifiable error will be returned.
172+
func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error {
143173
iat, err := claims.GetIssuedAt()
144174
if err != nil {
145-
return false
175+
return err
146176
}
147177

148-
if iat != nil {
149-
return !cmp.Before(iat.Add(-v.leeway))
150-
} else {
151-
return !required
178+
if iat == nil {
179+
return errorIfRequired(required, "iat")
152180
}
181+
182+
return errorIfFalse(!cmp.Before(iat.Add(-v.leeway)), ErrTokenUsedBeforeIssued)
153183
}
154184

155-
// VerifyNotBefore compares the nbf claim in claims against cmp. This function
185+
// verifyNotBefore compares the nbf claim in claims against cmp. This function
156186
// will return true if cmp >= nbf. Additional leeway is taken into account.
157187
//
158-
// If nbf is not set, it will return true if the claim is not required,
159-
// otherwise false will be returned.
188+
// If nbf is not set, it will succeed if the claim is not required,
189+
// otherwise ErrTokenRequiredClaimMissing will be returned.
160190
//
161191
// Additionally, if any error occurs while retrieving the claim, e.g., when its
162-
// the wrong type, false will be returned.
163-
func (v *validator) VerifyNotBefore(claims Claims, cmp time.Time, required bool) bool {
192+
// the wrong type, an ErrTokenUnverifiable error will be returned.
193+
func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error {
164194
nbf, err := claims.GetNotBefore()
165195
if err != nil {
166-
return false
196+
return err
167197
}
168198

169-
if nbf != nil {
170-
return !cmp.Before(nbf.Add(-v.leeway))
171-
} else {
172-
return !required
199+
if nbf == nil {
200+
return errorIfRequired(required, "nbf")
173201
}
202+
203+
return errorIfFalse(!cmp.Before(nbf.Add(-v.leeway)), ErrTokenNotValidYet)
174204
}
175205

176-
// VerifyAudience compares the aud claim against cmp.
206+
// verifyAudience compares the aud claim against cmp.
177207
//
178-
// If aud is not set or an empty list, it will return true if the claim is not
179-
// required, otherwise false will be returned.
208+
// If aud is not set or an empty list, it will succeed if the claim is not required,
209+
// otherwise ErrTokenRequiredClaimMissing will be returned.
180210
//
181211
// Additionally, if any error occurs while retrieving the claim, e.g., when its
182-
// the wrong type, false will be returned.
183-
func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) bool {
212+
// the wrong type, an ErrTokenUnverifiable error will be returned.
213+
func (v *validator) verifyAudience(claims Claims, cmp string, required bool) error {
184214
aud, err := claims.GetAudience()
185215
if err != nil {
186-
return false
216+
return err
187217
}
188218

189219
if len(aud) == 0 {
190-
return !required
220+
return errorIfRequired(required, "aud")
191221
}
192222

193223
// use a var here to keep constant time compare when looping over a number of claims
@@ -203,48 +233,68 @@ func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) boo
203233

204234
// case where "" is sent in one or many aud claims
205235
if stringClaims == "" {
206-
return !required
236+
return errorIfRequired(required, "aud")
207237
}
208238

209-
return result
239+
return errorIfFalse(result, ErrTokenInvalidAudience)
210240
}
211241

212-
// VerifyIssuer compares the iss claim in claims against cmp.
242+
// verifyIssuer compares the iss claim in claims against cmp.
213243
//
214-
// If iss is not set, it will return true if the claim is not required,
215-
// otherwise false will be returned.
244+
// If iss is not set, it will succeed if the claim is not required,
245+
// otherwise ErrTokenRequiredClaimMissing will be returned.
216246
//
217247
// Additionally, if any error occurs while retrieving the claim, e.g., when its
218-
// the wrong type, false will be returned.
219-
func (v *validator) VerifyIssuer(claims Claims, cmp string, required bool) bool {
248+
// the wrong type, an ErrTokenUnverifiable error will be returned.
249+
func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error {
220250
iss, err := claims.GetIssuer()
221251
if err != nil {
222-
return false
252+
return err
223253
}
224254

225255
if iss == "" {
226-
return !required
256+
return errorIfRequired(required, "iss")
227257
}
228258

229-
return iss == cmp
259+
return errorIfFalse(iss == cmp, ErrTokenInvalidIssuer)
230260
}
231261

232-
// VerifySubject compares the sub claim against cmp.
262+
// verifySubject compares the sub claim against cmp.
233263
//
234-
// If sub is not set, it will return true if the claim is not required,
235-
// otherwise false will be returned.
264+
// If sub is not set, it will succeed if the claim is not required,
265+
// otherwise ErrTokenRequiredClaimMissing will be returned.
236266
//
237267
// Additionally, if any error occurs while retrieving the claim, e.g., when its
238-
// the wrong type, false will be returned.
239-
func (v *validator) VerifySubject(claims Claims, cmp string, required bool) bool {
268+
// the wrong type, an ErrTokenUnverifiable error will be returned.
269+
func (v *validator) verifySubject(claims Claims, cmp string, required bool) error {
240270
sub, err := claims.GetSubject()
241271
if err != nil {
242-
return false
272+
return err
243273
}
244274

245275
if sub == "" {
246-
return !required
276+
return errorIfRequired(required, "sub")
247277
}
248278

249-
return sub == cmp
279+
return errorIfFalse(sub == cmp, ErrTokenInvalidIssuer)
280+
}
281+
282+
// errorIfFalse returns the error specified in err, if the value is true.
283+
// Otherwise, nil is returned.
284+
func errorIfFalse(value bool, err error) error {
285+
if value {
286+
return nil
287+
} else {
288+
return err
289+
}
290+
}
291+
292+
// errorIfRequired returns an ErrTokenRequiredClaimMissing error if required is
293+
// true. Otherwise, nil is returned.
294+
func errorIfRequired(required bool, claim string) error {
295+
if required {
296+
return newError(fmt.Sprintf("%s claim is required", claim), ErrTokenRequiredClaimMissing)
297+
} else {
298+
return nil
299+
}
250300
}

0 commit comments

Comments
 (0)