@@ -95,8 +95,14 @@ def closure():
9595
9696class TestBoundaryRegistry :
9797 def check (self , results ):
98- # Check results: verify that all threads get the same obj
99- assert len (set (results )) == 1
98+ # only one thread can succeed registration, all others should raise.
99+ expected_msg = (
100+ "Another function is already registered with key='test'. "
101+ "If you meant to override the existing function, "
102+ "consider setting allow_unsafe_override=True"
103+ )
104+ assert len (results ) == N_THREADS - 1
105+ assert results .count (expected_msg ) == N_THREADS - 1
100106
101107 def test_concurrent_threading (self ):
102108 # Defines a thread barrier that will be spawned before parallel execution
@@ -106,28 +112,32 @@ def test_concurrent_threading(self):
106112 # This object will be shared by all the threads.
107113 registry = BoundaryRegistry ()
108114
109- def test_recipe (
110- same_side_active_layer ,
111- same_side_ghost_layer ,
112- opposite_side_active_layer ,
113- opposite_side_ghost_layer ,
114- weight_same_side_active_layer ,
115- weight_same_side_ghost_layer ,
116- weight_opposite_side_active_layer ,
117- weight_opposite_side_ghost_layer ,
118- side ,
119- metadata ,
120- ): ...
121-
122115 results = []
123116
124117 def closure ():
118+ def test_recipe (
119+ same_side_active_layer ,
120+ same_side_ghost_layer ,
121+ opposite_side_active_layer ,
122+ opposite_side_ghost_layer ,
123+ weight_same_side_active_layer ,
124+ weight_same_side_ghost_layer ,
125+ weight_opposite_side_active_layer ,
126+ weight_opposite_side_ghost_layer ,
127+ side ,
128+ metadata ,
129+ ): ...
130+
125131 assert "test" not in registry
126132
127133 # Ensure that all threads reach this point before concurrent execution.
128134 barrier .wait ()
129- registry .register ("test" , test_recipe )
130- results .append (registry ["test" ])
135+ try :
136+ registry .register ("test" , test_recipe )
137+ except ValueError as exc :
138+ msg , * _ = exc .args
139+ results .append (msg )
140+ assert "test" in registry
131141
132142 # Spawn n threads that call _setup_host_cell_index concurrently.
133143 workers = []
@@ -150,31 +160,30 @@ def test_concurrent_pool(self):
150160 # This object will be shared by all the threads.
151161 registry = BoundaryRegistry ()
152162
153- def test_recipe (
154- same_side_active_layer ,
155- same_side_ghost_layer ,
156- opposite_side_active_layer ,
157- opposite_side_ghost_layer ,
158- weight_same_side_active_layer ,
159- weight_same_side_ghost_layer ,
160- weight_opposite_side_active_layer ,
161- weight_opposite_side_ghost_layer ,
162- side ,
163- metadata ,
164- ): ...
165-
166- results = []
167-
168163 def closure ():
164+ def test_recipe (
165+ same_side_active_layer ,
166+ same_side_ghost_layer ,
167+ opposite_side_active_layer ,
168+ opposite_side_ghost_layer ,
169+ weight_same_side_active_layer ,
170+ weight_same_side_ghost_layer ,
171+ weight_opposite_side_active_layer ,
172+ weight_opposite_side_ghost_layer ,
173+ side ,
174+ metadata ,
175+ ): ...
176+
169177 assert "test" not in registry
170178
171179 # Ensure that all threads reach this point before concurrent execution.
172180 barrier .wait ()
173181 registry .register ("test" , test_recipe )
174- results .append (registry ["test" ])
175182
176183 with ThreadPoolExecutor (max_workers = N_THREADS ) as executor :
177184 futures = [executor .submit (closure ) for _ in range (N_THREADS )]
178185
179- results = [f .result () for f in futures ]
186+ assert "test" in registry
187+ exceptions = [f .exception () for f in futures ]
188+ results = [exc .args [0 ] for exc in exceptions if exc is not None ]
180189 self .check (results )
0 commit comments