Skip to content

Commit 6065146

Browse files
authored
feat(iavl): add Pin and UnsafeBytes design for managing mmaps (#25657)
1 parent f2d4a98 commit 6065146

File tree

7 files changed

+190
-44
lines changed

7 files changed

+190
-44
lines changed

iavl/internal/mem_node.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ func (node *MemNode) Version() uint32 {
4343
}
4444

4545
// Key implements the Node interface.
46-
func (node *MemNode) Key() ([]byte, error) {
47-
return node.key, nil
46+
func (node *MemNode) Key() (UnsafeBytes, error) {
47+
return WrapSafeBytes(node.key), nil
4848
}
4949

5050
// Value implements the Node interface.
51-
func (node *MemNode) Value() ([]byte, error) {
52-
return node.value, nil
51+
func (node *MemNode) Value() (UnsafeBytes, error) {
52+
return WrapSafeBytes(node.value), nil
5353
}
5454

5555
// Left implements the Node interface.
@@ -63,8 +63,8 @@ func (node *MemNode) Right() *NodePointer {
6363
}
6464

6565
// Hash implements the Node interface.
66-
func (node *MemNode) Hash() []byte {
67-
return node.hash
66+
func (node *MemNode) Hash() UnsafeBytes {
67+
return WrapSafeBytes(node.hash)
6868
}
6969

7070
// MutateBranch implements the Node interface.
@@ -76,35 +76,37 @@ func (node *MemNode) MutateBranch(version uint32) (*MemNode, error) {
7676
}
7777

7878
// Get implements the Node interface.
79-
func (node *MemNode) Get(key []byte) (value []byte, index int64, err error) {
79+
func (node *MemNode) Get(key []byte) (value UnsafeBytes, index int64, err error) {
8080
if node.IsLeaf() {
8181
switch bytes.Compare(node.key, key) {
8282
case -1:
83-
return nil, 1, nil
83+
return UnsafeBytes{}, 1, nil
8484
case 1:
85-
return nil, 0, nil
85+
return UnsafeBytes{}, 0, nil
8686
default:
87-
return node.value, 0, nil
87+
return WrapSafeBytes(node.value), 0, nil
8888
}
8989
}
9090

9191
if bytes.Compare(key, node.key) < 0 {
92-
leftNode, err := node.left.Resolve()
92+
leftNode, pin, err := node.left.Resolve()
93+
defer pin.Unpin()
9394
if err != nil {
94-
return nil, 0, err
95+
return UnsafeBytes{}, 0, err
9596
}
9697

9798
return leftNode.Get(key)
9899
}
99100

100-
rightNode, err := node.right.Resolve()
101+
rightNode, pin, err := node.right.Resolve()
102+
defer pin.Unpin()
101103
if err != nil {
102-
return nil, 0, err
104+
return UnsafeBytes{}, 0, err
103105
}
104106

105107
value, index, err = rightNode.Get(key)
106108
if err != nil {
107-
return nil, 0, err
109+
return UnsafeBytes{}, 0, err
108110
}
109111

110112
index += node.size - rightNode.Size()

iavl/internal/mem_node_test.go

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@ func TestMemNode_Getters(t *testing.T) {
1111
right := NewNodePointer(&MemNode{})
1212
nodeId := NewNodeID(true, 5, 10)
1313

14+
testKey := []byte("testkey")
15+
testValue := []byte("testvalue")
16+
testHash := []byte("testhash")
1417
node := &MemNode{
1518
height: 3,
1619
version: 7,
1720
size: 42,
18-
key: []byte("testkey"),
19-
value: []byte("testvalue"),
20-
hash: []byte("testhash"),
21+
key: testKey,
22+
value: testValue,
23+
hash: testHash,
2124
left: left,
2225
right: right,
2326
nodeId: nodeId,
@@ -29,16 +32,16 @@ func TestMemNode_Getters(t *testing.T) {
2932
require.Equal(t, int64(42), node.Size())
3033
require.Equal(t, left, node.Left())
3134
require.Equal(t, right, node.Right())
32-
require.Equal(t, []byte("testhash"), node.Hash())
35+
require.Equal(t, testHash, node.Hash().UnsafeBytes())
3336
require.Equal(t, nodeId, node.ID())
3437

3538
key, err := node.Key()
3639
require.NoError(t, err)
37-
require.Equal(t, []byte("testkey"), key)
40+
require.Equal(t, testKey, key.UnsafeBytes())
3841

3942
value, err := node.Value()
4043
require.NoError(t, err)
41-
require.Equal(t, []byte("testvalue"), value)
44+
require.Equal(t, testValue, value.UnsafeBytes())
4245
}
4346

4447
func TestMemNode_IsLeaf(t *testing.T) {
@@ -98,12 +101,14 @@ func TestMemNode_String(t *testing.T) {
98101
}
99102

100103
func TestMemNode_MutateBranch(t *testing.T) {
104+
key := []byte("key")
105+
origHash := []byte("origHash")
101106
original := &MemNode{
102107
height: 2,
103108
version: 5,
104109
size: 10,
105-
key: []byte("key"),
106-
hash: []byte("oldhash"),
110+
key: key,
111+
hash: origHash,
107112
left: NewNodePointer(&MemNode{}),
108113
right: NewNodePointer(&MemNode{}),
109114
}
@@ -113,13 +118,13 @@ func TestMemNode_MutateBranch(t *testing.T) {
113118

114119
// Version updated, hash cleared
115120
require.Equal(t, uint32(12), mutated.Version())
116-
require.Nil(t, mutated.Hash())
121+
require.Nil(t, mutated.Hash().UnsafeBytes())
117122

118123
// Other fields preserved
119124
require.Equal(t, original.Height(), mutated.Height())
120125
require.Equal(t, original.Size(), mutated.Size())
121-
key, _ := mutated.Key()
122-
require.Equal(t, []byte("key"), key)
126+
key2, _ := mutated.Key()
127+
require.Equal(t, key, key2.UnsafeBytes())
123128
require.Equal(t, original.Left(), mutated.Left())
124129
require.Equal(t, original.Right(), mutated.Right())
125130

@@ -128,7 +133,7 @@ func TestMemNode_MutateBranch(t *testing.T) {
128133

129134
// Original unchanged
130135
require.Equal(t, uint32(5), original.Version())
131-
require.Equal(t, []byte("oldhash"), original.Hash())
136+
require.Equal(t, origHash, original.Hash().UnsafeBytes())
132137
}
133138

134139
func TestMemNode_Get_Leaf(t *testing.T) {
@@ -180,7 +185,7 @@ func TestMemNode_Get_Leaf(t *testing.T) {
180185
}
181186
val, idx, err := node.Get([]byte(tt.searchKey))
182187
require.NoError(t, err)
183-
require.Equal(t, tt.wantValue, val)
188+
require.Equal(t, tt.wantValue, val.UnsafeBytes())
184189
require.Equal(t, tt.wantIndex, idx)
185190
})
186191
}
@@ -254,7 +259,7 @@ func TestMemNode_Get_Branch(t *testing.T) {
254259
t.Run(tt.name, func(t *testing.T) {
255260
val, idx, err := root.Get([]byte(tt.searchKey))
256261
require.NoError(t, err)
257-
require.Equal(t, tt.wantValue, val)
262+
require.Equal(t, tt.wantValue, val.UnsafeBytes())
258263
require.Equal(t, tt.wantIndex, idx)
259264
})
260265
}
@@ -312,7 +317,7 @@ func TestMemNode_Get_DeeperTree(t *testing.T) {
312317
t.Run(tt.searchKey, func(t *testing.T) {
313318
val, idx, err := root.Get([]byte(tt.searchKey))
314319
require.NoError(t, err)
315-
require.Equal(t, tt.wantValue, val)
320+
require.Equal(t, tt.wantValue, val.UnsafeBytes())
316321
require.Equal(t, tt.wantIndex, idx)
317322
})
318323
}

iavl/internal/node.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ type Node interface {
1212
IsLeaf() bool
1313

1414
// Key returns the key of this node.
15-
Key() ([]byte, error)
15+
Key() (UnsafeBytes, error)
1616

17-
// Value returns the value of this node.
18-
// Calling this on a non-leaf node will return nil and possibly an error.
19-
Value() ([]byte, error)
17+
// Value returns the value of this node. It is an error to call this method on non-leaf nodes.
18+
Value() (UnsafeBytes, error)
2019

2120
// Left returns a pointer to the left child node.
2221
// If this is called on a leaf node, it returns nil.
@@ -28,7 +27,7 @@ type Node interface {
2827

2928
// Hash returns the hash of this node.
3029
// Hash may or may not have been computed yet.
31-
Hash() []byte
30+
Hash() UnsafeBytes
3231

3332
// Height returns the height of the subtree rooted at this node.
3433
Height() uint8
@@ -45,7 +44,7 @@ type Node interface {
4544
// The index is the 0-based position where the key exists or would be inserted
4645
// in sorted order among all leaf keys in this subtree. This is useful for
4746
// range queries and determining a key's position even when it doesn't exist.
48-
Get(key []byte) (value []byte, index int64, err error)
47+
Get(key []byte) (value UnsafeBytes, index int64, err error)
4948

5049
// MutateBranch creates a mutable copy of this branch node created at the specified version.
5150
// Since this is an immutable tree, whenever we need to modify a branch node, we should call this method

iavl/internal/node_pointer.go

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ import (
77

88
// NodePointer is a pointer to a Node, which may be either in-memory, on-disk or both.
99
type NodePointer struct {
10-
mem atomic.Pointer[MemNode]
11-
changeset *Changeset
12-
fileIdx uint32 // absolute index in file, 1-based, zero means we don't have an offset
13-
id NodeID
10+
mem atomic.Pointer[MemNode]
11+
// changeset *Changeset // commented to satisfy linter, will uncomment in a future PR when we wire it up
12+
fileIdx uint32 // absolute index in file, 1-based, zero means we don't have an offset
13+
id NodeID
1414
}
1515

1616
// NewNodePointer creates a new NodePointer pointing to the given in-memory node.
@@ -20,13 +20,23 @@ func NewNodePointer(memNode *MemNode) *NodePointer {
2020
return n
2121
}
2222

23-
// Resolve resolves the NodePointer to a Node, loading from memory or disk as necessary.
24-
func (p *NodePointer) Resolve() (Node, error) {
23+
// Resolve resolves the NodePointer to a Node, loading from memory or disk as necessary
24+
// as well as a Pin which MUST be unpinned after the caller is done using the node.
25+
// Resolve will ALWAYS return a valid Pin even if there is an error. For clarity and
26+
// consistency it is recommended to introduce a defer pin.Unpin() immediately after
27+
// calling Resolve and BEFORE checking the error return value like this:
28+
//
29+
// node, pin, err := nodePointer.Resolve()
30+
// defer pin.Unpin()
31+
// if err != nil {
32+
// // handle error
33+
// }
34+
func (p *NodePointer) Resolve() (Node, Pin, error) {
2535
mem := p.mem.Load()
2636
if mem != nil {
27-
return mem, nil
37+
return mem, NoopPin{}, nil
2838
}
29-
return p.changeset.Resolve(p.id, p.fileIdx)
39+
return nil, NoopPin{}, fmt.Errorf("node not in memory and on-disk loading will be implemented in a future PR")
3040
}
3141

3242
// String implements the fmt.Stringer interface.

iavl/internal/pin.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package internal
2+
3+
// Pin represents a handle that pins some memory-mapped file data in memory.
4+
// When the Pin is released via Unpin(), the data may be unmapped from memory.
5+
// Pin must be used to ensure that any UnsafeBytes obtained from memory-mapped
6+
// data remains valid while in use.
7+
// The caller must ensure that Unpin() is called exactly once
8+
// for each Pin obtained. It is recommended to use the following pattern:
9+
//
10+
// node, pin, err := nodePointer.Resolve()
11+
// defer pin.Unpin()
12+
// if err != nil {
13+
// // handle error
14+
// }
15+
//
16+
// When we are using arrays directly addressed to memory mapped files, these arrays
17+
// are not part of the normal Go garbage collected memory. We must map and unmap
18+
// these regions of memory explicitly. Pin represents a commitment to keep the memory
19+
// mapped at least until Unpin() is called. During normal operation, changeset files
20+
// will be mapped and unmapped as needed either because the file size has grown, we have
21+
// compacted a changeset, or simply to manage open file descriptors.
22+
// Under the hood pins use a reference counting mechanism to keep track of how many
23+
// active users there are of a particular memory-mapped region.
24+
type Pin interface {
25+
// Unpin releases the Pin, allowing the underlying memory to be unmapped.
26+
// Implementors should ensure that Unpin() is idempotent and only unpins the
27+
// memory once even if called multiple times.
28+
Unpin()
29+
}
30+
31+
// NoopPin is a Pin that does nothing on Unpin().
32+
type NoopPin struct{}
33+
34+
// Unpin implements the Pin interface.
35+
func (NoopPin) Unpin() {}

iavl/internal/unsafe_bytes.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package internal
2+
3+
// UnsafeBytes wraps a byte slice that may or not be a direct reference to
4+
// a memory-mapped file.
5+
// Generally, an unsafe byte slice cannot be expected to live longer than the
6+
// Pin on the object it was obtained from.
7+
// As long as it is pinned, it is safe to use the UnsafeBytes() method to get
8+
// the underlying byte slice without copying.
9+
// If the byte slice needs to be retained beyond the Pin's lifetime, the
10+
// SafeCopy() method must be used to get a safe copy of the byte slice.
11+
type UnsafeBytes struct {
12+
bz []byte
13+
safe bool
14+
}
15+
16+
// WrapUnsafeBytes wraps an unsafe byte slice as UnsafeBytes, indicating that
17+
// it is unsafe to use without copying.
18+
// Use this method when you are wrapping a byte slice obtained from a memory-mapped file.
19+
func WrapUnsafeBytes(bz []byte) UnsafeBytes {
20+
return UnsafeBytes{bz: bz, safe: false}
21+
}
22+
23+
// WrapSafeBytes wraps a safe byte slice as UnsafeBytes, indicating that
24+
// it is safe to use without copying.
25+
// Use this method when you are wrapping a byte slice that is known to be safe,
26+
// e.g., a byte slice allocated in regular garbage-collected memory.
27+
func WrapSafeBytes(bz []byte) UnsafeBytes {
28+
return UnsafeBytes{bz: bz, safe: true}
29+
}
30+
31+
// IsNil returns true if the underlying byte slice is nil.
32+
func (ub UnsafeBytes) IsNil() bool {
33+
return ub.bz == nil
34+
}
35+
36+
// UnsafeBytes returns the underlying byte slice without copying.
37+
// The caller must ensure that the byte slice is not used beyond the lifetime
38+
// of the Pin on the object it was obtained from.
39+
func (ub UnsafeBytes) UnsafeBytes() []byte {
40+
return ub.bz
41+
}
42+
43+
// SafeCopy returns a safe copy of the underlying byte slice.
44+
// If the underlying byte slice is already safe or nil, it is returned as is.
45+
// If the underlying byte slice is unsafe, a copy is made and returned.
46+
func (ub UnsafeBytes) SafeCopy() []byte {
47+
if ub.safe {
48+
return ub.bz
49+
}
50+
if ub.bz == nil {
51+
return nil
52+
}
53+
copied := make([]byte, len(ub.bz))
54+
copy(copied, ub.bz)
55+
return copied
56+
}

iavl/internal/unsafe_bytes_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package internal
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestUnsafeBytes(t *testing.T) {
10+
hello := []byte("hello")
11+
unsafe := WrapUnsafeBytes(hello)
12+
require.False(t, unsafe.IsNil())
13+
require.Equal(t, hello, unsafe.UnsafeBytes())
14+
safeCopy := unsafe.SafeCopy()
15+
require.Equal(t, hello, safeCopy)
16+
require.NotSame(t, &hello[0], &safeCopy[0]) // different underlying array
17+
18+
safe := WrapSafeBytes(hello)
19+
require.False(t, safe.IsNil())
20+
require.Equal(t, hello, safe.UnsafeBytes())
21+
safeCopy2 := safe.SafeCopy()
22+
require.Equal(t, hello, safeCopy2)
23+
require.Same(t, &hello[0], &safeCopy2[0]) // same underlying array
24+
25+
nilUnsafe := WrapUnsafeBytes(nil)
26+
require.True(t, nilUnsafe.IsNil())
27+
require.Nil(t, nilUnsafe.UnsafeBytes())
28+
require.Nil(t, nilUnsafe.SafeCopy())
29+
30+
nilSafe := WrapSafeBytes(nil)
31+
require.True(t, nilSafe.IsNil())
32+
require.Nil(t, nilSafe.UnsafeBytes())
33+
require.Nil(t, nilSafe.SafeCopy())
34+
35+
nilInit := UnsafeBytes{}
36+
require.True(t, nilInit.IsNil())
37+
require.Nil(t, nilInit.UnsafeBytes())
38+
require.Nil(t, nilInit.SafeCopy())
39+
}

0 commit comments

Comments
 (0)