Skip to content

Commit a24a421

Browse files
Merge pull request #241 from neutrinoceros/bug/boundary_registry_thread_safety
BUG: fix thread safety for `BoundaryRegistry.register`
2 parents 1bd2b7a + 4474eb2 commit a24a421

File tree

2 files changed

+60
-48
lines changed

2 files changed

+60
-48
lines changed

src/gpgi/_boundaries.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4+
from threading import Lock
45
from typing import TYPE_CHECKING, Any, Literal, cast
56

67
if TYPE_CHECKING:
@@ -26,6 +27,7 @@
2627
class BoundaryRegistry:
2728
def __init__(self) -> None:
2829
self._registry: dict[str, BoundaryRecipeT] = {}
30+
self._lock = Lock()
2931
for key, recipe in _base_registry.items():
3032
self.register(key, recipe, skip_validation=True)
3133

@@ -99,20 +101,21 @@ def register(
99101
multiple times either under the same key or another, unused key, is
100102
always safe so it does not raise.
101103
"""
102-
if key in self._registry:
103-
if recipe is self._registry[key]:
104-
return
105-
elif not allow_unsafe_override:
106-
raise ValueError(
107-
f"Another function is already registered with {key=!r}. "
108-
"If you meant to override the existing function, "
109-
"consider setting allow_unsafe_override=True"
110-
)
111-
112-
if not skip_validation:
113-
self._validate_recipe(recipe)
114-
115-
self._registry[key] = recipe
104+
with self._lock:
105+
if key in self._registry:
106+
if recipe is self._registry[key]:
107+
return
108+
elif not allow_unsafe_override:
109+
raise ValueError(
110+
f"Another function is already registered with {key=!r}. "
111+
"If you meant to override the existing function, "
112+
"consider setting allow_unsafe_override=True"
113+
)
114+
115+
if not skip_validation:
116+
self._validate_recipe(recipe)
117+
118+
self._registry[key] = recipe
116119

117120
def __getitem__(self, key: str) -> BoundaryRecipeT:
118121
return self._registry[key]

tests/test_concurrent.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,14 @@ def closure():
9595

9696
class TestBoundaryRegistry:
9797
def check(self, results):
98-
# Check results: verify that all threads get the same obj
99-
assert len(set(results)) == 1
98+
# only one thread can succeed registration, all others should raise.
99+
expected_msg = (
100+
"Another function is already registered with key='test'. "
101+
"If you meant to override the existing function, "
102+
"consider setting allow_unsafe_override=True"
103+
)
104+
assert len(results) == N_THREADS - 1
105+
assert results.count(expected_msg) == N_THREADS - 1
100106

101107
def test_concurrent_threading(self):
102108
# Defines a thread barrier that will be spawned before parallel execution
@@ -106,28 +112,32 @@ def test_concurrent_threading(self):
106112
# This object will be shared by all the threads.
107113
registry = BoundaryRegistry()
108114

109-
def test_recipe(
110-
same_side_active_layer,
111-
same_side_ghost_layer,
112-
opposite_side_active_layer,
113-
opposite_side_ghost_layer,
114-
weight_same_side_active_layer,
115-
weight_same_side_ghost_layer,
116-
weight_opposite_side_active_layer,
117-
weight_opposite_side_ghost_layer,
118-
side,
119-
metadata,
120-
): ...
121-
122115
results = []
123116

124117
def closure():
118+
def test_recipe(
119+
same_side_active_layer,
120+
same_side_ghost_layer,
121+
opposite_side_active_layer,
122+
opposite_side_ghost_layer,
123+
weight_same_side_active_layer,
124+
weight_same_side_ghost_layer,
125+
weight_opposite_side_active_layer,
126+
weight_opposite_side_ghost_layer,
127+
side,
128+
metadata,
129+
): ...
130+
125131
assert "test" not in registry
126132

127133
# Ensure that all threads reach this point before concurrent execution.
128134
barrier.wait()
129-
registry.register("test", test_recipe)
130-
results.append(registry["test"])
135+
try:
136+
registry.register("test", test_recipe)
137+
except ValueError as exc:
138+
msg, *_ = exc.args
139+
results.append(msg)
140+
assert "test" in registry
131141

132142
# Spawn n threads that call _setup_host_cell_index concurrently.
133143
workers = []
@@ -150,31 +160,30 @@ def test_concurrent_pool(self):
150160
# This object will be shared by all the threads.
151161
registry = BoundaryRegistry()
152162

153-
def test_recipe(
154-
same_side_active_layer,
155-
same_side_ghost_layer,
156-
opposite_side_active_layer,
157-
opposite_side_ghost_layer,
158-
weight_same_side_active_layer,
159-
weight_same_side_ghost_layer,
160-
weight_opposite_side_active_layer,
161-
weight_opposite_side_ghost_layer,
162-
side,
163-
metadata,
164-
): ...
165-
166-
results = []
167-
168163
def closure():
164+
def test_recipe(
165+
same_side_active_layer,
166+
same_side_ghost_layer,
167+
opposite_side_active_layer,
168+
opposite_side_ghost_layer,
169+
weight_same_side_active_layer,
170+
weight_same_side_ghost_layer,
171+
weight_opposite_side_active_layer,
172+
weight_opposite_side_ghost_layer,
173+
side,
174+
metadata,
175+
): ...
176+
169177
assert "test" not in registry
170178

171179
# Ensure that all threads reach this point before concurrent execution.
172180
barrier.wait()
173181
registry.register("test", test_recipe)
174-
results.append(registry["test"])
175182

176183
with ThreadPoolExecutor(max_workers=N_THREADS) as executor:
177184
futures = [executor.submit(closure) for _ in range(N_THREADS)]
178185

179-
results = [f.result() for f in futures]
186+
assert "test" in registry
187+
exceptions = [f.exception() for f in futures]
188+
results = [exc.args[0] for exc in exceptions if exc is not None]
180189
self.check(results)

0 commit comments

Comments
 (0)