Skip to content

Warn if peak mem exceeds allowed_mem #516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions cubed/extensions/mem_warn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import warnings
from collections import Counter

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback


class MemoryWarningCallback(Callback):
def on_compute_start(self, event):
# store ops keyed by name
self.ops = {}
for name, node in visit_nodes(event.dag, event.resume):
primitive_op = node["primitive_op"]
self.ops[name] = primitive_op

# count number of times each op exceeds allowed mem
self.counter = Counter()

def on_task_end(self, event):
allowed_mem = self.ops[event.name].allowed_mem
if (
event.peak_measured_mem_end is not None
and event.peak_measured_mem_end > allowed_mem
):
self.counter.update({event.name: 1})

def on_compute_end(self, event):
if sum(self.counter.values()) > 0:
exceeded = [
f"{k} ({v}/{self.ops[k].num_tasks})" for k, v in self.counter.items()
]
warnings.warn(
f"Peak memory usage exceeded allowed_mem when running tasks: {', '.join(exceeded)}",
UserWarning,
)
23 changes: 23 additions & 0 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import cubed.array_api as xp
import cubed.random
from cubed.extensions.history import HistoryCallback
from cubed.extensions.mem_warn import MemoryWarningCallback
from cubed.extensions.rich import RichProgressBar
from cubed.extensions.timeline import TimelineVisualizationCallback
from cubed.extensions.tqdm import TqdmProgressBar
Expand Down Expand Up @@ -145,6 +146,28 @@ def test_callbacks_modal(spec, modal_executor):
fs.rm(tmp_path, recursive=True)


@pytest.mark.skipif(
platform.system() == "Windows", reason="measuring memory does not run on windows"
)
def test_mem_warn(tmp_path, executor):
if executor.name not in ("processes", "lithops"):
pytest.skip(f"{executor.name} executor does not support MemoryWarningCallback")

spec = cubed.Spec(tmp_path, allowed_mem=200_000_000, reserved_mem=100_000_000)
mem_warn = MemoryWarningCallback()

def func(a):
np.ones(100_000_000) # blow memory
return a

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = cubed.map_blocks(func, a, dtype=a.dtype)
with pytest.raises(
UserWarning, match="Peak memory usage exceeded allowed_mem when running tasks"
):
b.compute(executor=executor, callbacks=[mem_warn])


def test_resume(spec, executor):
if executor.name == "beam":
pytest.skip(f"{executor.name} executor does not support resume")
Expand Down
4 changes: 3 additions & 1 deletion cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import cubed.random
from cubed.backend_array_api import namespace as nxp
from cubed.extensions.history import HistoryCallback
from cubed.extensions.mem_warn import MemoryWarningCallback
from cubed.runtime.executors.lithops import LithopsExecutor
from cubed.tests.utils import LITHOPS_LOCAL_CONFIG

Expand Down Expand Up @@ -277,12 +278,13 @@ def run_operation(tmp_path, name, result_array, *, optimize_function=None):
# result_array.visualize(f"cubed-{name}", optimize_function=optimize_function)
executor = LithopsExecutor(config=LITHOPS_LOCAL_CONFIG)
hist = HistoryCallback()
mem_warn = MemoryWarningCallback()
# use store=None to write to temporary zarr
cubed.to_zarr(
result_array,
store=None,
executor=executor,
callbacks=[hist],
callbacks=[hist, mem_warn],
optimize_function=optimize_function,
)

Expand Down
Loading