Skip to content

Commit 8644b87

Browse files
committed
Wrap SSO token expiration error into user-facing error
1 parent 0c44746 commit 8644b87

File tree

7 files changed

+221
-11
lines changed

7 files changed

+221
-11
lines changed

neo4j/error.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ func IsTransactionExecutionLimit(err error) bool {
110110
return is
111111
}
112112

113+
// TokenExpiredError represent errors caused by the driver not being able to connect to Neo4j services,
114+
// or lost connections.
115+
type TokenExpiredError struct {
116+
Code string
117+
Message string
118+
}
119+
120+
func (e *TokenExpiredError) Error() string {
121+
return fmt.Sprintf("TokenExpiredError: %s (%s)", e.Code, e.Message)
122+
}
123+
113124
func wrapError(err error) error {
114125
if err == nil {
115126
return nil
@@ -134,6 +145,10 @@ func wrapError(err error) error {
134145
// most likely due to read timeout configuration hint being applied
135146
return &ConnectivityError{inner: err}
136147
}
148+
case *db.Neo4jError:
149+
if e.Code == "Neo.ClientError.Security.TokenExpired" {
150+
return &TokenExpiredError{Code: e.Code, Message: e.Msg}
151+
}
137152
}
138153
return err
139154
}

neo4j/internal/bolt/bolt4_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,4 +871,29 @@ func TestBolt4(ot *testing.T) {
871871
assertBoltState(t, bolt4_dead, bolt)
872872
AssertError(t, err)
873873
})
874+
875+
ot.Run("Immediately expired authentication token error triggers a connection failure", func(t *testing.T) {
876+
bolt, cleanup := connectToServer(t, func(srv *bolt4server) {
877+
srv.accept(4)
878+
srv.sendFailureMsg("Neo.ClientError.Security.TokenExpired", "SSO token is... expired")
879+
})
880+
defer cleanup()
881+
882+
_, err := bolt.Run(db.Command{Cypher: "MATCH (n) RETURN n"}, db.TxConfig{Mode: db.ReadMode})
883+
assertBoltState(t, bolt4_failed, bolt)
884+
AssertError(t, err)
885+
})
886+
887+
ot.Run("Expired authentication token error after run triggers a connection failure", func(t *testing.T) {
888+
bolt, cleanup := connectToServer(t, func(srv *bolt4server) {
889+
srv.accept(4)
890+
srv.waitForRun()
891+
srv.sendFailureMsg("Neo.ClientError.Security.TokenExpired", "SSO token is... expired")
892+
})
893+
defer cleanup()
894+
895+
_, err := bolt.Run(db.Command{Cypher: "MATCH (n) RETURN n"}, db.TxConfig{Mode: db.ReadMode})
896+
assertBoltState(t, bolt4_failed, bolt)
897+
AssertError(t, err)
898+
})
874899
}

neo4j/internal/testutil/asserts.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
package testutil
2222

2323
import (
24+
"fmt"
2425
"io"
2526
"reflect"
2627
"strings"
@@ -81,10 +82,15 @@ func AssertNoError(t *testing.T, err error) {
8182
}
8283
}
8384

85+
func AssertErrorMessageContains(t *testing.T, err error, msg string, args ...interface{}) {
86+
AssertError(t, err)
87+
AssertStringContain(t, err.Error(), fmt.Sprintf(msg, args...))
88+
}
89+
8490
func AssertError(t *testing.T, err error) {
8591
t.Helper()
8692
if err == nil {
87-
t.Fatal("Expected an error but it wasn't")
93+
t.Fatal("Expected an error but none occurred")
8894
}
8995
}
9096

neo4j/internal/testutil/connfake.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ type ConnFake struct {
4848
Table *db.RoutingTable
4949
Err error
5050
Id int
51+
TxBeginErr error
5152
TxBeginHandle db.TxHandle
53+
RunErr error
5254
RunStream db.StreamHandle
55+
RunTxErr error
5356
RunTxStream db.StreamHandle
5457
Nexts []Next
5558
Bookm string
@@ -112,7 +115,7 @@ func (c *ConnFake) GetRoutingTable(context map[string]string, bookmarks []string
112115

113116
func (c *ConnFake) TxBegin(txConfig db.TxConfig) (db.TxHandle, error) {
114117
c.RecordedTxs = append(c.RecordedTxs, RecordedTx{Origin: "TxBegin", Mode: txConfig.Mode, Bookmarks: txConfig.Bookmarks, Timeout: txConfig.Timeout, Meta: txConfig.Meta})
115-
return c.TxBeginHandle, c.Err
118+
return c.TxBeginHandle, c.TxBeginErr
116119
}
117120

118121
func (c *ConnFake) TxRollback(tx db.TxHandle) error {
@@ -129,11 +132,11 @@ func (c *ConnFake) TxCommit(tx db.TxHandle) error {
129132
func (c *ConnFake) Run(runCommand db.Command, txConfig db.TxConfig) (db.StreamHandle, error) {
130133

131134
c.RecordedTxs = append(c.RecordedTxs, RecordedTx{Origin: "Run", Mode: txConfig.Mode, Bookmarks: txConfig.Bookmarks, Timeout: txConfig.Timeout, Meta: txConfig.Meta})
132-
return c.RunStream, c.Err
135+
return c.RunStream, c.RunErr
133136
}
134137

135138
func (c *ConnFake) RunTx(tx db.TxHandle, runCommand db.Command) (db.StreamHandle, error) {
136-
return c.RunTxStream, c.Err
139+
return c.RunTxStream, c.RunTxErr
137140
}
138141

139142
func (c *ConnFake) Keys(streamHandle db.StreamHandle) ([]string, error) {

neo4j/session.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func (s *session) BeginTransaction(configurers ...func(*TransactionConfig)) (Tra
208208
})
209209
if err != nil {
210210
s.pool.Return(conn)
211-
return nil, err
211+
return nil, wrapError(err)
212212
}
213213

214214
// Create transaction wrapper

neo4j/session_test.go

Lines changed: 155 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ func TestSession(st *testing.T) {
6262
return &router, &pool, sess
6363
}
6464

65+
tokenExpiredErr := &db.Neo4jError{Code: "Neo.ClientError.Security.TokenExpired", Msg: "oopsie whoopsie"}
66+
6567
st.Run("Retry mechanism", func(rt *testing.T) {
6668
// Checks that retries occur on database error and that it stops retrying after a certain
6769
// amount of time and that connections are returned to pool upon failure.
@@ -141,17 +143,17 @@ func TestSession(st *testing.T) {
141143
dirtyBookmarks := []string{"", "b1", "", "b2", ""}
142144
cleanBookmarks := []string{"b1", "b2"}
143145
_, pool, sess := createSessionWithBookmarks(dirtyBookmarks)
144-
conn := &ConnFake{Alive: true, Err: errors.New("Make all fail")}
146+
err := errors.New("make all fail")
147+
conn := &ConnFake{Alive: true, RunErr: err, TxBeginErr: err}
145148
pool.BorrowConn = conn
146149

147-
// All of these assume that Err on ConnFake fails the operations
148150
sess.Run("cypher", nil)
149151
sess.BeginTransaction()
150152
sess.ReadTransaction(func(tx Transaction) (interface{}, error) {
151-
return nil, errors.New("somehting")
153+
return nil, errors.New("something")
152154
})
153155
sess.WriteTransaction(func(tx Transaction) (interface{}, error) {
154-
return nil, errors.New("somehting")
156+
return nil, errors.New("something")
155157
})
156158
AssertLen(t, conn.RecordedTxs, 4)
157159
for _, rtx := range conn.RecordedTxs {
@@ -297,6 +299,98 @@ func TestSession(st *testing.T) {
297299
_, err = sess.Run("cypher", nil)
298300
assertUsageError(t, err)
299301
})
302+
303+
bt.Run("Token expiration in session run after errored connection acquisition", func(t *testing.T) {
304+
_, pool, sess := createSession()
305+
pool.BorrowErr = tokenExpiredErr
306+
307+
_, err := sess.Run("cypher", map[string]interface{}{})
308+
309+
assertTokenExpiredError(t, err)
310+
})
311+
312+
bt.Run("Token expiration after run", func(t *testing.T) {
313+
_, pool, sess := createSession()
314+
conn := &ConnFake{Alive: true, RunErr: tokenExpiredErr}
315+
pool.BorrowConn = conn
316+
317+
_, err := sess.Run("cypher", map[string]interface{}{})
318+
319+
assertTokenExpiredError(t, err)
320+
})
321+
322+
bt.Run("Token expiration after result collect call", func(t *testing.T) {
323+
_, pool, sess := createSession()
324+
conn := &ConnFake{Alive: true, Nexts: []Next{{Err: tokenExpiredErr}}}
325+
pool.BorrowConn = conn
326+
327+
result, err := sess.Run("cypher", map[string]interface{}{})
328+
AssertNil(t, err)
329+
_, err = result.Collect()
330+
331+
assertTokenExpiredError(t, err)
332+
})
333+
334+
bt.Run("Token expiration after result consume call", func(t *testing.T) {
335+
_, pool, sess := createSession()
336+
conn := &ConnFake{Alive: true, ConsumeErr: tokenExpiredErr}
337+
pool.BorrowConn = conn
338+
339+
result, err := sess.Run("cypher", map[string]interface{}{})
340+
AssertNil(t, err)
341+
_, err = result.Consume()
342+
343+
assertTokenExpiredError(t, err)
344+
})
345+
346+
bt.Run("Token expiration after result consume next and err call", func(t *testing.T) {
347+
_, pool, sess := createSession()
348+
conn := &ConnFake{Alive: true, Nexts: []Next{{Err: tokenExpiredErr}}}
349+
pool.BorrowConn = conn
350+
351+
result, err := sess.Run("cypher", map[string]interface{}{})
352+
AssertNil(t, err)
353+
_ = result.Next()
354+
err = result.Err()
355+
356+
assertTokenExpiredError(t, err)
357+
})
358+
359+
bt.Run("Token expiration after result single record extraction", func(t *testing.T) {
360+
_, pool, sess := createSession()
361+
conn := &ConnFake{Alive: true, Nexts: []Next{{Err: tokenExpiredErr}}}
362+
pool.BorrowConn = conn
363+
364+
result, err := sess.Run("cypher", map[string]interface{}{})
365+
AssertNil(t, err)
366+
_, err = result.Single()
367+
368+
assertTokenExpiredError(t, err)
369+
})
370+
371+
bt.Run("Token expiration after write transaction function", func(t *testing.T) {
372+
_, pool, sess := createSession()
373+
conn := &ConnFake{Alive: true}
374+
pool.BorrowConn = conn
375+
376+
_, err := sess.WriteTransaction(func(tx Transaction) (interface{}, error) {
377+
return nil, tokenExpiredErr
378+
})
379+
380+
assertTokenExpiredError(t, err)
381+
})
382+
383+
bt.Run("Token expiration after read transaction function", func(t *testing.T) {
384+
_, pool, sess := createSession()
385+
conn := &ConnFake{Alive: true}
386+
pool.BorrowConn = conn
387+
388+
_, err := sess.ReadTransaction(func(tx Transaction) (interface{}, error) {
389+
return nil, tokenExpiredErr
390+
})
391+
392+
assertTokenExpiredError(t, err)
393+
})
300394
})
301395

302396
st.Run("Explicit transaction", func(bt *testing.T) {
@@ -342,6 +436,57 @@ func TestSession(st *testing.T) {
342436
_, err := sess.BeginTransaction()
343437
AssertNoError(t, err)
344438
})
439+
440+
bt.Run("Token expiration after transaction begin", func(t *testing.T) {
441+
_, pool, sess := createSession()
442+
conn := &ConnFake{Alive: true, TxBeginErr: tokenExpiredErr}
443+
pool.BorrowConn = conn
444+
445+
tx, err := sess.BeginTransaction()
446+
447+
AssertNil(t, tx)
448+
assertTokenExpiredError(t, err)
449+
})
450+
451+
bt.Run("Token expiration after transaction run", func(t *testing.T) {
452+
_, pool, sess := createSession()
453+
conn := &ConnFake{Alive: true, RunTxErr: tokenExpiredErr}
454+
pool.BorrowConn = conn
455+
456+
tx, err := sess.BeginTransaction()
457+
AssertNil(t, err)
458+
_, err = tx.Run("cypher", map[string]interface{}{})
459+
460+
assertTokenExpiredError(t, err)
461+
})
462+
463+
bt.Run("Token expiration after transaction commit", func(t *testing.T) {
464+
_, pool, sess := createSession()
465+
conn := &ConnFake{Alive: true, TxCommitErr: tokenExpiredErr}
466+
pool.BorrowConn = conn
467+
468+
tx, err := sess.BeginTransaction()
469+
AssertNil(t, err)
470+
_, err = tx.Run("cypher", map[string]interface{}{})
471+
AssertNil(t, err)
472+
err = tx.Commit()
473+
474+
assertTokenExpiredError(t, err)
475+
})
476+
477+
bt.Run("Token expiration after transaction rollback", func(t *testing.T) {
478+
_, pool, sess := createSession()
479+
conn := &ConnFake{Alive: true, TxRollbackErr: tokenExpiredErr}
480+
pool.BorrowConn = conn
481+
482+
tx, err := sess.BeginTransaction()
483+
AssertNil(t, err)
484+
_, err = tx.Run("cypher", map[string]interface{}{})
485+
AssertNil(t, err)
486+
err = tx.Rollback()
487+
488+
assertTokenExpiredError(t, err)
489+
})
345490
})
346491

347492
st.Run("Close", func(ct *testing.T) {
@@ -367,3 +512,9 @@ func TestSession(st *testing.T) {
367512
})
368513
})
369514
}
515+
516+
func assertTokenExpiredError(t *testing.T, err error) {
517+
AssertSameType(t, err, &TokenExpiredError{})
518+
AssertErrorMessageContains(t, err, "Neo.ClientError.Security.TokenExpired")
519+
AssertErrorMessageContains(t, err, "oopsie whoopsie")
520+
}

testkit-backend/backend.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,15 @@ func (b *backend) writeError(err error) {
115115
// track of this error so that we can reuse the real thing within a retryable tx
116116
fmt.Printf("Error: %s (%T)\n", err.Error(), err)
117117
code := ""
118+
tokenErr, isTokenExpiredErr := err.(*neo4j.TokenExpiredError)
119+
if isTokenExpiredErr {
120+
code = tokenErr.Code
121+
}
118122
if neo4j.IsNeo4jError(err) {
119123
code = err.(*db.Neo4jError).Code
120124
}
121-
isDriverError := neo4j.IsNeo4jError(err) ||
125+
isDriverError := isTokenExpiredErr ||
126+
neo4j.IsNeo4jError(err) ||
122127
neo4j.IsUsageError(err) ||
123128
neo4j.IsConnectivityError(err) ||
124129
neo4j.IsTransactionExecutionLimit(err) ||
@@ -353,10 +358,14 @@ func (b *backend) handleRequest(req map[string]interface{}) {
353358
authTokenMap := data["authorizationToken"].(map[string]interface{})["data"].(map[string]interface{})
354359
switch authTokenMap["scheme"] {
355360
case "basic":
361+
realm, ok := authTokenMap["realm"].(string)
362+
if !ok {
363+
realm = ""
364+
}
356365
authToken = neo4j.BasicAuth(
357366
authTokenMap["principal"].(string),
358367
authTokenMap["credentials"].(string),
359-
authTokenMap["realm"].(string))
368+
realm)
360369
case "kerberos":
361370
authToken = neo4j.KerberosAuth(authTokenMap["ticket"].(string))
362371
case "bearer":
@@ -572,6 +581,7 @@ func (b *backend) handleRequest(req map[string]interface{}) {
572581
"Optimization:ConnectionReuse",
573582
"Optimization:ImplicitDefaultArguments",
574583
"Optimization:PullPipelining",
584+
"Feature:Auth:Bearer",
575585
},
576586
})
577587

0 commit comments

Comments
 (0)