Skip to content

Commit bafdae0

Browse files
authored
Implement new bearer authentication scheme for SSO support
* Implement new bearer authentication scheme for SSO support * Refactor expired auth error handling * Wrap SSO token expiration error into user-facing error
1 parent c7cff16 commit bafdae0

File tree

9 files changed

+246
-17
lines changed

9 files changed

+246
-17
lines changed

neo4j/auth_tokens.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const keyScheme = "scheme"
2828
const schemeNone = "none"
2929
const schemeBasic = "basic"
3030
const schemeKerberos = "kerberos"
31+
const schemeBearer = "bearer"
3132
const keyPrincipal = "principal"
3233
const keyCredentials = "credentials"
3334
const keyRealm = "realm"
@@ -67,6 +68,18 @@ func KerberosAuth(ticket string) AuthToken {
6768
return token
6869
}
6970

71+
// BearerAuth generates an authentication token with the provided base-64 value generated by a Single Sign-On provider
72+
func BearerAuth(token string) AuthToken {
73+
result := AuthToken{
74+
tokens: map[string]interface{}{
75+
keyScheme: schemeBearer,
76+
keyCredentials: token,
77+
},
78+
}
79+
80+
return result
81+
}
82+
7083
// CustomAuth generates a custom authentication token with provided parameters
7184
func CustomAuth(scheme string, username string, password string, realm string, parameters map[string]interface{}) AuthToken {
7285
tokens := map[string]interface{}{

neo4j/error.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ func IsTransactionExecutionLimit(err error) bool {
110110
return is
111111
}
112112

113+
// TokenExpiredError represent errors caused by the driver not being able to connect to Neo4j services,
114+
// or lost connections.
115+
type TokenExpiredError struct {
116+
Code string
117+
Message string
118+
}
119+
120+
func (e *TokenExpiredError) Error() string {
121+
return fmt.Sprintf("TokenExpiredError: %s (%s)", e.Code, e.Message)
122+
}
123+
113124
func wrapError(err error) error {
114125
if err == nil {
115126
return nil
@@ -134,6 +145,10 @@ func wrapError(err error) error {
134145
// most likely due to read timeout configuration hint being applied
135146
return &ConnectivityError{inner: err}
136147
}
148+
case *db.Neo4jError:
149+
if e.Code == "Neo.ClientError.Security.TokenExpired" {
150+
return &TokenExpiredError{Code: e.Code, Message: e.Msg}
151+
}
137152
}
138153
return err
139154
}

neo4j/internal/bolt/bolt4.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,8 @@ func (b *bolt4) setError(err error, fatal bool) {
157157
b.state = bolt4_failed
158158
}
159159

160-
neo4jErr, _ := err.(*db.Neo4jError)
161160
// Increase severity even if it was a previous error
162-
// Treat expired auth as fatal so that pool is cleaned up of old connections
163-
if fatal || (neo4jErr != nil && neo4jErr.Code == "Status.Security.AuthorizationExpired") {
161+
if fatal {
164162
b.state = bolt4_dead
165163
}
166164

@@ -171,7 +169,8 @@ func (b *bolt4) setError(err error, fatal bool) {
171169
}
172170

173171
// Do not log big cypher statements as errors
174-
if neo4jErr != nil && neo4jErr.Classification() == "ClientError" {
172+
neo4jErr, casted := err.(*db.Neo4jError)
173+
if casted && neo4jErr.Classification() == "ClientError" {
175174
b.log.Debugf(log.Bolt4, b.logId, "%s", err)
176175
} else {
177176
b.log.Error(log.Bolt4, b.logId, err)
@@ -205,7 +204,7 @@ func (b *bolt4) receiveSuccess() *success {
205204
}
206205
return v
207206
case *db.Neo4jError:
208-
b.setError(v, false)
207+
b.setError(v, isFatalError(v))
209208
return nil
210209
default:
211210
// Unexpected message received
@@ -805,7 +804,7 @@ func (b *bolt4) receiveNext() (*db.Record, bool, *db.Summary) {
805804
b.checkStreams()
806805
return nil, false, sum
807806
case *db.Neo4jError:
808-
b.setError(x, false) // Will detach the stream
807+
b.setError(x, isFatalError(x)) // Will detach the stream
809808
return nil, false, nil
810809
default:
811810
// Unknown territory
@@ -983,3 +982,8 @@ func (b *bolt4) initializeReadTimeoutHint(hints map[string]interface{}) {
983982
}
984983
b.in.connReadTimeout = time.Duration(readTimeout) * time.Second
985984
}
985+
986+
func isFatalError(err *db.Neo4jError) bool {
987+
// Treat expired auth as fatal so that pool is cleaned up of old connections
988+
return err != nil && err.Code == "Status.Security.AuthorizationExpired"
989+
}

neo4j/internal/bolt/bolt4_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,4 +871,29 @@ func TestBolt4(ot *testing.T) {
871871
assertBoltState(t, bolt4_dead, bolt)
872872
AssertError(t, err)
873873
})
874+
875+
ot.Run("Immediately expired authentication token error triggers a connection failure", func(t *testing.T) {
876+
bolt, cleanup := connectToServer(t, func(srv *bolt4server) {
877+
srv.accept(4)
878+
srv.sendFailureMsg("Neo.ClientError.Security.TokenExpired", "SSO token is... expired")
879+
})
880+
defer cleanup()
881+
882+
_, err := bolt.Run(db.Command{Cypher: "MATCH (n) RETURN n"}, db.TxConfig{Mode: db.ReadMode})
883+
assertBoltState(t, bolt4_failed, bolt)
884+
AssertError(t, err)
885+
})
886+
887+
ot.Run("Expired authentication token error after run triggers a connection failure", func(t *testing.T) {
888+
bolt, cleanup := connectToServer(t, func(srv *bolt4server) {
889+
srv.accept(4)
890+
srv.waitForRun()
891+
srv.sendFailureMsg("Neo.ClientError.Security.TokenExpired", "SSO token is... expired")
892+
})
893+
defer cleanup()
894+
895+
_, err := bolt.Run(db.Command{Cypher: "MATCH (n) RETURN n"}, db.TxConfig{Mode: db.ReadMode})
896+
assertBoltState(t, bolt4_failed, bolt)
897+
AssertError(t, err)
898+
})
874899
}

neo4j/internal/testutil/asserts.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
package testutil
2222

2323
import (
24+
"fmt"
2425
"io"
2526
"reflect"
2627
"strings"
@@ -81,10 +82,15 @@ func AssertNoError(t *testing.T, err error) {
8182
}
8283
}
8384

85+
func AssertErrorMessageContains(t *testing.T, err error, msg string, args ...interface{}) {
86+
AssertError(t, err)
87+
AssertStringContain(t, err.Error(), fmt.Sprintf(msg, args...))
88+
}
89+
8490
func AssertError(t *testing.T, err error) {
8591
t.Helper()
8692
if err == nil {
87-
t.Fatal("Expected an error but it wasn't")
93+
t.Fatal("Expected an error but none occurred")
8894
}
8995
}
9096

neo4j/internal/testutil/connfake.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ type ConnFake struct {
4848
Table *db.RoutingTable
4949
Err error
5050
Id int
51+
TxBeginErr error
5152
TxBeginHandle db.TxHandle
53+
RunErr error
5254
RunStream db.StreamHandle
55+
RunTxErr error
5356
RunTxStream db.StreamHandle
5457
Nexts []Next
5558
Bookm string
@@ -112,7 +115,7 @@ func (c *ConnFake) GetRoutingTable(context map[string]string, bookmarks []string
112115

113116
func (c *ConnFake) TxBegin(txConfig db.TxConfig) (db.TxHandle, error) {
114117
c.RecordedTxs = append(c.RecordedTxs, RecordedTx{Origin: "TxBegin", Mode: txConfig.Mode, Bookmarks: txConfig.Bookmarks, Timeout: txConfig.Timeout, Meta: txConfig.Meta})
115-
return c.TxBeginHandle, c.Err
118+
return c.TxBeginHandle, c.TxBeginErr
116119
}
117120

118121
func (c *ConnFake) TxRollback(tx db.TxHandle) error {
@@ -129,11 +132,11 @@ func (c *ConnFake) TxCommit(tx db.TxHandle) error {
129132
func (c *ConnFake) Run(runCommand db.Command, txConfig db.TxConfig) (db.StreamHandle, error) {
130133

131134
c.RecordedTxs = append(c.RecordedTxs, RecordedTx{Origin: "Run", Mode: txConfig.Mode, Bookmarks: txConfig.Bookmarks, Timeout: txConfig.Timeout, Meta: txConfig.Meta})
132-
return c.RunStream, c.Err
135+
return c.RunStream, c.RunErr
133136
}
134137

135138
func (c *ConnFake) RunTx(tx db.TxHandle, runCommand db.Command) (db.StreamHandle, error) {
136-
return c.RunTxStream, c.Err
139+
return c.RunTxStream, c.RunTxErr
137140
}
138141

139142
func (c *ConnFake) Keys(streamHandle db.StreamHandle) ([]string, error) {

neo4j/session.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func (s *session) BeginTransaction(configurers ...func(*TransactionConfig)) (Tra
208208
})
209209
if err != nil {
210210
s.pool.Return(conn)
211-
return nil, err
211+
return nil, wrapError(err)
212212
}
213213

214214
// Create transaction wrapper

0 commit comments

Comments
 (0)