Skip to content

Commit 41d6c1b

Browse files
ADR-026 Home Database Cache (#618)
* ADR-026 Home Database Cache --------- Co-authored-by: Robsdedude <[email protected]> Co-authored-by: Robsdedude <[email protected]>
1 parent ee9efc8 commit 41d6c1b

23 files changed

+1284
-243
lines changed

neo4j/directrouter.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ func (r *directRouter) InvalidateReader(string, string) {}
3434

3535
func (r *directRouter) InvalidateServer(string) {}
3636

37-
func (r *directRouter) GetOrUpdateReaders(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) {
37+
func (r *directRouter) GetOrUpdateReaders(context.Context, func(context.Context) ([]string, error), db.DatabaseSelection, *db.ReAuthToken, log.BoltLogger, func(string)) ([]string, error) {
3838
return []string{r.address}, nil
3939
}
4040

4141
func (r *directRouter) Readers(string) []string {
4242
return []string{r.address}
4343
}
4444

45-
func (r *directRouter) GetOrUpdateWriters(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) {
45+
func (r *directRouter) GetOrUpdateWriters(context.Context, func(context.Context) ([]string, error), db.DatabaseSelection, *db.ReAuthToken, log.BoltLogger, func(string)) ([]string, error) {
4646
return []string{r.address}, nil
4747
}
4848

neo4j/driver_with_context.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package neo4j
2121
import (
2222
"context"
2323
"fmt"
24+
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/homedb"
2425
"net/url"
2526
"strings"
2627
"sync"
@@ -219,8 +220,14 @@ func NewDriverWithContext(target string, auth auth.TokenManager, configurers ...
219220
d.connector.RoutingContext = routingContext
220221
d.connector.Config = d.config
221222

223+
// Create cache for home database
224+
d.cache, err = homedb.NewCache(homedb.DefaultCacheMaxSize)
225+
if err != nil {
226+
return nil, err
227+
}
228+
222229
// Let the pool use the same log ID as the driver to simplify log reading.
223-
d.pool = pool.New(d.config, d.connector.Connect, d.log, d.logId)
230+
d.pool = pool.New(d.config, d.connector.Connect, d.log, d.logId, d.cache)
224231

225232
if !routing {
226233
d.router = &directRouter{address: address}
@@ -295,12 +302,12 @@ type sessionRouter interface {
295302
// note: bookmarks are lazily supplied, only when a new routing table needs to be fetched
296303
// this is needed because custom bookmark managers may provide bookmarks from external systems
297304
// they should not be called when it is not needed (e.g. when a routing table is cached)
298-
GetOrUpdateReaders(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error)
305+
GetOrUpdateReaders(ctx context.Context, bookmarks func(context.Context) ([]string, error), dbSelection idb.DatabaseSelection, auth *idb.ReAuthToken, boltLogger log.BoltLogger, onRoutingTableUpdated func(string)) ([]string, error)
299306
// Readers returns the list of servers that can serve reads on the requested database.
300307
Readers(database string) []string
301308
// GetOrUpdateWriters returns the list of servers that can serve writes on the requested database.
302309
// note: bookmarks are lazily supplied, see Readers documentation to learn why
303-
GetOrUpdateWriters(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error)
310+
GetOrUpdateWriters(ctx context.Context, bookmarks func(context.Context) ([]string, error), dbSelection idb.DatabaseSelection, auth *idb.ReAuthToken, boltLogger log.BoltLogger, onRoutingTableUpdated func(string)) ([]string, error)
304311
// Writers returns the list of servers that can serve writes on the requested database.
305312
Writers(database string) []string
306313
// GetNameOfDefaultDatabase returns the name of the default database for the specified user.
@@ -329,15 +336,14 @@ type driverWithContext struct {
329336
// this is *not* used by default by user-created session (see NewSession)
330337
executeQueryBookmarkManager BookmarkManager
331338
auth auth.TokenManager
339+
cache *homedb.Cache
332340
}
333341

334342
func (d *driverWithContext) Target() url.URL {
335343
return *d.target
336344
}
337345

338-
// TODO 6.0: remove unused Context parameter
339-
340-
func (d *driverWithContext) NewSession(_ context.Context, config SessionConfig) SessionWithContext {
346+
func (d *driverWithContext) NewSession(ctx context.Context, config SessionConfig) SessionWithContext {
341347
if config.DatabaseName == "" {
342348
config.DatabaseName = idb.DefaultDatabase
343349
}
@@ -363,7 +369,7 @@ func (d *driverWithContext) NewSession(_ context.Context, config SessionConfig)
363369
return &erroredSessionWithContext{
364370
err: &UsageError{Message: "Trying to create session on closed driver"}}
365371
}
366-
return newSessionWithContext(d.config, config, d.router, d.pool, d.log, reAuthToken)
372+
return newSessionWithContext(ctx, d.config, config, d.router, d.pool, d.cache, d.log, reAuthToken)
367373
}
368374

369375
func (d *driverWithContext) VerifyConnectivity(ctx context.Context) error {

neo4j/driver_with_context_testkit.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,16 @@ func ForceRoutingTableUpdate(d DriverWithContext, database string, bookmarks []s
4343
FromSession: false,
4444
ForceReAuth: false,
4545
}
46-
_, err := driver.router.GetOrUpdateReaders(ctx, getBookmarks, database, auth, logger)
46+
dbSelection := idb.DatabaseSelection{Name: database}
47+
_, err := driver.router.GetOrUpdateReaders(ctx, getBookmarks, dbSelection, auth, logger, func(db string) {
48+
if dbSelection.Name == "" {
49+
dbSelection.Name = db
50+
}
51+
})
4752
if err != nil {
4853
return errorutil.WrapError(err)
4954
}
50-
_, err = driver.router.GetOrUpdateWriters(ctx, getBookmarks, database, auth, logger)
55+
_, err = driver.router.GetOrUpdateWriters(ctx, getBookmarks, dbSelection, auth, logger, nil)
5156
return errorutil.WrapError(err)
5257
}
5358

neo4j/internal/bolt/bolt3.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,3 +872,11 @@ func (b *bolt3) GetCurrentAuth() (auth.TokenManager, iauth.Token) {
872872
func (b *bolt3) Telemetry(telemetry.API, func()) {
873873
// TELEMETRY not support by this protocol version, so we ignore it.
874874
}
875+
876+
func (b *bolt3) SetPinHomeDatabaseCallback(func(context.Context, string)) {
877+
// Home database not supported by this protocol version, so we ignore it.
878+
}
879+
880+
func (b *bolt3) IsSsrEnabled() bool {
881+
return false
882+
}

neo4j/internal/bolt/bolt4.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,14 @@ func (b *bolt4) Telemetry(telemetry.API, func()) {
985985
// TELEMETRY not support by this protocol version, so we ignore it.
986986
}
987987

988+
func (b *bolt4) SetPinHomeDatabaseCallback(func(context.Context, string)) {
989+
// Home database not supported by this protocol version, so we ignore it.
990+
}
991+
992+
func (b *bolt4) IsSsrEnabled() bool {
993+
return false
994+
}
995+
988996
func (b *bolt4) helloResponseHandler(checkUtcPatch bool) responseHandler {
989997
return b.expectedSuccessHandler(b.onHelloSuccess(checkUtcPatch))
990998
}

neo4j/internal/bolt/bolt5.go

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ const (
5050
// Default fetch size
5151
const bolt5FetchSize = 1000
5252

53+
const (
54+
telemetryEnabledHintName = "telemetry.enabled"
55+
ssrEnabledHintName = "ssr.enabled"
56+
)
57+
5358
type internalTx5 struct {
5459
mode idb.AccessMode
5560
bookmarks []string
@@ -93,28 +98,30 @@ func (i *internalTx5) toMeta(logger log.Logger, logId string, version db.Protoco
9398
}
9499

95100
type bolt5 struct {
96-
state int
97-
txId idb.TxHandle
98-
streams openstreams
99-
conn io.ReadWriteCloser
100-
serverName string
101-
queue messageQueue
102-
connId string
103-
logId string
104-
serverVersion string
105-
bookmark string // Last bookmark
106-
birthDate time.Time
107-
log log.Logger
108-
databaseName string
109-
err error // Last fatal error
110-
minor int
111-
lastQid int64 // Last seen qid
112-
idleDate time.Time
113-
auth map[string]any
114-
authManager auth.TokenManager
115-
resetAuth bool
116-
errorListener ConnectionErrorListener
117-
telemetryEnabled bool
101+
state int
102+
txId idb.TxHandle
103+
streams openstreams
104+
conn io.ReadWriteCloser
105+
serverName string
106+
queue messageQueue
107+
connId string
108+
logId string
109+
serverVersion string
110+
bookmark string // Last bookmark
111+
birthDate time.Time
112+
log log.Logger
113+
databaseName string
114+
err error // Last fatal error
115+
minor int
116+
lastQid int64 // Last seen qid
117+
idleDate time.Time
118+
auth map[string]any
119+
authManager auth.TokenManager
120+
resetAuth bool
121+
errorListener ConnectionErrorListener
122+
telemetryEnabled bool
123+
ssrEnabled bool
124+
pinHomeDatabaseCallback func(context.Context, string)
118125
}
119126

120127
func NewBolt5(
@@ -322,7 +329,7 @@ func (b *bolt5) TxBegin(
322329
notificationConfig: txConfig.NotificationConfig,
323330
}
324331

325-
b.queue.appendBegin(tx.toMeta(b.log, b.logId, b.Version()), b.beginResponseHandler())
332+
b.queue.appendBegin(tx.toMeta(b.log, b.logId, b.Version()), b.beginResponseHandler(ctx))
326333
if syncMessages {
327334
if b.queue.send(ctx); b.err != nil {
328335
return 0, b.err
@@ -562,7 +569,7 @@ func (b *bolt5) run(ctx context.Context, cypher string, params map[string]any, r
562569
fetchSize := b.normalizeFetchSize(rawFetchSize)
563570
stream := &stream{fetchSize: fetchSize}
564571
b.Version()
565-
b.queue.appendRun(cypher, params, tx.toMeta(b.log, b.logId, b.Version()), b.runResponseHandler(stream))
572+
b.queue.appendRun(cypher, params, tx.toMeta(b.log, b.logId, b.Version()), b.runResponseHandler(ctx, stream))
566573
b.queue.appendPullN(fetchSize, b.pullResponseHandler(stream))
567574
if b.queue.send(ctx); b.err != nil {
568575
return nil, b.err
@@ -850,6 +857,14 @@ func (b *bolt5) SetBoltLogger(boltLogger log.BoltLogger) {
850857
b.queue.setBoltLogger(boltLogger)
851858
}
852859

860+
func (b *bolt5) SetPinHomeDatabaseCallback(callback func(context.Context, string)) {
861+
b.pinHomeDatabaseCallback = callback
862+
}
863+
864+
func (b *bolt5) IsSsrEnabled() bool {
865+
return b.ssrEnabled
866+
}
867+
853868
func (b *bolt5) ReAuth(ctx context.Context, auth *idb.ReAuthToken) error {
854869
if b.minor == 0 {
855870
return b.fallbackReAuth(ctx, auth)
@@ -998,12 +1013,19 @@ func (b *bolt5) routeResponseHandler(table **idb.RoutingTable) responseHandler {
9981013
})
9991014
}
10001015

1001-
func (b *bolt5) beginResponseHandler() responseHandler {
1002-
return b.expectedSuccessHandler(onSuccessNoOp)
1016+
func (b *bolt5) beginResponseHandler(ctx context.Context) responseHandler {
1017+
return b.expectedSuccessHandler(func(beginSuccess *success) {
1018+
if b.pinHomeDatabaseCallback != nil && beginSuccess.db != "" {
1019+
b.pinHomeDatabaseCallback(ctx, beginSuccess.db)
1020+
}
1021+
})
10031022
}
10041023

1005-
func (b *bolt5) runResponseHandler(stream *stream) responseHandler {
1024+
func (b *bolt5) runResponseHandler(ctx context.Context, stream *stream) responseHandler {
10061025
return b.expectedSuccessHandler(func(runSuccess *success) {
1026+
if b.pinHomeDatabaseCallback != nil && runSuccess.db != "" {
1027+
b.pinHomeDatabaseCallback(ctx, runSuccess.db)
1028+
}
10071029
stream.attached = true
10081030
stream.keys = runSuccess.fields
10091031
stream.qid = runSuccess.qid
@@ -1124,7 +1146,8 @@ func (b *bolt5) onHelloSuccess(helloSuccess *success) {
11241146
b.logId = connectionLogId
11251147
b.queue.setLogId(connectionLogId)
11261148
b.initializeReadTimeoutHint(helloSuccess.configurationHints)
1127-
b.initializeTelemetryHint(helloSuccess.configurationHints)
1149+
b.initializeTelemetryEnabledHint(helloSuccess.configurationHints)
1150+
b.initializeSsrEnabledHint(helloSuccess.configurationHints)
11281151
}
11291152

11301153
func (b *bolt5) onCommitSuccess(commitSuccess *success) {
@@ -1172,19 +1195,30 @@ func (b *bolt5) initializeReadTimeoutHint(hints map[string]any) {
11721195
b.queue.in.connReadTimeout = time.Duration(readTimeout) * time.Second
11731196
}
11741197

1175-
const readTelemetryHintName = "telemetry.enabled"
1198+
func (b *bolt5) initializeTelemetryEnabledHint(hints map[string]any) {
1199+
telemetryEnabledHint, ok := hints[telemetryEnabledHintName]
1200+
if !ok {
1201+
return
1202+
}
1203+
telemetryEnabled, ok := telemetryEnabledHint.(bool)
1204+
if !ok {
1205+
b.log.Infof(log.Bolt5, b.logId, `invalid %q value: %v, ignoring hint. Only boolean values are accepted`, telemetryEnabledHintName, telemetryEnabledHint)
1206+
return
1207+
}
1208+
b.telemetryEnabled = telemetryEnabled
1209+
}
11761210

1177-
func (b *bolt5) initializeTelemetryHint(hints map[string]any) {
1178-
readTelemetryHint, ok := hints[readTelemetryHintName]
1211+
func (b *bolt5) initializeSsrEnabledHint(hints map[string]any) {
1212+
ssrEnabledHint, ok := hints[ssrEnabledHintName]
11791213
if !ok {
11801214
return
11811215
}
1182-
readTelemetry, ok := readTelemetryHint.(bool)
1216+
ssrEnabled, ok := ssrEnabledHint.(bool)
11831217
if !ok {
1184-
b.log.Infof(log.Bolt5, b.logId, `invalid %q value: %v, ignoring hint. Only boolean values are accepted`, readTelemetryHintName, readTelemetryHint)
1218+
b.log.Infof(log.Bolt5, b.logId, `invalid %q value: %v, ignoring hint. Only boolean values are accepted`, ssrEnabledHintName, ssrEnabledHint)
11851219
return
11861220
}
1187-
b.telemetryEnabled = readTelemetry
1221+
b.ssrEnabled = ssrEnabled
11881222
}
11891223

11901224
func (b *bolt5) extractSummary(success *success, stream *stream) *db.Summary {

neo4j/internal/bolt/connect.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (p *protocolVersion) formatProtocol() string {
4747
// new manifest-style negotiation.
4848
var versions = [4]protocolVersion{
4949
{major: 0xFF, minor: 0x01, back: 0x00}, // Bolt manifest marker
50-
{major: 5, minor: 7, back: 7},
50+
{major: 5, minor: 8, back: 8},
5151
{major: 4, minor: 4, back: 2},
5252
{major: 3, minor: 0, back: 0},
5353
}

neo4j/internal/db/connection.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ type Connection interface {
170170
GetCurrentAuth() (auth.TokenManager, iauth.Token)
171171
// Telemetry sends telemetry information about the API usage to the server.
172172
Telemetry(api telemetry.API, onSuccess func())
173+
// SetPinHomeDatabaseCallback registers a callback to update the session's cached home database.
174+
// The callback is triggered on successful BEGIN or RUN responses containing a database name.
175+
SetPinHomeDatabaseCallback(callback func(ctx context.Context, database string))
176+
// IsSsrEnabled returns true if the connection supports Server-Side Routing.
177+
IsSsrEnabled() bool
173178
}
174179

175180
type RoutingTable struct {
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [https://neo4j.com]
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* https://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package db
19+
20+
// DatabaseSelection encapsulates the database name and whether it is guessed.
21+
type DatabaseSelection struct {
22+
Name string
23+
IsHomeDbGuess bool
24+
}

0 commit comments

Comments
 (0)