Skip to content

Commit 71988fe

Browse files
desertaxleclaude
andcommitted
Fix TASK_SOURCE cache policy for remote execution with cloudpickle
This PR fixes an issue where the `TASK_SOURCE` cache policy fails when used with remote execution decorators like `@ecs` that use cloudpickle to ship code to remote environments. The problem: `inspect.getsource()` fails on cloudpickled functions because the original source file doesn't exist on the remote machine. The previous fallback to hashing `__code__.co_code` (bytecode) led to inconsistent cache keys because bytecode varies across Python versions and the code object contains unstable attributes like `co_locals` that change between serialization/deserialization cycles. The solution: Store the function's source code on the `Task` object during initialization so it survives cloudpickle serialization. The `TaskSource.compute_key()` method now checks for this stored source code first before falling back to `inspect.getsource()`. This ensures stable, deterministic cache keys regardless of execution environment. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent f14ad4d commit 71988fe

File tree

4 files changed

+146
-10
lines changed

4 files changed

+146
-10
lines changed

src/prefect/cache_policies.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,13 @@ def compute_key(
291291
) -> Optional[str]:
292292
if not task_ctx:
293293
return None
294+
295+
# Use stored source code if available (works after cloudpickle serialization)
296+
source_code = getattr(task_ctx.task, "source_code", None)
297+
if source_code is not None:
298+
return hash_objects(source_code, raise_on_failure=True)
299+
300+
# Fall back to inspect.getsource for local execution
294301
try:
295302
lines = inspect.getsource(task_ctx.task)
296303
except TypeError:

src/prefect/tasks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,14 @@ def __init__(
456456
update_wrapper(self, fn)
457457
self.fn = fn
458458

459+
# Capture source code for cache key computation
460+
# This is stored on the task so it survives cloudpickle serialization
461+
# to remote environments where the source file is not available
462+
try:
463+
self.source_code: str | None = inspect.getsource(fn)
464+
except (TypeError, OSError):
465+
self.source_code = None
466+
459467
# the task is considered async if its function is async or an async
460468
# generator
461469
self.isasync: bool = inspect.iscoroutinefunction(

tests/test_cache_policies.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -283,35 +283,100 @@ def my_func(x):
283283

284284
assert key != new_key
285285

286-
def test_source_fallback_behavior(self):
286+
def test_uses_stored_source_code(self):
287+
"""Test that TaskSource uses stored source_code attribute when available."""
287288
policy = TaskSource()
288289

289-
def task_a_fn():
290-
pass
290+
mock_task_a = MagicMock()
291+
mock_task_b = MagicMock()
292+
293+
# Set different source code on each mock task
294+
mock_task_a.source_code = "def task_a():\n return 'a'"
295+
mock_task_b.source_code = "def task_b():\n return 'b'"
296+
297+
task_ctx_a = TaskRunContext.model_construct(task=mock_task_a)
298+
task_ctx_b = TaskRunContext.model_construct(task=mock_task_b)
299+
300+
key_a = policy.compute_key(
301+
task_ctx=task_ctx_a, inputs=None, flow_parameters=None
302+
)
303+
key_b = policy.compute_key(
304+
task_ctx=task_ctx_b, inputs=None, flow_parameters=None
305+
)
306+
307+
# Keys should be generated and different for different source code
308+
assert key_a is not None
309+
assert key_b is not None
310+
assert key_a != key_b
311+
312+
def test_stored_source_code_stability(self):
313+
"""Test that the same source code produces the same key consistently."""
314+
policy = TaskSource()
315+
316+
mock_task = MagicMock()
317+
mock_task.source_code = "def my_task():\n return 'hello'"
291318

292-
def task_b_fn():
293-
return 1
319+
task_ctx = TaskRunContext.model_construct(task=mock_task)
320+
321+
key1 = policy.compute_key(task_ctx=task_ctx, inputs=None, flow_parameters=None)
322+
key2 = policy.compute_key(task_ctx=task_ctx, inputs=None, flow_parameters=None)
323+
324+
# Same source code should produce same key
325+
assert key1 == key2
326+
327+
def test_fallback_when_source_code_is_none(self):
328+
"""Test that TaskSource returns None when source_code is None and getsource fails."""
329+
policy = TaskSource()
294330

295331
mock_task_a = MagicMock()
296332
mock_task_b = MagicMock()
297333

298-
mock_task_a.fn = task_a_fn
299-
mock_task_b.fn = task_b_fn
334+
mock_task_a.source_code = None
335+
mock_task_b.source_code = None
300336

301337
task_ctx_a = TaskRunContext.model_construct(task=mock_task_a)
302338
task_ctx_b = TaskRunContext.model_construct(task=mock_task_b)
303339

340+
# When source_code is None and getsource fails, policy returns None
304341
for os_error_msg in {"could not get source code", "source code not available"}:
305-
with patch("inspect.getsource", side_effect=OSError(os_error_msg)):
342+
with patch(
343+
"prefect.cache_policies.inspect.getsource",
344+
side_effect=OSError(os_error_msg),
345+
):
306346
fallback_key_a = policy.compute_key(
307347
task_ctx=task_ctx_a, inputs=None, flow_parameters=None
308348
)
309349
fallback_key_b = policy.compute_key(
310350
task_ctx=task_ctx_b, inputs=None, flow_parameters=None
311351
)
312352

313-
assert fallback_key_a and fallback_key_b
314-
assert fallback_key_a != fallback_key_b
353+
# Without stored source and without getsource, returns None
354+
assert fallback_key_a is None
355+
assert fallback_key_b is None
356+
357+
def test_returns_none_when_no_source_available(self):
358+
"""Test that TaskSource returns None when neither stored source nor getsource works."""
359+
policy = TaskSource()
360+
361+
def task_fn():
362+
return "test"
363+
364+
mock_task = MagicMock()
365+
mock_task.source_code = None
366+
mock_task.fn = task_fn
367+
368+
task_ctx = TaskRunContext.model_construct(task=mock_task)
369+
370+
# When source_code is None and getsource raises TypeError, returns None
371+
with patch(
372+
"prefect.cache_policies.inspect.getsource",
373+
side_effect=TypeError("not a function"),
374+
):
375+
key = policy.compute_key(
376+
task_ctx=task_ctx, inputs=None, flow_parameters=None
377+
)
378+
379+
assert key is None
315380

316381

317382
class TestDefaultPolicy:

tests/test_tasks.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,62 @@ def __call__(self, x):
161161
assert tt.task_key.startswith("Funky-")
162162

163163

164+
class TestTaskSourceCode:
165+
def test_source_code_captured_for_function(self):
166+
@task
167+
def my_task():
168+
return 42
169+
170+
assert my_task.source_code is not None
171+
assert "def my_task" in my_task.source_code
172+
assert "return 42" in my_task.source_code
173+
174+
def test_source_code_is_none_for_callable_object(self):
175+
class MyCallable:
176+
def __call__(self):
177+
return 42
178+
179+
callable_obj = MyCallable()
180+
my_task = Task(fn=callable_obj)
181+
182+
# Callable objects don't have source code accessible via inspect.getsource
183+
assert my_task.source_code is None
184+
185+
def test_source_code_survives_cloudpickle(self):
186+
import cloudpickle
187+
188+
@task
189+
def my_task():
190+
return "hello"
191+
192+
# Verify source code is captured
193+
original_source = my_task.source_code
194+
assert original_source is not None
195+
assert "def my_task" in original_source
196+
197+
# Serialize and deserialize the task
198+
pickled = cloudpickle.dumps(my_task)
199+
restored_task = cloudpickle.loads(pickled)
200+
201+
# Source code should survive serialization
202+
assert restored_task.source_code == original_source
203+
204+
def test_source_code_different_for_different_tasks(self):
205+
@task
206+
def task_a():
207+
return "a"
208+
209+
@task
210+
def task_b():
211+
return "b"
212+
213+
assert task_a.source_code is not None
214+
assert task_b.source_code is not None
215+
assert task_a.source_code != task_b.source_code
216+
assert "task_a" in task_a.source_code
217+
assert "task_b" in task_b.source_code
218+
219+
164220
class TestTaskRunName:
165221
def test_run_name_default(self):
166222
@task

0 commit comments

Comments
 (0)