From 33a2a3d77f75c27019d8f7586ba2b32c6f9088f0 Mon Sep 17 00:00:00 2001 From: Tom White Date: Sat, 20 Jul 2024 12:57:24 +0100 Subject: [PATCH 1/3] Warn if peak mem exceeds `allowed_mem` --- cubed/extensions/mem_warn.py | 35 +++++++++++++++++++++++++++ cubed/tests/test_executor_features.py | 20 +++++++++++++++ cubed/tests/test_mem_utilization.py | 4 ++- 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 cubed/extensions/mem_warn.py diff --git a/cubed/extensions/mem_warn.py b/cubed/extensions/mem_warn.py new file mode 100644 index 000000000..4d1f5487a --- /dev/null +++ b/cubed/extensions/mem_warn.py @@ -0,0 +1,35 @@ +import warnings +from collections import Counter + +from cubed.runtime.pipeline import visit_nodes +from cubed.runtime.types import Callback + + +class MemoryWarningCallback(Callback): + def on_compute_start(self, event): + # store ops keyed by name + self.ops = {} + for name, node in visit_nodes(event.dag, event.resume): + primitive_op = node["primitive_op"] + self.ops[name] = primitive_op + + # count number of times each op exceeds allowed mem + self.counter = Counter() + + def on_task_end(self, event): + allowed_mem = self.ops[event.name].allowed_mem + if ( + event.peak_measured_mem_end is not None + and event.peak_measured_mem_end > allowed_mem + ): + self.counter.update({event.name: 1}) + + def on_compute_end(self, event): + if self.counter.total() > 0: + exceeded = [ + f"{k} ({v}/{self.ops[k].num_tasks})" for k, v in self.counter.items() + ] + warnings.warn( + f"Peak memory usage exceeded allowed_mem when running tasks: {', '.join(exceeded)}", + UserWarning, + ) diff --git a/cubed/tests/test_executor_features.py b/cubed/tests/test_executor_features.py index 60bb397b4..6e85402d5 100644 --- a/cubed/tests/test_executor_features.py +++ b/cubed/tests/test_executor_features.py @@ -10,6 +10,7 @@ import cubed.array_api as xp import cubed.random from cubed.extensions.history import HistoryCallback +from cubed.extensions.mem_warn import MemoryWarningCallback from cubed.extensions.rich import RichProgressBar from cubed.extensions.timeline import TimelineVisualizationCallback from cubed.extensions.tqdm import TqdmProgressBar @@ -145,6 +146,25 @@ def test_callbacks_modal(spec, modal_executor): fs.rm(tmp_path, recursive=True) +def test_mem_warn(tmp_path, executor): + if executor.name not in ("processes", "lithops"): + pytest.skip(f"{executor.name} executor does not support MemoryWarningCallback") + + spec = cubed.Spec(tmp_path, allowed_mem=200_000_000, reserved_mem=100_000_000) + mem_warn = MemoryWarningCallback() + + def func(a): + np.ones(100_000_000) # blow memory + return a + + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = cubed.map_blocks(func, a, dtype=a.dtype) + with pytest.raises( + UserWarning, match="Peak memory usage exceeded allowed_mem when running tasks" + ): + b.compute(executor=executor, callbacks=[mem_warn]) + + def test_resume(spec, executor): if executor.name == "beam": pytest.skip(f"{executor.name} executor does not support resume") diff --git a/cubed/tests/test_mem_utilization.py b/cubed/tests/test_mem_utilization.py index 853f993f8..85f067c88 100644 --- a/cubed/tests/test_mem_utilization.py +++ b/cubed/tests/test_mem_utilization.py @@ -14,6 +14,7 @@ import cubed.random from cubed.backend_array_api import namespace as nxp from cubed.extensions.history import HistoryCallback +from cubed.extensions.mem_warn import MemoryWarningCallback from cubed.runtime.executors.lithops import LithopsExecutor from cubed.tests.utils import LITHOPS_LOCAL_CONFIG @@ -277,12 +278,13 @@ def run_operation(tmp_path, name, result_array, *, optimize_function=None): # result_array.visualize(f"cubed-{name}", optimize_function=optimize_function) executor = LithopsExecutor(config=LITHOPS_LOCAL_CONFIG) hist = HistoryCallback() + mem_warn = MemoryWarningCallback() # use store=None to write to temporary zarr cubed.to_zarr( result_array, store=None, executor=executor, - callbacks=[hist], + callbacks=[hist, mem_warn], optimize_function=optimize_function, ) From cdaa022969c855100d8bd2d574769f36d31a7d02 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 22 Jul 2024 20:13:11 +0100 Subject: [PATCH 2/3] Remove usage of Python 3.10 API (Counter.total) --- cubed/extensions/mem_warn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cubed/extensions/mem_warn.py b/cubed/extensions/mem_warn.py index 4d1f5487a..03f5b2bf5 100644 --- a/cubed/extensions/mem_warn.py +++ b/cubed/extensions/mem_warn.py @@ -25,7 +25,7 @@ def on_task_end(self, event): self.counter.update({event.name: 1}) def on_compute_end(self, event): - if self.counter.total() > 0: + if sum(self.counter.values()) > 0: exceeded = [ f"{k} ({v}/{self.ops[k].num_tasks})" for k, v in self.counter.items() ] From b36ed52b4fad297592c850c8a77d40540198b8d9 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 23 Jul 2024 09:05:28 +0100 Subject: [PATCH 3/3] Don't run mem warn test on Windows --- cubed/tests/test_executor_features.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cubed/tests/test_executor_features.py b/cubed/tests/test_executor_features.py index 6e85402d5..1f99ff130 100644 --- a/cubed/tests/test_executor_features.py +++ b/cubed/tests/test_executor_features.py @@ -146,6 +146,9 @@ def test_callbacks_modal(spec, modal_executor): fs.rm(tmp_path, recursive=True) +@pytest.mark.skipif( + platform.system() == "Windows", reason="measuring memory does not run on windows" +) def test_mem_warn(tmp_path, executor): if executor.name not in ("processes", "lithops"): pytest.skip(f"{executor.name} executor does not support MemoryWarningCallback")