Skip to content

Commit aec6f32

Browse files
angel-coreOrbax Authors
authored andcommitted
Refactor LeafHandlerRegistry state management and handler resolution logic.
PiperOrigin-RevId: 887045623
1 parent 519af35 commit aec6f32

File tree

3 files changed

+433
-105
lines changed

3 files changed

+433
-105
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,13 @@ def _get_typestr(leaf_type: Any) -> str:
452452

453453
# register standardard v1 leaf handlers to the v0 type handler registry.
454454
handlers = []
455-
for leaf_type, _, leaf_handler_type in leaf_handler_registry.get_all():
455+
# We must reverse the order of the leaf handlers to ensure that the last
456+
# registered handler is the first one used as V1 registry is ordered by
457+
# priority of generic to specific, while V0 type handler registry is ordered
458+
# by the reverse.
459+
for leaf_type, _, leaf_handler_type in reversed(
460+
leaf_handler_registry.get_all()
461+
):
456462
try:
457463
leaf_handler = leaf_handler_type(context=context) # pytype: disable=wrong-keyword-args
458464
except TypeError as e:

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registry.py

Lines changed: 197 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
"""Leaf Handler Registry."""
1616

17-
from typing import Any, Dict, Sequence, Tuple, Type
17+
from collections.abc import Sequence
18+
import dataclasses
19+
from typing import Any
1820

1921
from absl import logging
2022
import jax
@@ -59,6 +61,47 @@
5961
}
6062

6163

64+
@dataclasses.dataclass
65+
class _Registration:
66+
"""A registration entry for a LeafHandler.
67+
68+
Attributes:
69+
leaf_type: The concrete PyTree leaf type.
70+
abstract_type: The abstract representation of the leaf type.
71+
handler_type: The LeafHandler class.
72+
secondary_typestrs: Optional alternate identifiers for the handler.
73+
leaf_specificity_score: Specificity score for the leaf type. Higher value
74+
means more specific type relative to other leaf types which it is a
75+
subclass of. This determines which handler we resolve to during
76+
save/load operations.
77+
abstract_specificity_score: Specificity score for the abstract type. Higher
78+
value means more specific type relative to other abstract types which it
79+
is a subprotocol/subclass of. This determines which handler we resolve to
80+
during save/load operations.
81+
"""
82+
83+
leaf_type: type[Any]
84+
abstract_type: type[Any]
85+
handler_type: type[types.LeafHandler[Any, Any]]
86+
secondary_typestrs: Sequence[str] | None
87+
leaf_specificity_score: int
88+
abstract_specificity_score: int
89+
90+
91+
def _is_abstract_subprotocol(
92+
type_a: type[Any], type_b: type[Any]
93+
) -> bool:
94+
"""Checks if 'type_a' is a subclass or sub-protocol of 'type_b'."""
95+
try:
96+
if typing_extensions.is_protocol(type_b): # pytype: disable=not-supported-yet
97+
return protocol_utils.is_subclass_protocol(
98+
cls=type_a, protocol=type_b
99+
)
100+
return issubclass(type_a, type_b)
101+
except TypeError:
102+
return False
103+
104+
62105
class BaseLeafHandlerRegistry:
63106
"""Base Leaf Handler Registry implements the LeafHandlerRegistry Protocol.
64107
@@ -87,71 +130,57 @@ class CustomArray(np.ndarray): pass
87130
"""
88131

89132
def __init__(self):
90-
self._leaf_type_registry: Dict[
91-
Type[Any], Type[types.LeafHandler[Any, Any]]
92-
] = {}
93-
self._abstract_type_registry: Dict[
94-
Type[Any], Type[types.LeafHandler[Any, Any]]
95-
] = {}
96-
97-
# for easy look up for replacement
98-
self._handler_to_types: Dict[
99-
Type[types.LeafHandler[Any, Any]], Tuple[Type[Any], Type[Any]]
100-
] = {}
101-
self._secondary_typestrs: Dict[
102-
Type[types.LeafHandler[Any, Any]], Sequence[str]
103-
] = {}
133+
# Sorted [Generic -> Specific] primarily by leaf_specificity_score.
134+
self._entries: list[_Registration] = []
104135

105136
def _try_get(
106-
self, leaf_type: Type[types.Leaf]
107-
) -> Type[types.LeafHandler[types.Leaf, Any]] | None:
108-
"""Returns the handler registered for a given type, if available."""
109-
for registered_ty, handler_type in self._leaf_type_registry.items():
110-
if issubclass(leaf_type, registered_ty):
111-
return handler_type
112-
113-
# no handler found
137+
self, leaf_type: type[types.Leaf]
138+
) -> type[types.LeafHandler[types.Leaf, Any]] | None:
139+
"""Returns the most specific handler for a given type, if available."""
140+
# self._entries is sorted Generic -> Specific by leaf_specificity_score.
141+
# Iterating reversed checks the most specific handlers first.
142+
for entry in reversed(self._entries):
143+
try:
144+
if issubclass(leaf_type, entry.leaf_type):
145+
return entry.handler_type
146+
except TypeError:
147+
pass
114148
return None
115149

116150
def get(
117-
self, leaf_type: Type[types.Leaf]
118-
) -> Type[types.LeafHandler[types.Leaf, Any]]:
151+
self, leaf_type: type[types.Leaf]
152+
) -> type[types.LeafHandler[types.Leaf, Any]]:
119153
if (handler_type := self._try_get(leaf_type)) is None:
120154
raise ValueError(
121-
f'Unknown Leaf type: "{leaf_type}". Must register it with'
155+
f'Unknown Leaf type: "{leaf_type!r}". Must register it with'
122156
' LeafHandlerRegistry.'
123157
)
124-
125158
return handler_type
126159

127160
def _try_get_abstract(
128161
self,
129-
abstract_type: Type[types.AbstractLeaf],
130-
) -> Type[types.LeafHandler[Any, types.AbstractLeaf]] | None:
131-
"""Returns the handler registered for a given abstract type, if available."""
132-
for (
133-
registered_abstract_ty,
134-
handler_type,
135-
) in self._abstract_type_registry.items():
136-
if typing_extensions.is_protocol(registered_abstract_ty): # pytype: disable=not-supported-yet
137-
if protocol_utils.is_subclass_protocol(
138-
cls=abstract_type, protocol=registered_abstract_ty
139-
):
140-
return handler_type
141-
elif issubclass(abstract_type, registered_abstract_ty):
142-
return handler_type
143-
144-
# no handler found
162+
abstract_type: type[types.AbstractLeaf],
163+
) -> type[types.LeafHandler[Any, types.AbstractLeaf]] | None:
164+
"""Returns the most specific handler for a given abstract type."""
165+
# Sort ascending by abstract_specificity_score (lowest to highest).
166+
sorted_entries = sorted(
167+
self._entries,
168+
key=lambda e: e.abstract_specificity_score
169+
)
170+
# Iterating reversed checks the most specific handlers first.
171+
for entry in reversed(sorted_entries):
172+
if _is_abstract_subprotocol(abstract_type, entry.abstract_type):
173+
return entry.handler_type
145174
return None
146175

147176
def get_abstract(
148177
self,
149-
abstract_type: Type[types.AbstractLeaf],
150-
) -> Type[types.LeafHandler[Any, types.AbstractLeaf]]:
178+
abstract_type: type[types.AbstractLeaf],
179+
) -> type[types.LeafHandler[Any, types.AbstractLeaf]]:
151180
if (handler_type := self._try_get_abstract(abstract_type)) is None:
152181
raise ValueError(
153-
f'Unknown AbstractLeaf type: "{abstract_type}". Must register it with'
154-
' LeafHandlerRegistry.'
182+
f'Unknown AbstractLeaf type: "{abstract_type!r}". Must register it'
183+
' with LeafHandlerRegistry.'
155184
)
156185

157186
return handler_type
@@ -167,24 +196,32 @@ def get_all(
167196
"""
168197
return [
169198
(
170-
leaf_type,
171-
abstract_type,
172-
handler_type,
173-
)
174-
for (leaf_type, handler_type), abstract_type in zip(
175-
self._leaf_type_registry.items(), self._abstract_type_registry
199+
entry.leaf_type,
200+
entry.abstract_type,
201+
entry.handler_type,
176202
)
203+
for entry in self._entries
177204
]
178205

179206
def add(
180207
self,
181-
leaf_type: Type[types.Leaf],
182-
abstract_type: Type[types.AbstractLeaf],
183-
handler_type: Type[types.LeafHandler[types.Leaf, types.AbstractLeaf]],
208+
leaf_type: type[types.Leaf],
209+
abstract_type: type[types.AbstractLeaf],
210+
handler_type: type[types.LeafHandler[types.Leaf, types.AbstractLeaf]],
184211
override: bool = False,
185212
secondary_typestrs: Sequence[str] | None = None,
186213
):
187-
"""Adds a handler_type for a given leaf_type and abstract_type pair.
214+
"""Registers a `handler_type` for a `leaf_type` and `abstract_type` pair.
215+
216+
The registry automatically maintains a [Generic -> Specific] hierarchy for
217+
both leaf and abstract types using dynamic topological priorities to ensure
218+
correct resolution. We maintain and recalculate these specificity scores to
219+
ensure that the most specific handler is chosen during resolution.
220+
221+
A conflict occurs if the exact `leaf_type` is already registered, or if the
222+
`abstract_type` is already mapped to a different handler. Set
223+
`override=True` to automatically remove conflicting entries and force the
224+
new registration.
188225
189226
Args:
190227
leaf_type: The concrete PyTree leaf type to register.
@@ -196,56 +233,110 @@ def add(
196233
secondary identifiers for the handler.
197234
198235
Raises:
199-
ValueError: If the `leaf_type` or `abstract_type` is already registered
200-
and `override` is False. Also raised if the `abstract_type` is already
201-
registered with a fundamentally different handler type.
236+
ValueError: If a duplicate `leaf_type` or conflicting `abstract_type`
237+
mapping exists and `override` is False.
202238
"""
203-
current_handler_type = self._try_get(leaf_type)
204-
current_abstract_handle_type = self._try_get_abstract(abstract_type)
205-
206-
if not override and (current_handler_type or current_abstract_handle_type):
207-
raise ValueError(
208-
f'Leaf_type[{leaf_type}] or abstract_type[{abstract_type}] has'
209-
f' already registered, current_handler: {current_handler_type}, '
210-
f'current_abstract_handle_type: {current_abstract_handle_type}'
211-
)
212239

213-
logging.vlog(
214-
1,
215-
'add: leaf_type[%s], abstract_type[%s], handler_type[%s],'
216-
' current_handler[%s], current_abstract_handle_type[%s]',
240+
# Check for exact duplicate registration
241+
for e in self._entries:
242+
if (
243+
e.leaf_type == leaf_type
244+
and e.abstract_type == abstract_type
245+
and e.handler_type == handler_type
246+
):
247+
logging.info(
248+
'Registration already exists for leaf_type[%s], '
249+
'abstract_type[%s], handler_type[%s]. Skipping.',
250+
leaf_type,
251+
abstract_type,
252+
handler_type,
253+
)
254+
return
255+
256+
if override:
257+
# Filter out conflicting entries if override is True.
258+
new_entries = []
259+
for e in self._entries:
260+
is_conflict = (e.leaf_type == leaf_type) or (
261+
e.abstract_type == abstract_type and e.handler_type != handler_type
262+
)
263+
if is_conflict:
264+
logging.warning(
265+
'clearing conflicting entry: leaf_type[%s], abstract_type[%s]'
266+
' handler_type[%s] during override.',
267+
e.leaf_type,
268+
e.abstract_type,
269+
e.handler_type,
270+
)
271+
else:
272+
new_entries.append(e)
273+
self._entries = new_entries
274+
else:
275+
for e in self._entries:
276+
if e.leaf_type == leaf_type:
277+
raise ValueError(
278+
f'leaf_type [{leaf_type}] is already handled by '
279+
f'{e.handler_type}. Use override=True to replace its entry. '
280+
f'Registry: {self._entries}'
281+
)
282+
if e.abstract_type == abstract_type and e.handler_type != handler_type:
283+
raise ValueError(
284+
f'abstract_type[{abstract_type}] is already handled by '
285+
f'{e.handler_type}. Use override=True to replace all '
286+
f'conflicting entries. Registry: {self._entries}'
287+
)
288+
289+
# Append the new entry with default priorities
290+
new_reg = _Registration(
217291
leaf_type,
218292
abstract_type,
219293
handler_type,
220-
current_handler_type,
221-
current_abstract_handle_type,
294+
secondary_typestrs,
295+
leaf_specificity_score=0,
296+
abstract_specificity_score=0,
297+
)
298+
self._entries.append(new_reg)
299+
# Recalculate specificity scores for all entries since new entry was added
300+
# and may change the specificity scores of existing entries.
301+
self._recalculate_specificity_scores()
302+
303+
# Sort the single source of truth [Generic -> Specific] based on leaf type
304+
# primarily, and abstract type secondarily.
305+
self._entries.sort(
306+
key=lambda x: (
307+
x.leaf_specificity_score,
308+
x.abstract_specificity_score,
309+
x.handler_type.__name__,
310+
)
222311
)
223312

224-
if current_handler_type and (
225-
current_abstract_handle_type
226-
and current_handler_type != current_abstract_handle_type
227-
):
228-
raise ValueError(
229-
f'Abstract_type[{abstract_type}] has already registered with a'
230-
' different type.'
231-
)
232-
elif current_handler_type and not current_abstract_handle_type:
233-
# need to remove the previous abstract type
234-
_, old_abstract_ty = self._handler_to_types.pop(current_handler_type)
235-
self._abstract_type_registry.pop(old_abstract_ty)
236-
237-
# new type and abstract type pair
238-
self._leaf_type_registry[leaf_type] = handler_type
239-
self._abstract_type_registry[abstract_type] = handler_type
240-
self._handler_to_types[handler_type] = (leaf_type, abstract_type)
241-
# Allows for multiple handlers to be associated with the same leaf_type and
242-
# abstract_type pair, typically for backward compatibility.
243-
if secondary_typestrs is not None:
244-
self._secondary_typestrs[handler_type] = (
245-
secondary_typestrs
246-
)
313+
def _recalculate_specificity_scores(self) -> None:
314+
"""Recalculates specificity scores and sorts the registry."""
315+
for target_entry in self._entries:
316+
leaf_count = 0
317+
abstract_count = 0
318+
for other_entry in self._entries:
319+
# Count how many leaf types this target is a subclass of.
320+
try:
321+
if (
322+
target_entry.leaf_type != other_entry.leaf_type and
323+
issubclass(target_entry.leaf_type, other_entry.leaf_type)
324+
):
325+
leaf_count += 1
326+
except TypeError:
327+
pass
328+
# Count how many abstract types this target is a subprotocol of.
329+
if (
330+
target_entry.abstract_type != other_entry.abstract_type and
331+
_is_abstract_subprotocol(
332+
target_entry.abstract_type, other_entry.abstract_type
333+
)
334+
):
335+
abstract_count += 1
336+
target_entry.leaf_specificity_score = leaf_count
337+
target_entry.abstract_specificity_score = abstract_count
247338

248-
def is_handleable(self, leaf_type: Type[Any]) -> bool:
339+
def is_handleable(self, leaf_type: type[Any]) -> bool:
249340
"""Returns True if the type is handleable.
250341
251342
This checks if the provided concrete leaf type, or any of its base classes,
@@ -259,8 +350,8 @@ def is_handleable(self, leaf_type: Type[Any]) -> bool:
259350
"""
260351
return self._try_get(leaf_type) is not None
261352

262-
def is_abstract_handleable(self, abstract_type: Type[Any]) -> bool:
263-
"""Returns True if the abstract type is handlable.
353+
def is_abstract_handleable(self, abstract_type: type[Any]) -> bool:
354+
"""Returns True if the abstract type is handleable.
264355
265356
This checks if the provided abstract leaf type, or any of its matching base
266357
classes or protocols, has a registered handler in the registry.
@@ -274,9 +365,12 @@ def is_abstract_handleable(self, abstract_type: Type[Any]) -> bool:
274365
return self._try_get_abstract(abstract_type) is not None
275366

276367
def get_secondary_typestrs(
277-
self, handler_type: Type[types.LeafHandler[Any, Any]]
368+
self, handler_type: type[types.LeafHandler[Any, Any]]
278369
) -> Sequence[str]:
279-
return self._secondary_typestrs.get(handler_type, [])
370+
for entry in self._entries:
371+
if entry.handler_type == handler_type:
372+
return entry.secondary_typestrs or []
373+
return []
280374

281375

282376
class StandardLeafHandlerRegistry(BaseLeafHandlerRegistry):

0 commit comments

Comments
 (0)