Skip to content

Commit 4bc9a56

Browse files
author
Orbax Authors
committed
Improve load_pytree and load_checkpointables method docstrings
PiperOrigin-RevId: 884832259
1 parent d88be19 commit 4bc9a56

File tree

1 file changed

+150
-4
lines changed

1 file changed

+150
-4
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py

Lines changed: 150 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -524,11 +524,53 @@ def load_pytree(
524524
) -> tree_types.PyTreeOf[tree_types.LeafType]:
525525
"""Loads a PyTree checkpoint at the given step.
526526
527-
This function behaves similarly to :py:func:`.load_pytree` (see
528-
documentation).
527+
This function behaves similarly to
528+
:py:func:`~orbax.checkpoint.v1._src.loading.loading.load_pytree`.
529+
530+
**Note:** Loading a PyTree without providing an `abstract_pytree` is
531+
provided purely for convenience. For serious or production use cases, it is
532+
STRONGLY recommended to always provide an `abstract_pytree` to ensure the
533+
restored PyTree strictly matches the expected shapes, dtypes, and sharding.
534+
535+
Example:
536+
1. Basic Loading:
537+
Load a PyTree without providing an abstract structure. By passing
538+
`step=None` (or omitting it), it automatically loads the latest step::
539+
540+
from orbax.checkpoint.v1 import training
541+
542+
# Initialize the checkpointer for the directory
543+
ckptr = training.Checkpointer(directory)
544+
545+
# Load the saved PyTree from latest step
546+
restored_tree = ckptr.load_pytree(step=None)
547+
548+
2. Loading with an Abstract PyTree:
549+
Provide an abstract structure (such as target shapes and dtypes)
550+
to ensure the restored PyTree is safely and correctly formatted::
551+
552+
import jax
553+
import jax.numpy as jnp
554+
from orbax.checkpoint.v1 import training
555+
556+
ckptr = training.Checkpointer(directory)
557+
558+
# Define the expected structure (shapes and dtypes) to restore into
559+
target_structure = {
560+
'weights': jax.ShapeDtypeStruct((128, 128), dtype=jnp.float32),
561+
'bias': jax.ShapeDtypeStruct((128,), dtype=jnp.float32)
562+
}
563+
564+
# Restore exactly matching the target structure
565+
restored_tree = ckptr.load_pytree(
566+
step=1,
567+
abstract_pytree=target_structure
568+
)
529569
530570
Args:
531-
step: The step number or :py:class:`.CheckpointMetadata` to load.
571+
step: The step number or :py:class:`.CheckpointMetadata` to load. If None,
572+
the checkpointer will attempt to resolve and load the latest existing
573+
checkpoint.
532574
abstract_pytree: The abstract PyTree to load.
533575
534576
Returns:
@@ -543,7 +585,111 @@ def load_checkpointables(
543585
step: int | CheckpointMetadata | None = None,
544586
abstract_checkpointables: dict[str, Any] | None = None,
545587
) -> dict[str, Any]:
546-
"""Loads a set of checkpointables at the given step."""
588+
"""Loads a set of checkpointables at the given step.
589+
590+
This function behaves similarly to
591+
:py:func:`~orbax.checkpoint.v1._src.loading.loading.load_checkpointables`.
592+
593+
This function retrieves multiple named items (such as model weights or
594+
optimizer states) from a specific checkpoint directory. If no step is
595+
provided, it automatically resolves to and loads the most recently saved
596+
checkpoint.
597+
598+
**Note:** Loading without providing an `abstract_checkpointables`
599+
dictionary is provided purely for convenience. For serious or production
600+
use cases, it is STRONGLY recommended to always provide
601+
`abstract_checkpointables` to ensure the restored items strictly match
602+
the exact nested structures, shapes, and data types expected.
603+
604+
Example:
605+
1. Basic Loading:
606+
Load multiple named items (such as a model and optimizer) from a
607+
specific step. If step is omitted, it resolves to the latest
608+
available checkpoint::
609+
610+
from orbax.checkpoint.v1 import training
611+
612+
# Initialize the checkpointer for the directory
613+
ckptr = training.Checkpointer(directory)
614+
615+
# Load all checkpointables saved at the latest step
616+
restored_items = ckptr.load_checkpointables(step=None)
617+
618+
# Access the individual components by their original string keys
619+
my_model = restored_items["model"]
620+
my_opt = restored_items["optimizer"]
621+
622+
2. Loading with Abstract Checkpointables (Recommended):
623+
Provide a dictionary of abstract structures to ensure the restored
624+
items strictly match your expected shapes and data types::
625+
626+
import jax
627+
import jax.numpy as jnp
628+
from orbax.checkpoint.v1 import training
629+
630+
ckptr = training.Checkpointer(directory)
631+
632+
# Define the expected structure for each named item using JAX arrays
633+
target_items = {
634+
"model": {
635+
'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32),
636+
'bias': jax.ShapeDtypeStruct((128,), jnp.float32)
637+
},
638+
"optimizer": {
639+
'momentum': jax.ShapeDtypeStruct((128, 128), jnp.float32)
640+
}
641+
}
642+
643+
# Restore exactly matching the target structures
644+
restored_items = ckptr.load_checkpointables(
645+
step=1,
646+
abstract_checkpointables=target_items
647+
)
648+
649+
3. Partial Loading:
650+
If you only need to load a subset of checkpointables (e.g., loading
651+
model weights but omitting optimizer state), you can provide an
652+
`abstract_checkpointables` dictionary containing only the keys for the
653+
items you wish to restore::
654+
655+
import jax
656+
import jax.numpy as jnp
657+
from orbax.checkpoint.v1 import training
658+
659+
ckptr = training.Checkpointer(directory)
660+
661+
# Define abstract structure for ONLY the items to load
662+
target_items = {
663+
"model": {
664+
'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32),
665+
'bias': jax.ShapeDtypeStruct((128,), jnp.float32)
666+
},
667+
}
668+
669+
# Load only "model", omitting "optimizer"
670+
restored_items = ckptr.load_checkpointables(
671+
step=1,
672+
abstract_checkpointables=target_items
673+
)
674+
my_model = restored_items["model"]
675+
# my_opt = restored_items["optimizer"]
676+
677+
Args:
678+
step: The step number or :py:class:`.CheckpointMetadata` to load. If None,
679+
the checkpointer will attempt to resolve and load the latest existing
680+
checkpoint.
681+
abstract_checkpointables: A dictionary mapping string names to their
682+
corresponding abstract structures (e.g., target PyTrees). This guides
683+
the loading process to ensure shape and type compliance. If provided, it
684+
can be used to load only a subset of checkpointables by providing only a
685+
subset of keys.
686+
687+
Returns:
688+
dict[str, Any]: A dictionary containing the loaded checkpointable objects,
689+
keyed by string names. If `abstract_checkpointables` was specified,
690+
returns only the keys specified in that dict, otherwise returns all
691+
keys saved with `save_checkpointables`.
692+
"""
547693
step = self._resolve_existing_checkpoint(step).step
548694
return loading.load_checkpointables(
549695
self.directory / self._step_name_format.build_name(step),

0 commit comments

Comments
 (0)