Skip to content

Commit f6ee1aa

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Internal change.
PiperOrigin-RevId: 882163694
1 parent 203783a commit f6ee1aa

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,12 @@ class TrainState:
10571057
serialized_item, value_metadata_tree
10581058
)
10591059
else:
1060+
# Deserialize value metadata tree to the same structure as item to allow
1061+
# for comparison with item that contains rich types.
1062+
if self._pytree_metadata_options.support_rich_types:
1063+
value_metadata_tree = tree_utils.deserialize_tree(
1064+
value_metadata_tree, item
1065+
)
10601066
# is_empty_or_leaf is necessary here to treat empty nodes (e.g. empty
10611067
# dicts, lists, custom nodes) as leaves, as they do not contain any
10621068
# actual data to be restored, but are needed to maintain the structure.
@@ -1083,12 +1089,11 @@ class TrainState:
10831089
restore_args, self._pytree_metadata_options
10841090
)
10851091

1086-
value_metadata_tree_deserialized = tree_utils.deserialize_tree(
1087-
value_metadata_tree, item
1088-
)
1089-
restore_args_deserialized = tree_utils.deserialize_tree(restore_args, item)
1090-
value_metadata_tree = value_metadata_tree_deserialized
1091-
restore_args = restore_args_deserialized
1092+
if not self._pytree_metadata_options.support_rich_types:
1093+
value_metadata_tree = tree_utils.deserialize_tree(
1094+
value_metadata_tree, item
1095+
)
1096+
restore_args = tree_utils.deserialize_tree(restore_args, item)
10921097

10931098
param_infos = self._get_param_infos(
10941099
item=value_metadata_tree,

checkpoint/orbax/checkpoint/_src/tree/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Any, Callable, Mapping, Optional, Tuple, TypeVar, Union
1818

19+
import flax
1920
import jax
2021
import jax.tree_util as jtu
2122
from orbax.checkpoint._src.arrays import abstract_arrays
@@ -235,10 +236,14 @@ def _reconstruct_from_keypath(keypath, _):
235236
result = serialized
236237
for key in keypath:
237238
key_name = get_key_name(key)
238-
if isinstance(key, jax.tree_util.GetAttrKey) and isinstance_of_namedtuple(
239-
result
240-
):
241-
result = getattr(result, key_name)
239+
if isinstance(key, jax.tree_util.GetAttrKey):
240+
if isinstance_of_namedtuple(result):
241+
result = getattr(result, key_name)
242+
elif isinstance(result, flax.struct.PyTreeNode):
243+
# Special case to support flax.struct.PyTreeNode
244+
result = result.__dict__[key_name]
245+
else:
246+
result = result[key_name]
242247
else:
243248
# Special case to support Pax.
244249
if not isinstance(result, (list, tuple)) and key_name not in result:

0 commit comments

Comments
 (0)