Skip to content

Commit 8e89b1e

Browse files
author
Dmitriy Matrenichev
committed
feat: add GetOrCreate and GetOrCall methods
Also add Len/Clear/Trunc utility methods. Signed-off-by: Dmitriy Matrenichev <[email protected]>
1 parent 7c7ccc3 commit 8e89b1e

File tree

3 files changed

+243
-1
lines changed

3 files changed

+243
-1
lines changed

containers/map.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,49 @@ func (m *ConcurrentMap[K, V]) Get(key K) (V, bool) {
2222
return val, ok
2323
}
2424

25+
// GetOrCreate returns the existing value for the key if present. Otherwise, it stores and returns the given value.
26+
// The loaded result is true if the value was loaded, false if stored.
27+
func (m *ConcurrentMap[K, V]) GetOrCreate(key K, val V) (V, bool) {
28+
m.mx.Lock()
29+
defer m.mx.Unlock()
30+
31+
if res, ok := m.m[key]; ok {
32+
return res, true
33+
}
34+
35+
if m.m == nil {
36+
m.m = map[K]V{}
37+
}
38+
39+
m.m[key] = val
40+
41+
return val, false
42+
}
43+
44+
// GetOrCall returns the existing value for the key if present. Otherwise, it calls fn, stores the result and returns it.
45+
// The loaded result is true if the value was loaded, false if it was created using fn.
46+
//
47+
// The main reason for this function is to avoid unnecessary allocations if you use pointer types as values, since
48+
// compiler cannot prove that the value does not escape if it's not stored.
49+
func (m *ConcurrentMap[K, V]) GetOrCall(key K, fn func() V) (V, bool) {
50+
m.mx.Lock()
51+
defer m.mx.Unlock()
52+
53+
if res, ok := m.m[key]; ok {
54+
return res, true
55+
}
56+
57+
if m.m == nil {
58+
m.m = map[K]V{}
59+
}
60+
61+
val := fn()
62+
63+
m.m[key] = val
64+
65+
return val, false
66+
}
67+
2568
// Set sets the value for the given key.
2669
func (m *ConcurrentMap[K, V]) Set(key K, val V) {
2770
m.mx.Lock()
@@ -46,6 +89,21 @@ func (m *ConcurrentMap[K, V]) Remove(key K) {
4689
delete(m.m, key)
4790
}
4891

92+
// RemoveAndGet removes the value for the given key and returns it if it exists.
93+
func (m *ConcurrentMap[K, V]) RemoveAndGet(key K) (V, bool) {
94+
m.mx.Lock()
95+
defer m.mx.Unlock()
96+
97+
if m.m == nil {
98+
return *new(V), false //nolint:gocritic
99+
}
100+
101+
val, ok := m.m[key]
102+
delete(m.m, key)
103+
104+
return val, ok
105+
}
106+
49107
// ForEach calls the given function for each key-value pair.
50108
func (m *ConcurrentMap[K, V]) ForEach(f func(K, V)) {
51109
m.mx.Lock()
@@ -56,6 +114,14 @@ func (m *ConcurrentMap[K, V]) ForEach(f func(K, V)) {
56114
}
57115
}
58116

117+
// Len returns the number of elements in the map.
118+
func (m *ConcurrentMap[K, V]) Len() int {
119+
m.mx.Lock()
120+
defer m.mx.Unlock()
121+
122+
return len(m.m)
123+
}
124+
59125
// Clear removes all key-value pairs.
60126
func (m *ConcurrentMap[K, V]) Clear() {
61127
m.mx.Lock()
@@ -65,3 +131,11 @@ func (m *ConcurrentMap[K, V]) Clear() {
65131
delete(m.m, k)
66132
}
67133
}
134+
135+
// Reset resets the underlying map.
136+
func (m *ConcurrentMap[K, V]) Reset() {
137+
m.mx.Lock()
138+
defer m.mx.Unlock()
139+
140+
m.m = nil
141+
}

containers/map_test.go

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,37 @@
55
package containers_test
66

77
import (
8+
"fmt"
9+
"math/rand"
810
"testing"
911

1012
"github.com/stretchr/testify/require"
1113

1214
"github.com/siderolabs/gen/containers"
15+
"github.com/siderolabs/gen/xsync"
1316
)
1417

1518
func TestConcurrentMap(t *testing.T) {
19+
t.Parallel()
20+
1621
t.Run("should return nothing if key doesnt exist", func(t *testing.T) {
22+
t.Parallel()
23+
1724
m := containers.ConcurrentMap[int, int]{}
1825
_, ok := m.Get(0)
1926
require.False(t, ok)
2027
})
2128

2229
t.Run("should remove nothing if map is empty", func(t *testing.T) {
30+
t.Parallel()
31+
2332
m := containers.ConcurrentMap[int, int]{}
2433
m.Remove(0)
2534
})
2635

2736
t.Run("should return setted value", func(t *testing.T) {
37+
t.Parallel()
38+
2839
m := containers.ConcurrentMap[int, int]{}
2940
m.Set(1, 1)
3041
val, ok := m.Get(1)
@@ -33,14 +44,34 @@ func TestConcurrentMap(t *testing.T) {
3344
})
3445

3546
t.Run("should remove value", func(t *testing.T) {
47+
t.Parallel()
48+
3649
m := containers.ConcurrentMap[int, int]{}
3750
m.Set(1, 1)
3851
m.Remove(1)
3952
_, ok := m.Get(1)
4053
require.False(t, ok)
54+
55+
m.Set(2, 2)
56+
got, ok := m.RemoveAndGet(2)
57+
require.True(t, ok)
58+
require.Equal(t, 2, got)
59+
60+
got, ok = m.RemoveAndGet(2)
61+
require.False(t, ok)
62+
require.Zero(t, got)
63+
64+
m.Reset()
65+
got, ok = m.RemoveAndGet(2)
66+
require.False(t, ok)
67+
require.Zero(t, got)
68+
69+
require.False(t, ok)
4170
})
4271

4372
t.Run("should call fn for every key", func(t *testing.T) {
73+
t.Parallel()
74+
4475
m := containers.ConcurrentMap[int, int]{}
4576
m.Set(1, 1)
4677
m.Set(2, 2)
@@ -52,4 +83,139 @@ func TestConcurrentMap(t *testing.T) {
5283
})
5384
require.Equal(t, 3, count)
5485
})
86+
87+
t.Run("should clear the map", func(t *testing.T) {
88+
t.Parallel()
89+
90+
m := containers.ConcurrentMap[int, int]{}
91+
m.Set(1, 1)
92+
93+
require.Equal(t, 1, m.Len())
94+
95+
m.Clear()
96+
97+
require.Equal(t, 0, m.Len())
98+
})
99+
100+
t.Run("should trunc the map", func(t *testing.T) {
101+
t.Parallel()
102+
103+
m := containers.ConcurrentMap[int, int]{}
104+
m.Set(1, 1)
105+
106+
require.Equal(t, 1, m.Len())
107+
108+
m.Reset()
109+
110+
require.Equal(t, 0, m.Len())
111+
})
112+
}
113+
114+
func TestConcurrentMap_GetOrCall(t *testing.T) {
115+
var m containers.ConcurrentMap[int, int]
116+
117+
t.Run("group", func(t *testing.T) {
118+
t.Run("try to insert value", func(t *testing.T) {
119+
parallelGetOrCall(t, &m, 100, 1000)
120+
})
121+
122+
t.Run("try to insert value #2", func(t *testing.T) {
123+
parallelGetOrCall(t, &m, 1000, 100)
124+
})
125+
})
126+
}
127+
128+
func parallelGetOrCall(t *testing.T, m *containers.ConcurrentMap[int, int], our, another int) {
129+
t.Parallel()
130+
131+
oneAnotherGet := false
132+
133+
for i := 0; i < 10000; i++ {
134+
key := int(rand.Int63n(10000))
135+
136+
res, ok := m.GetOrCall(key, func() int { return key * our })
137+
if ok {
138+
switch res {
139+
case key * our:
140+
case key * another:
141+
oneAnotherGet = true
142+
default:
143+
t.Fatalf("unexpected value %d", res)
144+
}
145+
}
146+
}
147+
148+
require.True(t, oneAnotherGet)
149+
}
150+
151+
func TestConcurrentMap_GetOrCreate(t *testing.T) {
152+
var m containers.ConcurrentMap[int, int]
153+
154+
t.Run("group", func(t *testing.T) {
155+
t.Run("try to insert value", func(t *testing.T) {
156+
parallelGetOrCreate(t, &m, 100, 1000)
157+
})
158+
159+
t.Run("try to insert value #2", func(t *testing.T) {
160+
parallelGetOrCreate(t, &m, 1000, 100)
161+
})
162+
})
163+
}
164+
165+
func parallelGetOrCreate(t *testing.T, m *containers.ConcurrentMap[int, int], our, another int) {
166+
t.Parallel()
167+
168+
oneAnotherGet := false
169+
170+
for i := 0; i < 10000; i++ {
171+
key := int(rand.Int63n(10000))
172+
173+
res, ok := m.GetOrCreate(key, key*our)
174+
if ok {
175+
switch res {
176+
case key * our:
177+
case key * another:
178+
oneAnotherGet = true
179+
default:
180+
t.Fatalf("unexpected value %d", res)
181+
}
182+
}
183+
}
184+
185+
require.True(t, oneAnotherGet)
186+
}
187+
188+
func Example_benchConcurrentMap() {
189+
var sink int
190+
191+
benchResult := testing.Benchmark(func(b *testing.B) {
192+
b.ReportAllocs()
193+
194+
var m containers.ConcurrentMap[int, *xsync.Once[int]]
195+
196+
for i := 0; i < b.N; i++ {
197+
variable := 0
198+
199+
res, _ := m.GetOrCall(10, func() *xsync.Once[int] {
200+
return &xsync.Once[int]{}
201+
})
202+
203+
sink = res.Do(func() int {
204+
variable++
205+
206+
return variable
207+
})
208+
}
209+
})
210+
211+
if benchResult.AllocsPerOp() > 0 {
212+
fmt.Println("this benchmark should not allocate memory")
213+
}
214+
215+
fmt.Println("ok")
216+
217+
// Output:
218+
// ok
219+
220+
_ = sink
55221
}

xsync/once.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
// Package xsync contains the additions to std sync package.
66
package xsync
77

8-
import "sync"
8+
import (
9+
"sync"
10+
)
911

1012
// Once is small wrapper around [sync.Once]. It stores the result inside.
1113
type Once[T any] struct {

0 commit comments

Comments
 (0)