Skip to content

Commit 203783a

Browse files
mxberlotOrbax Authors
authored andcommitted
Add DeferredPathAsyncCheckpointHandler and DeferredWritableTemporaryPath to AsyncCheckpointer
PiperOrigin-RevId: 884274073
1 parent ae44dc1 commit 203783a

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,15 @@ async def _save(
504504
ckpt_args = checkpointer.construct_checkpoint_args(
505505
self._handler, True, *args, **kwargs
506506
)
507+
if isinstance(
508+
self._handler,
509+
async_checkpoint_handler.DeferredPathAsyncCheckpointHandler,
510+
) and isinstance(tmpdir, atomicity.DeferredWritableTemporaryPath):
511+
path = tmpdir.get_awaitable_path()
512+
else:
513+
path = tmpdir.get()
507514
commit_ops.extend(
508-
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
515+
await self._handler.async_save(path, args=ckpt_args) or []
509516
)
510517
commit_ops, _ = jax.tree.flatten(commit_ops)
511518
commit_ops = [op for op in commit_ops if op is not None]

checkpoint/orbax/checkpoint/test_utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,3 +832,83 @@ def is_compression_used(
832832

833833
else:
834834
return read_spec['metadata']['compressor'] is not None
835+
836+
837+
class MockDeferredWritableTemporaryPath(
838+
atomicity.DeferredWritableTemporaryPath
839+
):
840+
"""Mock DeferredWritableTemporaryPath for testing deferred path resolution."""
841+
842+
instances = []
843+
844+
def __init__(
845+
self,
846+
final_path: epath.Path,
847+
*,
848+
checkpoint_metadata_store=None,
849+
file_options=None,
850+
):
851+
from orbax.checkpoint.google.path import tfhub_atomicity # pylint: disable=g-import-not-at-top
852+
853+
self._deferred_path = tfhub_atomicity.DeferredPath()
854+
super().__init__(
855+
temporary_path=None,
856+
final_path=final_path,
857+
checkpoint_metadata_store=checkpoint_metadata_store,
858+
file_options=file_options,
859+
)
860+
MockDeferredWritableTemporaryPath.instances.append(self)
861+
862+
@property
863+
def deferred_path(self):
864+
return self._deferred_path
865+
866+
def get_awaitable_path(self):
867+
return self._deferred_path
868+
869+
def get(self):
870+
if self._tmp_path is None:
871+
raise ValueError(
872+
'Temporary path has not been created yet. Please call `create` first.'
873+
)
874+
return self._tmp_path
875+
876+
def get_final(self):
877+
return self._final_path
878+
879+
async def create(self):
880+
self._tmp_path = await self._deferred_path.await_creation()
881+
self._tmp_path.mkdir(parents=True, exist_ok=True)
882+
return self._tmp_path
883+
884+
async def finalize(self, **kwargs):
885+
pass
886+
887+
@classmethod
888+
def from_final(
889+
cls,
890+
final_path,
891+
*,
892+
checkpoint_metadata_store=None,
893+
file_options=None,
894+
use_snapshot=None,
895+
):
896+
return cls(
897+
final_path,
898+
checkpoint_metadata_store=checkpoint_metadata_store,
899+
file_options=file_options,
900+
)
901+
902+
@classmethod
903+
def from_temporary(
904+
cls, temporary_path, *, file_options=None, use_snapshot=None
905+
):
906+
raise NotImplementedError
907+
908+
@classmethod
909+
async def validate(cls, temporary_path):
910+
pass
911+
912+
@classmethod
913+
async def validate_final(cls, final_path):
914+
pass

0 commit comments

Comments
 (0)