1414
1515"""AsyncCheckpointer."""
1616
17+ import datetime
1718import sys
1819import threading
1920import time
@@ -69,23 +70,27 @@ def _background_wait_for_commit_futures(
6970 on_commit_callback : Callable [[], None ],
7071 * ,
7172 barrier_sync_key_prefix : str ,
72- sync_fn : Callable [[str ], None ],
73+ sync_fn : Callable [[str , int ], None ],
74+ timeout_secs : int ,
7375 primary_host : int | None ,
7476):
7577 """A function to be run in a background thread that waits for futures."""
7678 current_process = multihost .process_index ()
7779 current_thread_id = threading .current_thread ().name
7880 process_count = jax .process_count ()
7981 logging .info (
80- '[process=%s][thread=%s] Background save thread started.' ,
82+ '[process=%s][thread=%s] Background save thread started. Deadline for'
83+ ' this save operation is %s' ,
8184 current_process ,
8285 current_thread_id ,
86+ datetime .datetime .now () + datetime .timedelta (seconds = timeout_secs ),
8387 )
8488 thread_start_time = time .time ()
8589
8690 # Wait for commit operations to complete.
87- for commit_future in commit_futures :
88- commit_future .result ()
91+ future .ChainedFuture (commit_futures , cb = lambda : None ).result (
92+ timeout = timeout_secs
93+ )
8994 commit_duration_secs = time .time () - thread_start_time
9095 logging .info (
9196 '[process=%s][thread=%s] %d Handler Commit operations completed. Time'
@@ -109,30 +114,48 @@ def _background_wait_for_commit_futures(
109114 # All processes will wait at the barrier. When all processes are at the
110115 # barrier, the barrier will be satisfied. If not, then it will timeout.
111116 try :
117+ time_remaining_secs = future .get_remaining_time (
118+ thread_start_time , timeout_secs
119+ )
112120 sync_fn (
113121 multihost .unique_barrier_key (
114122 'async_write_complete' ,
115123 prefix = barrier_sync_key_prefix ,
116124 suffix = f'{ directory .name } ' ,
117- )
125+ ),
126+ int (time_remaining_secs * 1000 ),
118127 )
119128 except jax .errors .JaxRuntimeError as e :
120129 if sys .version_info >= (3 , 11 ):
121130 if 'DEADLINE_EXCEEDED' in str (e ):
122131 _add_deadline_exceeded_notes (e )
123- raise
132+ raise TimeoutError (
133+ 'Timed out while waiting for async_write_complete barrier.'
134+ ) from e
124135
125136 if utils .is_primary_host (primary_host ):
126137 on_commit_callback ()
127138 if process_count > 1 :
128139 # Block until process 0 completes on_commit_callback.
129- sync_fn (
130- multihost .unique_barrier_key (
131- 'async_commit_complete' ,
132- prefix = barrier_sync_key_prefix ,
133- suffix = f'{ directory .name } ' ,
134- )
135- )
140+ try :
141+ time_remaining_secs = future .get_remaining_time (
142+ thread_start_time , timeout_secs
143+ )
144+ sync_fn (
145+ multihost .unique_barrier_key (
146+ 'async_commit_complete' ,
147+ prefix = barrier_sync_key_prefix ,
148+ suffix = f'{ directory .name } ' ,
149+ ),
150+ int (time_remaining_secs * 1000 ),
151+ )
152+ except jax .errors .JaxRuntimeError as e :
153+ if sys .version_info >= (3 , 11 ):
154+ if 'DEADLINE_EXCEEDED' in str (e ):
155+ _add_deadline_exceeded_notes (e )
156+ raise TimeoutError (
157+ 'Timed out while waiting for async_commit_complete barrier.'
158+ ) from e
136159
137160 thread_duration_secs = time .time () - thread_start_time
138161 jax .monitoring .record_event_duration_secs (
@@ -163,11 +186,10 @@ def __init__(
163186 self ,
164187 * ,
165188 barrier_sync_fn : multihost .BarrierSyncFn ,
166- timeout_secs : int | None = None ,
189+ timeout_secs : int ,
167190 primary_host : Optional [int ] = 0 ,
168191 barrier_sync_key_prefix : Optional [str ] = None ,
169192 ):
170- timeout_secs = timeout_secs or multihost .coordination_timeout ()
171193 if timeout_secs <= 0 :
172194 raise ValueError (
173195 f'Timeout must be positive, but got { timeout_secs } seconds.'
@@ -188,9 +210,8 @@ def __init__(
188210 self ._thread = None
189211 self ._exception = None
190212
191- timeout_in_ms = self ._timeout_secs * 1000
192- self ._sync_fn : Callable [[str ], None ] = lambda key : barrier_sync_fn (
193- key = key , timeout_ms = timeout_in_ms
213+ self ._sync_fn : Callable [[str , int ], None ] = (
214+ lambda key , timeout_ms : barrier_sync_fn (key = key , timeout_ms = timeout_ms )
194215 )
195216
196217 def __del__ (self ):
@@ -216,6 +237,7 @@ def _thread_func(
216237 on_commit_callback ,
217238 barrier_sync_key_prefix = self ._barrier_sync_key_prefix ,
218239 sync_fn = self ._sync_fn ,
240+ timeout_secs = self ._timeout_secs ,
219241 primary_host = self ._primary_host ,
220242 )
221243 except Exception as e : # pylint: disable=broad-exception-caught
0 commit comments