Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions src/gpgi/_boundaries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import warnings
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, cast

Expand Down Expand Up @@ -60,13 +59,59 @@ def _validate_recipe(recipe: BoundaryRecipeT) -> None:
)

def register(
self, key: str, recipe: BoundaryRecipeT, *, skip_validation: bool = False
self,
key: str,
recipe: BoundaryRecipeT,
*,
skip_validation: bool = False,
allow_unsafe_override: bool = False,
) -> None:
"""
Register a new boundary function.

Parameters
----------
key: str
A unique identifier (ideally a meaningful name) to associate with
the function.

recipe: Callable
A function matching the signature (order and names of arguments) of
gpgi's builtin boundary recipes.

skip_validation: bool, optional, keyword-only (default: False)
If set to True, signature validation is skipped.
This is meant to allow bypassing hypothetical bugs in the validation
routine.

allow_unsafe_override: bool, optional, keyword-only (default: False)
If set to True, registering a new function under an existing key
will not raise an exception. Note however that doing so is not
thread-safe.

Raises
------
ValueError:
- if skip_validation==False and the signature of the recipe doesn't meet
the requirements.
- if allow_unsafe_override==False and a new function is being registered
under an already used key. Registering the same exact function under
multiple times either under the same key or another, unused key, is
always safe so it does not raise.
"""
if key in self._registry:
if recipe is self._registry[key]:
return
elif not allow_unsafe_override:
raise ValueError(
f"Another function is already registered with {key=!r}. "
"If you meant to override the existing function, "
"consider setting allow_unsafe_override=True"
)

if not skip_validation:
self._validate_recipe(recipe)

if key in self._registry:
warnings.warn(f"Overriding existing method {key!r}", stacklevel=2)
self._registry[key] = recipe

def __getitem__(self, key: str) -> BoundaryRecipeT:
Expand Down
51 changes: 51 additions & 0 deletions tests/test_boundary_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest

from gpgi._boundaries import BoundaryRegistry


def test_boundary_register_overrides():
registry = BoundaryRegistry()

def test_recipe1(
same_side_active_layer,
same_side_ghost_layer,
opposite_side_active_layer,
opposite_side_ghost_layer,
weight_same_side_active_layer,
weight_same_side_ghost_layer,
weight_opposite_side_active_layer,
weight_opposite_side_ghost_layer,
side,
metadata,
): ...
def test_recipe2(
same_side_active_layer,
same_side_ghost_layer,
opposite_side_active_layer,
opposite_side_ghost_layer,
weight_same_side_active_layer,
weight_same_side_ghost_layer,
weight_opposite_side_active_layer,
weight_opposite_side_ghost_layer,
side,
metadata,
): ...

registry.register("test1", test_recipe1)
assert registry["test1"] is test_recipe1

# registering the same function a second time shouldn't raise
registry.register("test1", test_recipe1)

with pytest.raises(
ValueError,
match="Another function is already registered with key='test1'",
):
registry.register("test1", test_recipe2)

# check that we raised in time to preserve state
assert registry["test1"] is test_recipe1

# if we explicitly allow unsafe mutations, this should not raise
registry.register("test1", test_recipe2, allow_unsafe_override=True)
assert registry["test1"] is test_recipe2
88 changes: 88 additions & 0 deletions tests/test_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

import gpgi
from gpgi._boundaries import BoundaryRegistry

prng = np.random.default_rng()

Expand Down Expand Up @@ -90,3 +91,90 @@ def closure():

results = [f.result() for f in futures]
self.check(results)


class TestBoundaryRegistry:
def check(self, results):
# Check results: verify that all threads get the same obj
assert len(set(results)) == 1

def test_concurrent_threading(self):
# Defines a thread barrier that will be spawned before parallel execution
# this increases the probability of concurrent access clashes.
barrier = threading.Barrier(N_THREADS)

# This object will be shared by all the threads.
registry = BoundaryRegistry()

def test_recipe(
same_side_active_layer,
same_side_ghost_layer,
opposite_side_active_layer,
opposite_side_ghost_layer,
weight_same_side_active_layer,
weight_same_side_ghost_layer,
weight_opposite_side_active_layer,
weight_opposite_side_ghost_layer,
side,
metadata,
): ...

results = []

def closure():
assert "test" not in registry

# Ensure that all threads reach this point before concurrent execution.
barrier.wait()
registry.register("test", test_recipe)
results.append(registry["test"])

# Spawn n threads that call _setup_host_cell_index concurrently.
workers = []
for _ in range(0, N_THREADS):
workers.append(threading.Thread(target=closure))

for worker in workers:
worker.start()

for worker in workers:
worker.join()

self.check(results)

def test_concurrent_pool(self):
# Defines a thread barrier that will be spawned before parallel execution
# this increases the probability of concurrent access clashes.
barrier = threading.Barrier(N_THREADS)

# This object will be shared by all the threads.
registry = BoundaryRegistry()

def test_recipe(
same_side_active_layer,
same_side_ghost_layer,
opposite_side_active_layer,
opposite_side_ghost_layer,
weight_same_side_active_layer,
weight_same_side_ghost_layer,
weight_opposite_side_active_layer,
weight_opposite_side_ghost_layer,
side,
metadata,
): ...

results = []

def closure():
assert "test" not in registry

# Ensure that all threads reach this point before concurrent execution.
barrier.wait()
registry.register("test", test_recipe)
results.append(registry["test"])

with ThreadPoolExecutor(max_workers=N_THREADS) as executor:
futures = [executor.submit(closure) for _ in range(N_THREADS)]

results = [f.result() for f in futures]
self.check(results)
48 changes: 0 additions & 48 deletions tests/test_deposit.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,54 +396,6 @@ def _my_recipe(a, b, c, d, e, f):
ds.boundary_recipes.register("my", _my_recipe)


def test_warn_register_override(capsys):
nx = ny = 64
nparticles = 100

prng = np.random.RandomState(0)
ds = gpgi.load(
geometry="cartesian",
grid={
"cell_edges": {
"x": np.linspace(-1, 1, nx),
"y": np.linspace(-1, 1, ny),
},
},
particles={
"coordinates": {
"x": 2 * (prng.normal(0.5, 0.25, nparticles) % 1 - 0.5),
"y": 2 * (prng.normal(0.5, 0.25, nparticles) % 1 - 0.5),
},
"fields": {
"mass": np.ones(nparticles),
},
},
metadata={"fac": 1},
)

def _my_recipe(
same_side_active_layer,
same_side_ghost_layer,
opposite_side_active_layer,
opposite_side_ghost_layer,
weight_same_side_active_layer,
weight_same_side_ghost_layer,
weight_opposite_side_active_layer,
weight_opposite_side_ghost_layer,
side,
metadata,
):
print("gotcha")
return same_side_active_layer * metadata["fac"]

with pytest.warns(UserWarning, match="Overriding existing method 'open'"):
ds.boundary_recipes.register("open", _my_recipe)

ds.deposit("mass", method="tsc")
out, err = capsys.readouterr()
assert out == "gotcha\n" * 4


def test_register_custom_boundary_recipe(sample_2D_dataset):
def _my_recipe(
same_side_active_layer,
Expand Down