@@ -524,11 +524,53 @@ def load_pytree(
524524 ) -> tree_types .PyTreeOf [tree_types .LeafType ]:
525525 """Loads a PyTree checkpoint at the given step.
526526
527- This function behaves similarly to :py:func:`.load_pytree` (see
528- documentation).
527+ This function behaves similarly to
528+ :py:func:`~orbax.checkpoint.v1._src.loading.loading.load_pytree`.
529+
530+ **Note:** Loading a PyTree without providing an `abstract_pytree` is
531+ provided purely for convenience. For serious or production use cases, it is
532+ STRONGLY recommended to always provide an `abstract_pytree` to ensure the
533+ restored PyTree strictly matches the expected shapes, dtypes, and sharding.
534+
535+ Example:
536+ 1. Basic Loading:
537+ Load a PyTree without providing an abstract structure. By passing
538+ `step=None` (or omitting it), it automatically loads the latest step::
539+
540+ from orbax.checkpoint.v1 import training
541+
542+ # Initialize the checkpointer for the directory
543+ ckptr = training.Checkpointer(directory)
544+
545+ # Load the saved PyTree from latest step
546+ restored_tree = ckptr.load_pytree(step=None)
547+
548+ 2. Loading with an Abstract PyTree:
549+ Provide an abstract structure (such as target shapes and dtypes)
550+ to ensure the restored PyTree is safely and correctly formatted::
551+
552+ import jax
553+ import jax.numpy as jnp
554+ from orbax.checkpoint.v1 import training
555+
556+ ckptr = training.Checkpointer(directory)
557+
558+ # Define the expected structure (shapes and dtypes) to restore into
559+ target_structure = {
560+ 'weights': jax.ShapeDtypeStruct((128, 128), dtype=jnp.float32),
561+ 'bias': jax.ShapeDtypeStruct((128,), dtype=jnp.float32)
562+ }
563+
564+ # Restore exactly matching the target structure
565+ restored_tree = ckptr.load_pytree(
566+ step=1,
567+ abstract_pytree=target_structure
568+ )
529569
530570 Args:
531- step: The step number or :py:class:`.CheckpointMetadata` to load.
571+ step: The step number or :py:class:`.CheckpointMetadata` to load. If None,
572+ the checkpointer will attempt to resolve and load the latest existing
573+ checkpoint.
532574 abstract_pytree: The abstract PyTree to load.
533575
534576 Returns:
@@ -543,7 +585,111 @@ def load_checkpointables(
543585 step : int | CheckpointMetadata | None = None ,
544586 abstract_checkpointables : dict [str , Any ] | None = None ,
545587 ) -> dict [str , Any ]:
546- """Loads a set of checkpointables at the given step."""
588+ """Loads a set of checkpointables at the given step.
589+
590+ This function behaves similarly to
591+ :py:func:`~orbax.checkpoint.v1._src.loading.loading.load_checkpointables`.
592+
593+ This function retrieves multiple named items (such as model weights or
594+ optimizer states) from a specific checkpoint directory. If no step is
595+ provided, it automatically resolves to and loads the most recently saved
596+ checkpoint.
597+
598+ **Note:** Loading without providing an `abstract_checkpointables`
599+ dictionary is provided purely for convenience. For serious or production
600+ use cases, it is STRONGLY recommended to always provide
601+ `abstract_checkpointables` to ensure the restored items strictly match
602+ the exact nested structures, shapes, and data types expected.
603+
604+ Example:
605+ 1. Basic Loading:
606+ Load multiple named items (such as a model and optimizer) from a
607+ specific step. If step is omitted, it resolves to the latest
608+ available checkpoint::
609+
610+ from orbax.checkpoint.v1 import training
611+
612+ # Initialize the checkpointer for the directory
613+ ckptr = training.Checkpointer(directory)
614+
615+ # Load all checkpointables saved at the latest step
616+ restored_items = ckptr.load_checkpointables(step=None)
617+
618+ # Access the individual components by their original string keys
619+ my_model = restored_items["model"]
620+ my_opt = restored_items["optimizer"]
621+
622+ 2. Loading with Abstract Checkpointables (Recommended):
623+ Provide a dictionary of abstract structures to ensure the restored
624+ items strictly match your expected shapes and data types::
625+
626+ import jax
627+ import jax.numpy as jnp
628+ from orbax.checkpoint.v1 import training
629+
630+ ckptr = training.Checkpointer(directory)
631+
632+ # Define the expected structure for each named item using JAX arrays
633+ target_items = {
634+ "model": {
635+ 'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32),
636+ 'bias': jax.ShapeDtypeStruct((128,), jnp.float32)
637+ },
638+ "optimizer": {
639+ 'momentum': jax.ShapeDtypeStruct((128, 128), jnp.float32)
640+ }
641+ }
642+
643+ # Restore exactly matching the target structures
644+ restored_items = ckptr.load_checkpointables(
645+ step=1,
646+ abstract_checkpointables=target_items
647+ )
648+
649+ 3. Partial Loading:
650+ If you only need to load a subset of checkpointables (e.g., loading
651+ model weights but omitting optimizer state), you can provide an
652+ `abstract_checkpointables` dictionary containing only the keys for the
653+ items you wish to restore::
654+
655+ import jax
656+ import jax.numpy as jnp
657+ from orbax.checkpoint.v1 import training
658+
659+ ckptr = training.Checkpointer(directory)
660+
661+ # Define abstract structure for ONLY the items to load
662+ target_items = {
663+ "model": {
664+ 'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32),
665+ 'bias': jax.ShapeDtypeStruct((128,), jnp.float32)
666+ },
667+ }
668+
669+ # Load only "model", omitting "optimizer"
670+ restored_items = ckptr.load_checkpointables(
671+ step=1,
672+ abstract_checkpointables=target_items
673+ )
674+ my_model = restored_items["model"]
675+ # my_opt = restored_items["optimizer"]
676+
677+ Args:
678+ step: The step number or :py:class:`.CheckpointMetadata` to load. If None,
679+ the checkpointer will attempt to resolve and load the latest existing
680+ checkpoint.
681+ abstract_checkpointables: A dictionary mapping string names to their
682+ corresponding abstract structures (e.g., target PyTrees). This guides
683+ the loading process to ensure shape and type compliance. If provided, it
684+ can be used to load only a subset of checkpointables by providing only a
685+ subset of keys.
686+
687+ Returns:
688+ dict[str, Any]: A dictionary containing the loaded checkpointable objects,
689+ keyed by string names. If `abstract_checkpointables` was specified,
690+ returns only the keys specified in that dict, otherwise returns all
691+ keys saved with `save_checkpointables`.
692+ """
547693 step = self ._resolve_existing_checkpoint (step ).step
548694 return loading .load_checkpointables (
549695 self .directory / self ._step_name_format .build_name (step ),
0 commit comments