Skip to content

Commit 50f59a9

Browse files
committed
Merge branch '4.4' into optional-realm-in-testkit-auth-token
2 parents 0e482e5 + bafdae0 commit 50f59a9

File tree

12 files changed

+311
-54
lines changed

12 files changed

+311
-54
lines changed

neo4j/auth_tokens.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const keyScheme = "scheme"
2828
const schemeNone = "none"
2929
const schemeBasic = "basic"
3030
const schemeKerberos = "kerberos"
31+
const schemeBearer = "bearer"
3132
const keyPrincipal = "principal"
3233
const keyCredentials = "credentials"
3334
const keyRealm = "realm"
@@ -70,6 +71,18 @@ func KerberosAuth(ticket string) AuthToken {
7071
return token
7172
}
7273

74+
// BearerAuth generates an authentication token with the provided base-64 value generated by a Single Sign-On provider
75+
func BearerAuth(token string) AuthToken {
76+
result := AuthToken{
77+
tokens: map[string]interface{}{
78+
keyScheme: schemeBearer,
79+
keyCredentials: token,
80+
},
81+
}
82+
83+
return result
84+
}
85+
7386
// CustomAuth generates a custom authentication token with provided parameters
7487
func CustomAuth(scheme string, username string, password string, realm string, parameters map[string]interface{}) AuthToken {
7588
tokens := map[string]interface{}{

neo4j/error.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ func IsTransactionExecutionLimit(err error) bool {
110110
return is
111111
}
112112

113+
// TokenExpiredError represent errors caused by the driver not being able to connect to Neo4j services,
114+
// or lost connections.
115+
type TokenExpiredError struct {
116+
Code string
117+
Message string
118+
}
119+
120+
func (e *TokenExpiredError) Error() string {
121+
return fmt.Sprintf("TokenExpiredError: %s (%s)", e.Code, e.Message)
122+
}
123+
113124
func wrapError(err error) error {
114125
if err == nil {
115126
return nil
@@ -134,6 +145,10 @@ func wrapError(err error) error {
134145
// most likely due to read timeout configuration hint being applied
135146
return &ConnectivityError{inner: err}
136147
}
148+
case *db.Neo4jError:
149+
if e.Code == "Neo.ClientError.Security.TokenExpired" {
150+
return &TokenExpiredError{Code: e.Code, Message: e.Msg}
151+
}
137152
}
138153
return err
139154
}

neo4j/internal/bolt/bolt4.go

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,26 @@ func (i *internalTx4) toMeta() map[string]interface{} {
7474
}
7575

7676
type bolt4 struct {
77-
state int
78-
txId db.TxHandle
79-
streams openstreams
80-
conn net.Conn
81-
serverName string
82-
out outgoing
83-
in incoming
84-
connId string
85-
logId string
86-
serverVersion string
87-
tfirst int64 // Time that server started streaming
88-
pendingTx internalTx4 // Stashed away when tx started explicitly
89-
hasPendingTx bool
90-
bookmark string // Last bookmark
91-
birthDate time.Time
92-
log log.Logger
93-
databaseName string
94-
err error // Last fatal error
95-
minor int
77+
state int
78+
txId db.TxHandle
79+
streams openstreams
80+
conn net.Conn
81+
serverName string
82+
out outgoing
83+
in incoming
84+
connId string
85+
logId string
86+
serverVersion string
87+
tfirst int64 // Time that server started streaming
88+
pendingTx internalTx4 // Stashed away when tx started explicitly
89+
hasPendingTx bool
90+
bookmark string // Last bookmark
91+
birthDate time.Time
92+
log log.Logger
93+
databaseName string
94+
err error // Last fatal error
95+
minor int
96+
lastQid int64 // Last seen qid
9697
}
9798

9899
func NewBolt4(serverName string, conn net.Conn, log log.Logger, boltLog log.BoltLogger) *bolt4 {
@@ -156,10 +157,8 @@ func (b *bolt4) setError(err error, fatal bool) {
156157
b.state = bolt4_failed
157158
}
158159

159-
neo4jErr, _ := err.(*db.Neo4jError)
160160
// Increase severity even if it was a previous error
161-
// Treat expired auth as fatal so that pool is cleaned up of old connections
162-
if fatal || (neo4jErr != nil && neo4jErr.Code == "Status.Security.AuthorizationExpired") {
161+
if fatal {
163162
b.state = bolt4_dead
164163
}
165164

@@ -170,7 +169,8 @@ func (b *bolt4) setError(err error, fatal bool) {
170169
}
171170

172171
// Do not log big cypher statements as errors
173-
if neo4jErr != nil && neo4jErr.Classification() == "ClientError" {
172+
neo4jErr, casted := err.(*db.Neo4jError)
173+
if casted && neo4jErr.Classification() == "ClientError" {
174174
b.log.Debugf(log.Bolt4, b.logId, "%s", err)
175175
} else {
176176
b.log.Error(log.Bolt4, b.logId, err)
@@ -199,9 +199,12 @@ func (b *bolt4) receiveSuccess() *success {
199199

200200
switch v := msg.(type) {
201201
case *success:
202+
if v.qid > -1 {
203+
b.lastQid = v.qid
204+
}
202205
return v
203206
case *db.Neo4jError:
204-
b.setError(v, false)
207+
b.setError(v, isFatalError(v))
205208
return nil
206209
default:
207210
// Unexpected message received
@@ -433,7 +436,7 @@ func (b *bolt4) discardStream() {
433436
// already sent a discard.
434437
discarded = true
435438
stream.fetchSize = -1
436-
if b.state == bolt4_streamingtx {
439+
if b.state == bolt4_streamingtx && stream.qid != b.lastQid {
437440
b.out.appendDiscardNQid(stream.fetchSize, stream.qid)
438441
} else {
439442
b.out.appendDiscardN(stream.fetchSize)
@@ -464,7 +467,12 @@ func (b *bolt4) sendPullN() {
464467
b.out.appendPullN(b.streams.curr.fetchSize)
465468
b.out.send(b.conn)
466469
} else if b.state == bolt4_streamingtx {
467-
b.out.appendPullNQid(b.streams.curr.fetchSize, b.streams.curr.qid)
470+
fetchSize := b.streams.curr.fetchSize
471+
if b.streams.curr.qid == b.lastQid {
472+
b.out.appendPullN(fetchSize)
473+
} else {
474+
b.out.appendPullNQid(fetchSize, b.streams.curr.qid)
475+
}
468476
b.out.send(b.conn)
469477
}
470478
}
@@ -796,7 +804,7 @@ func (b *bolt4) receiveNext() (*db.Record, bool, *db.Summary) {
796804
b.checkStreams()
797805
return nil, false, sum
798806
case *db.Neo4jError:
799-
b.setError(x, false) // Will detach the stream
807+
b.setError(x, isFatalError(x)) // Will detach the stream
800808
return nil, false, nil
801809
default:
802810
// Unknown territory
@@ -974,3 +982,8 @@ func (b *bolt4) initializeReadTimeoutHint(hints map[string]interface{}) {
974982
}
975983
b.in.connReadTimeout = time.Duration(readTimeout) * time.Second
976984
}
985+
986+
func isFatalError(err *db.Neo4jError) bool {
987+
// Treat expired auth as fatal so that pool is cleaned up of old connections
988+
return err != nil && err.Code == "Status.Security.AuthorizationExpired"
989+
}

neo4j/internal/bolt/bolt4_test.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ func TestBolt4(ot *testing.T) {
333333
srv.send(msgRecord, []interface{}{"v1"})
334334
// ... and the batch summary
335335
srv.send(msgSuccess, map[string]interface{}{"has_more": true})
336-
// Wait for the discard message
337-
srv.waitForDiscardNAndQid(-1, int(qid))
336+
// Wait for the discard message (no need for qid since the last executed query is discarded)
337+
srv.waitForDiscardN(-1)
338338
// Respond to discard with has more to indicate that there are more records
339339
srv.send(msgSuccess, map[string]interface{}{"has_more": true})
340340
// Wait for the commit
@@ -871,4 +871,29 @@ func TestBolt4(ot *testing.T) {
871871
assertBoltState(t, bolt4_dead, bolt)
872872
AssertError(t, err)
873873
})
874+
875+
ot.Run("Immediately expired authentication token error triggers a connection failure", func(t *testing.T) {
876+
bolt, cleanup := connectToServer(t, func(srv *bolt4server) {
877+
srv.accept(4)
878+
srv.sendFailureMsg("Neo.ClientError.Security.TokenExpired", "SSO token is... expired")
879+
})
880+
defer cleanup()
881+
882+
_, err := bolt.Run(db.Command{Cypher: "MATCH (n) RETURN n"}, db.TxConfig{Mode: db.ReadMode})
883+
assertBoltState(t, bolt4_failed, bolt)
884+
AssertError(t, err)
885+
})
886+
887+
ot.Run("Expired authentication token error after run triggers a connection failure", func(t *testing.T) {
888+
bolt, cleanup := connectToServer(t, func(srv *bolt4server) {
889+
srv.accept(4)
890+
srv.waitForRun()
891+
srv.sendFailureMsg("Neo.ClientError.Security.TokenExpired", "SSO token is... expired")
892+
})
893+
defer cleanup()
894+
895+
_, err := bolt.Run(db.Command{Cypher: "MATCH (n) RETURN n"}, db.TxConfig{Mode: db.ReadMode})
896+
assertBoltState(t, bolt4_failed, bolt)
897+
AssertError(t, err)
898+
})
874899
}

neo4j/internal/bolt/bolt_logging.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,11 @@ type loggedSuccess struct {
5353
TLast string `json:"t_last,omitempty"`
5454
HasMore bool `json:"has_more,omitempy"`
5555
Db string `json:"db,omitempty"`
56+
Qid int64 `json:"qid,omitempty"`
5657
}
5758

5859
func (s loggableSuccess) String() string {
59-
return serializeTrace(loggedSuccess{
60+
success := loggedSuccess{
6061
Server: s.server,
6162
ConnectionId: s.connectionId,
6263
Fields: s.fields,
@@ -65,7 +66,12 @@ func (s loggableSuccess) String() string {
6566
TLast: formatOmittingZero(s.tlast),
6667
HasMore: s.hasMore,
6768
Db: s.db,
68-
})
69+
}
70+
if s.qid > -1 {
71+
success.Qid = s.qid
72+
}
73+
return serializeTrace(success)
74+
6975
}
7076

7177
func formatOmittingZero(i int64) string {

neo4j/internal/testutil/asserts.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
package testutil
2222

2323
import (
24+
"fmt"
2425
"io"
2526
"reflect"
2627
"strings"
@@ -81,10 +82,15 @@ func AssertNoError(t *testing.T, err error) {
8182
}
8283
}
8384

85+
func AssertErrorMessageContains(t *testing.T, err error, msg string, args ...interface{}) {
86+
AssertError(t, err)
87+
AssertStringContain(t, err.Error(), fmt.Sprintf(msg, args...))
88+
}
89+
8490
func AssertError(t *testing.T, err error) {
8591
t.Helper()
8692
if err == nil {
87-
t.Fatal("Expected an error but it wasn't")
93+
t.Fatal("Expected an error but none occurred")
8894
}
8995
}
9096

neo4j/internal/testutil/connfake.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ type ConnFake struct {
4848
Table *db.RoutingTable
4949
Err error
5050
Id int
51+
TxBeginErr error
5152
TxBeginHandle db.TxHandle
53+
RunErr error
5254
RunStream db.StreamHandle
55+
RunTxErr error
5356
RunTxStream db.StreamHandle
5457
Nexts []Next
5558
Bookm string
@@ -112,7 +115,7 @@ func (c *ConnFake) GetRoutingTable(context map[string]string, bookmarks []string
112115

113116
func (c *ConnFake) TxBegin(txConfig db.TxConfig) (db.TxHandle, error) {
114117
c.RecordedTxs = append(c.RecordedTxs, RecordedTx{Origin: "TxBegin", Mode: txConfig.Mode, Bookmarks: txConfig.Bookmarks, Timeout: txConfig.Timeout, Meta: txConfig.Meta})
115-
return c.TxBeginHandle, c.Err
118+
return c.TxBeginHandle, c.TxBeginErr
116119
}
117120

118121
func (c *ConnFake) TxRollback(tx db.TxHandle) error {
@@ -129,11 +132,11 @@ func (c *ConnFake) TxCommit(tx db.TxHandle) error {
129132
func (c *ConnFake) Run(runCommand db.Command, txConfig db.TxConfig) (db.StreamHandle, error) {
130133

131134
c.RecordedTxs = append(c.RecordedTxs, RecordedTx{Origin: "Run", Mode: txConfig.Mode, Bookmarks: txConfig.Bookmarks, Timeout: txConfig.Timeout, Meta: txConfig.Meta})
132-
return c.RunStream, c.Err
135+
return c.RunStream, c.RunErr
133136
}
134137

135138
func (c *ConnFake) RunTx(tx db.TxHandle, runCommand db.Command) (db.StreamHandle, error) {
136-
return c.RunTxStream, c.Err
139+
return c.RunTxStream, c.RunTxErr
137140
}
138141

139142
func (c *ConnFake) Keys(streamHandle db.StreamHandle) ([]string, error) {

neo4j/resultsummary.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ const (
4444
type ResultSummary interface {
4545
// Server returns basic information about the server where the statement is carried out.
4646
Server() ServerInfo
47-
// Statement returns statement that has been executed.
47+
// Deprecated: since 4.4, will be removed in 5.0. Use Query instead
4848
Statement() Statement
49+
// Query returns the query that has been executed.
50+
Query() Query
4951
// StatementType returns type of statement that has been executed.
5052
StatementType() StatementType
5153
// Counters returns statistics counts for the statement.
@@ -96,10 +98,16 @@ type Counters interface {
9698
}
9799

98100
type Statement interface {
101+
Query
102+
}
103+
104+
type Query interface {
99105
// Text returns the statement's text.
100106
Text() string
101-
// Params returns the statement's parameters.
107+
// Deprecated: since 4.4, will be removed in 5.0. Use Parameters instead
102108
Params() map[string]interface{}
109+
// Parameters returns the statement's parameters.
110+
Parameters() map[string]interface{}
103111
}
104112

105113
// ServerInfo contains basic information of the server.
@@ -219,6 +227,10 @@ func (s *resultSummary) Statement() Statement {
219227
return s
220228
}
221229

230+
func (s *resultSummary) Query() Query {
231+
return s
232+
}
233+
222234
func (s *resultSummary) StatementType() StatementType {
223235
return StatementType(s.sum.StmntType)
224236
}
@@ -228,6 +240,10 @@ func (s *resultSummary) Text() string {
228240
}
229241

230242
func (s *resultSummary) Params() map[string]interface{} {
243+
return s.Parameters()
244+
}
245+
246+
func (s *resultSummary) Parameters() map[string]interface{} {
231247
return s.params
232248
}
233249

neo4j/session.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func (s *session) BeginTransaction(configurers ...func(*TransactionConfig)) (Tra
208208
})
209209
if err != nil {
210210
s.pool.Return(conn)
211-
return nil, err
211+
return nil, wrapError(err)
212212
}
213213

214214
// Create transaction wrapper
@@ -458,9 +458,7 @@ func (s *session) Close() error {
458458
var err error
459459

460460
if s.txExplicit != nil {
461-
s.txExplicit.Close()
462-
err = &UsageError{Message: "Closing session with a pending transaction"}
463-
s.log.Warnf(log.Session, s.logId, err.Error())
461+
err = s.txExplicit.Close()
464462
}
465463

466464
if s.txAuto != nil {

0 commit comments

Comments
 (0)