Skip to content

Commit 165e445

Browse files
Fix panic when calling driver.Close concurrently (#626)
* Fix panic when calling driver.Close concurrently Doing so might well still yield unexpected or undesired results, but at least it shouldn't cause a panic. * TestKit: close resources on disconnect & catch backend panics * fixup! TestKit: close resources on disconnect & catch backend panics * fixup! TestKit: close resources on disconnect & catch backend panics --------- Co-authored-by: Stephen Cathcart <[email protected]>
1 parent b401c3d commit 165e445

File tree

11 files changed

+108
-23
lines changed

11 files changed

+108
-23
lines changed

neo4j/driver.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type Driver interface {
3737
// or error describing the problem.
3838
VerifyConnectivity() error
3939
// Close the driver and all underlying connections
40+
// This function may not be called while the driver is in use (i.e., concurrently).
4041
Close() error
4142
// IsEncrypted determines whether the driver communication with the server
4243
// is encrypted. This is a static check. The function can also be called on

neo4j/internal/bolt/bolt3.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ func (b *bolt3) ServerName() string {
151151
return b.serverName
152152
}
153153

154+
func (b *bolt3) ConnId() string {
155+
return b.connId
156+
}
157+
154158
func (b *bolt3) ServerVersion() string {
155159
return b.serverVersion
156160
}

neo4j/internal/bolt/bolt4.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ func (b *bolt4) ServerName() string {
174174
return b.serverName
175175
}
176176

177+
func (b *bolt4) ConnId() string {
178+
return b.connId
179+
}
180+
177181
func (b *bolt4) ServerVersion() string {
178182
return b.serverVersion
179183
}

neo4j/internal/bolt/bolt5.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ func (b *bolt5) ServerName() string {
185185
return b.serverName
186186
}
187187

188+
func (b *bolt5) ConnId() string {
189+
return b.connId
190+
}
191+
188192
func (b *bolt5) ServerVersion() string {
189193
return b.serverVersion
190194
}

neo4j/internal/db/connection.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ type Connection interface {
124124
Bookmark() string
125125
// ServerName returns the name of the remote server
126126
ServerName() string
127+
// ConnId returns the connection id as assigned by the server ("" if not available)
128+
ConnId() string
127129
// ServerVersion returns the server version on pattern Neo4j/1.2.3
128130
ServerVersion() string
129131
// IsAlive returns true if the connection is fully functional.

neo4j/internal/pool/pool.go

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,22 @@ func (p *Pool) Close(ctx context.Context) {
108108
p.queueMut.Unlock()
109109
// Go through each server and close all connections to it
110110
p.serversMut.Lock()
111-
for n, s := range p.servers {
112-
s.closeAll(ctx, p.closeConnection)
113-
delete(p.servers, n)
111+
pendingConnections := 0
112+
for _, s := range p.servers {
113+
s.startClosing(ctx, p.closeConnection)
114+
pendingConnections += s.size()
114115
}
115116
p.serversMut.Unlock()
116-
p.log.Infof(log.Pool, p.logId, "Closed")
117+
if pendingConnections == 0 {
118+
p.log.Infof(log.Pool, p.logId, "Closed")
119+
} else {
120+
p.log.Warnf(
121+
log.Pool,
122+
p.logId,
123+
"Called close with %d in-flight connections (will be closed when work is done).",
124+
pendingConnections,
125+
)
126+
}
117127
}
118128

119129
// For testing
@@ -194,8 +204,8 @@ func (p *Pool) Borrow(
194204
auth *idb.ReAuthToken,
195205
) (idb.Connection, error) {
196206
for {
197-
if p.closed {
198-
return nil, &errorutil.PoolClosed{}
207+
if err := p.checkClosed(); err != nil {
208+
return nil, err
199209
}
200210
serverNames := getServerNames()
201211
if len(serverNames) == 0 {
@@ -294,6 +304,10 @@ func (p *Pool) tryBorrow(
294304
var unlock = new(sync.Once)
295305
defer unlock.Do(p.serversMut.Unlock)
296306

307+
if err := p.checkClosed(); err != nil {
308+
return nil, err
309+
}
310+
297311
srv := p.servers[serverName]
298312
for {
299313
if srv != nil {
@@ -351,6 +365,13 @@ func (p *Pool) tryBorrow(
351365
return c, nil
352366
}
353367

368+
func (p *Pool) checkClosed() error {
369+
if p.closed {
370+
return &errorutil.PoolClosed{}
371+
}
372+
return nil
373+
}
374+
354375
func (p *Pool) unreg(ctx context.Context, serverName string, c idb.Connection, now time.Time) {
355376
p.serversMut.Lock()
356377
defer p.serversMut.Unlock()
@@ -396,13 +417,19 @@ func (p *Pool) closeConnection(ctx context.Context, c idb.Connection) {
396417
func (p *Pool) Return(ctx context.Context, c idb.Connection) {
397418
if p.closed {
398419
p.log.Warnf(log.Pool, p.logId, "Trying to return connection to closed pool")
399-
return
400420
}
401421

402422
// Get the name of the server that the connection belongs to.
403423
serverName := c.ServerName()
404424
isAlive := c.IsAlive()
405-
p.log.Debugf(log.Pool, p.logId, "Returning connection to %s {alive:%t}", serverName, isAlive)
425+
p.log.Debugf(
426+
log.Pool,
427+
p.logId,
428+
"Returning connection %s to %s {alive:%t}",
429+
c.ConnId(),
430+
serverName,
431+
isAlive,
432+
)
406433

407434
// If the connection is dead, remove all other idle connections on the same server that older
408435
// or of the same age as the dead connection, otherwise perform normal cleanup of old connections

neo4j/internal/pool/server.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,6 @@ func (s *server) removeIdleOlderThan(ctx context.Context, now time.Time, maxAge
196196
}
197197
}
198198

199-
func (s *server) closeAll(ctx context.Context, close closeFunc) {
200-
s.closeAndEmptyConnections(ctx, &s.idle, close)
201-
// Closing the busy connections could mean here that we do close from another thread.
202-
s.closeAndEmptyConnections(ctx, &s.busy, close)
203-
}
204-
205199
func (s *server) executeForAllConnections(callback func(c db.Connection)) {
206200
for item := s.busy.Front(); item != nil; item = item.Next() {
207201
callback(item.Value.(db.Connection))

neo4j/internal/pool/server_test.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,24 @@ func TestServer(ot *testing.T) {
133133
}
134134
})
135135

136-
ot.Run("closeAll clears all connections", func(t *testing.T) {
136+
ot.Run("startClosing sets closing flag", func(t *testing.T) {
137137
s := NewServer()
138138
// Register and return three connections
139139
_, _ = populateServer(s, time.Now(), 3, 3)
140-
s.closeAll(context.Background(), closeConnection)
141-
if s.idle.Len() != 0 {
142-
t.Errorf("Expected 0 idle connections, found %d", s.idle.Len())
143-
}
144-
if s.busy.Len() != 0 {
145-
t.Errorf("Expected 0 busy connections, found %d", s.busy.Len())
140+
s.startClosing(context.Background(), closeConnection)
141+
testutil.AssertTrue(t, s.closing)
142+
})
143+
144+
ot.Run("closing flag makes returnBusy close connections", func(t *testing.T) {
145+
s := NewServer()
146+
// Register and return three connections
147+
_, busy_connections := populateServer(s, time.Now(), 0, 3)
148+
s.startClosing(context.Background(), closeConnection)
149+
for i, c := range busy_connections {
150+
testutil.AssertFalse(t, c.Closed)
151+
s.returnBusy(context.Background(), c, closeConnection)
152+
testutil.AssertTrue(t, c.Closed)
153+
testutil.AssertIntEqual(t, s.busy.Len(), 2-i)
146154
}
147155
})
148156
}

neo4j/internal/testutil/connfake.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ type ConnFake struct {
7575
ReAuthHook func(context.Context, *idb.ReAuthToken) error
7676
SsrEnabled bool
7777
PinHomeDatabaseCallback func(context.Context, string)
78+
Closed bool
7879
}
7980

8081
func (c *ConnFake) Connect(
@@ -92,6 +93,10 @@ func (c *ConnFake) ServerName() string {
9293
return c.Name
9394
}
9495

96+
func (c *ConnFake) ConnId() string {
97+
return "bolt-1"
98+
}
99+
95100
func (c *ConnFake) IsAlive() bool {
96101
return c.Alive
97102
}
@@ -108,6 +113,7 @@ func (c *ConnFake) ForceReset(context.Context) {
108113
}
109114

110115
func (c *ConnFake) Close(ctx context.Context) {
116+
c.Closed = true
111117
}
112118

113119
func (c *ConnFake) Birthdate() time.Time {

testkit-backend/backend.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,29 @@ func (b *backend) writeLineLocked(s string) error {
190190

191191
// Reads and writes to the socket until it is closed
192192
func (b *backend) serve() {
193+
defer b.close()
193194
for b.process() {
194195
}
195196
}
196197

198+
func (b *backend) close() {
199+
b.closed = true
200+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
201+
defer cancel()
202+
for k, tx := range b.explicitTransactions {
203+
_ = tx.Close(ctx)
204+
delete(b.explicitTransactions, k)
205+
}
206+
for k, session := range b.sessionStates {
207+
_ = session.session.Close(ctx)
208+
delete(b.sessionStates, k)
209+
}
210+
for k, driver := range b.drivers {
211+
_ = driver.Close(ctx)
212+
delete(b.drivers, k)
213+
}
214+
}
215+
197216
func (b *backend) setError(err error) string {
198217
id := b.nextId()
199218
b.recordedErrors[id] = err

0 commit comments

Comments
 (0)