Skip to content

Commit 50b6dd2

Browse files
JustinPan-googOrbax Authors
authored andcommitted
No public description
PiperOrigin-RevId: 882877701
1 parent 306ffab commit 50b6dd2

File tree

5 files changed

+73
-33
lines changed

5 files changed

+73
-33
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,18 @@ will return `StepMetadata`, and will otherwise return `RootMetadata`.
2323
handlers using `StepMetadata.item_handlers` and the global `HandlerTypeRegistry`
2424
if no args are provided.
2525
- `CompositeCheckpointHandler.metadata()` now returns `StepMetadata`.
26+
- Double the default timeout from 600 to 1200 (20 minutes) in `AsyncOptions`;
27+
`timeout_secs` now becomes a mandatory parameter with default value of 1200
28+
(20 minutes) in `AsyncCheckpointer`.
2629

2730
### Fixed
2831

2932
- Fixed `get_device_memory` issue on TPU 7x devices where the device kind string
3033
was consistently reported without a space, causing a ValueError.
34+
- Fixed hanging in `AsyncCheckpointer` if timeout occurs during save. Remaining
35+
time is now calculated and applied to commit operations and synchronization
36+
barriers, ensuring that all async operations time out instead of hanging if
37+
preceding operations consume most of the timeout budget.
3138

3239
## [0.1.7] - 2022-03-29
3340

checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""AsyncCheckpointer."""
1616

17+
import datetime
1718
import sys
1819
import threading
1920
import time
@@ -69,23 +70,27 @@ def _background_wait_for_commit_futures(
6970
on_commit_callback: Callable[[], None],
7071
*,
7172
barrier_sync_key_prefix: str,
72-
sync_fn: Callable[[str], None],
73+
sync_fn: Callable[[str, int], None],
74+
timeout_secs: int,
7375
primary_host: int | None,
7476
):
7577
"""A function to be run in a background thread that waits for futures."""
7678
current_process = multihost.process_index()
7779
current_thread_id = threading.current_thread().name
7880
process_count = jax.process_count()
7981
logging.info(
80-
'[process=%s][thread=%s] Background save thread started.',
82+
'[process=%s][thread=%s] Background save thread started. Deadline for'
83+
' this save operation is %s',
8184
current_process,
8285
current_thread_id,
86+
datetime.datetime.now() + datetime.timedelta(seconds=timeout_secs),
8387
)
8488
thread_start_time = time.time()
8589

8690
# Wait for commit operations to complete.
87-
for commit_future in commit_futures:
88-
commit_future.result()
91+
future.ChainedFuture(commit_futures, cb=lambda: None).result(
92+
timeout=timeout_secs
93+
)
8994
commit_duration_secs = time.time() - thread_start_time
9095
logging.info(
9196
'[process=%s][thread=%s] %d Handler Commit operations completed. Time'
@@ -109,30 +114,48 @@ def _background_wait_for_commit_futures(
109114
# All processes will wait at the barrier. When all processes are at the
110115
# barrier, the barrier will be satisfied. If not, then it will timeout.
111116
try:
117+
time_remaining_secs = future.get_remaining_time(
118+
thread_start_time, timeout_secs
119+
)
112120
sync_fn(
113121
multihost.unique_barrier_key(
114122
'async_write_complete',
115123
prefix=barrier_sync_key_prefix,
116124
suffix=f'{directory.name}',
117-
)
125+
),
126+
int(time_remaining_secs * 1000),
118127
)
119128
except jax.errors.JaxRuntimeError as e:
120129
if sys.version_info >= (3, 11):
121130
if 'DEADLINE_EXCEEDED' in str(e):
122131
_add_deadline_exceeded_notes(e)
123-
raise
132+
raise TimeoutError(
133+
'Timed out while waiting for async_write_complete barrier.'
134+
) from e
124135

125136
if utils.is_primary_host(primary_host):
126137
on_commit_callback()
127138
if process_count > 1:
128139
# Block until process 0 completes on_commit_callback.
129-
sync_fn(
130-
multihost.unique_barrier_key(
131-
'async_commit_complete',
132-
prefix=barrier_sync_key_prefix,
133-
suffix=f'{directory.name}',
134-
)
135-
)
140+
try:
141+
time_remaining_secs = future.get_remaining_time(
142+
thread_start_time, timeout_secs
143+
)
144+
sync_fn(
145+
multihost.unique_barrier_key(
146+
'async_commit_complete',
147+
prefix=barrier_sync_key_prefix,
148+
suffix=f'{directory.name}',
149+
),
150+
int(time_remaining_secs * 1000),
151+
)
152+
except jax.errors.JaxRuntimeError as e:
153+
if sys.version_info >= (3, 11):
154+
if 'DEADLINE_EXCEEDED' in str(e):
155+
_add_deadline_exceeded_notes(e)
156+
raise TimeoutError(
157+
'Timed out while waiting for async_commit_complete barrier.'
158+
) from e
136159

137160
thread_duration_secs = time.time() - thread_start_time
138161
jax.monitoring.record_event_duration_secs(
@@ -163,11 +186,10 @@ def __init__(
163186
self,
164187
*,
165188
barrier_sync_fn: multihost.BarrierSyncFn,
166-
timeout_secs: int | None = None,
189+
timeout_secs: int,
167190
primary_host: Optional[int] = 0,
168191
barrier_sync_key_prefix: Optional[str] = None,
169192
):
170-
timeout_secs = timeout_secs or multihost.coordination_timeout()
171193
if timeout_secs <= 0:
172194
raise ValueError(
173195
f'Timeout must be positive, but got {timeout_secs} seconds.'
@@ -188,9 +210,8 @@ def __init__(
188210
self._thread = None
189211
self._exception = None
190212

191-
timeout_in_ms = self._timeout_secs * 1000
192-
self._sync_fn: Callable[[str], None] = lambda key: barrier_sync_fn(
193-
key=key, timeout_ms=timeout_in_ms
213+
self._sync_fn: Callable[[str, int], None] = (
214+
lambda key, timeout_ms: barrier_sync_fn(key=key, timeout_ms=timeout_ms)
194215
)
195216

196217
def __del__(self):
@@ -216,6 +237,7 @@ def _thread_func(
216237
on_commit_callback,
217238
barrier_sync_key_prefix=self._barrier_sync_key_prefix,
218239
sync_fn=self._sync_fn,
240+
timeout_secs=self._timeout_secs,
219241
primary_host=self._primary_host,
220242
)
221243
except Exception as e: # pylint: disable=broad-exception-caught

checkpoint/orbax/checkpoint/_src/futures/future.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,14 @@ class Future(Protocol):
170170
future, but merely wait for an ongoing operation to complete.
171171
"""
172172

173-
def result(self, timeout: Optional[int] = None) -> Any:
173+
def result(self, timeout: Optional[float] = None) -> Any:
174174
"""Waits for the future to complete its operation."""
175175
...
176176

177177

178178
class NoopFuture:
179179

180-
def result(self, timeout: Optional[int] = None) -> Any:
180+
def result(self, timeout: Optional[float] = None) -> Any:
181181
del timeout
182182
return None
183183

@@ -189,21 +189,18 @@ def __init__(self, futures: Sequence[Future], cb: Callable[[], None]):
189189
self._futures = futures
190190
self._cb = cb
191191

192-
def result(self, timeout: Optional[int] = None) -> Any:
192+
def result(self, timeout: Optional[float] = None) -> Any:
193193
"""Waits for all futures to complete."""
194194
n = len(self._futures)
195195
start = time.time()
196-
time_remaining = timeout
197196
for k, f in enumerate(self._futures):
198-
f.result(timeout=time_remaining)
199-
if time_remaining is not None:
200-
time_elapsed = time.time() - start
201-
time_remaining -= time_elapsed
202-
if time_remaining <= 0:
203-
raise TimeoutError(
204-
'ChainedFuture completed {:d}/{:d} futures but timed out after'
205-
' {:.2f} seconds.'.format(k, n, time_elapsed)
206-
)
197+
try:
198+
f.result(timeout=get_remaining_time(start, timeout))
199+
except TimeoutError as e:
200+
raise TimeoutError(
201+
f'ChainedFuture completed {k}/{n} futures but timed out after'
202+
f' {time.time() - start:.2f} seconds.'
203+
) from e
207204
time_elapsed = time.time() - start
208205
logging.vlog(
209206
1,
@@ -215,6 +212,18 @@ def result(self, timeout: Optional[int] = None) -> Any:
215212
self._cb()
216213

217214

215+
def get_remaining_time(
216+
start_time: float, timeout_secs: Optional[float]
217+
) -> Optional[float]:
218+
"""Returns remaining time in secs, or None if timeout_secs is None."""
219+
if timeout_secs is None:
220+
return None
221+
elapsed = time.time() - start_time
222+
if elapsed >= timeout_secs:
223+
raise TimeoutError(f'Timed out after {elapsed} seconds.')
224+
return timeout_secs - elapsed
225+
226+
218227
def wait_for_signals(
219228
receive_signals: Sequence[synchronization.HandlerAwaitableSignal],
220229
*,

checkpoint/orbax/checkpoint/checkpoint_manager_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def test_legacy_api_timeout(self):
404404
typing.cast(
405405
AsyncCheckpointer, manager._checkpointer
406406
)._async_manager._timeout_secs,
407-
600,
407+
1200,
408408
)
409409

410410
with CheckpointManager(

checkpoint/orbax/checkpoint/options.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class AsyncOptions:
2828
See :py:class:`.AsyncCheckpointer` for details.
2929
"""
3030

31-
timeout_secs: int = 600 # 10 minutes. Same as default in `AsyncCheckpointer`.
31+
timeout_secs: int = (
32+
1200 # 20 minutes. Same as default in `AsyncCheckpointer`.
33+
)
3234
barrier_sync_fn: Optional[multihost.BarrierSyncFn] = None
3335
post_finalization_callback: Optional[Callable[[], None]] = None
3436
create_directories_asynchronously: bool = True

0 commit comments

Comments
 (0)