1414
1515"""Leaf Handler Registry."""
1616
17- from typing import Any , Dict , Sequence , Tuple , Type
17+ from collections .abc import Sequence
18+ import dataclasses
19+ from typing import Any
1820
1921from absl import logging
2022import jax
5961}
6062
6163
64+ @dataclasses .dataclass
65+ class _Registration :
66+ """A registration entry for a LeafHandler.
67+
68+ Attributes:
69+ leaf_type: The concrete PyTree leaf type.
70+ abstract_type: The abstract representation of the leaf type.
71+ handler_type: The LeafHandler class.
72+ secondary_typestrs: Optional alternate identifiers for the handler.
73+ leaf_specificity_score: Specificity score for the leaf type. Higher value
74+ means more specific type relative to other leaf types which it is a
75+ subclass of. This determines which handler we resolve to during
76+ save/load operations.
77+ abstract_specificity_score: Specificity score for the abstract type. Higher
78+ value means more specific type relative to other abstract types which it
79+ is a subprotocol/subclass of. This determines which handler we resolve to
80+ during save/load operations.
81+ """
82+
83+ leaf_type : type [Any ]
84+ abstract_type : type [Any ]
85+ handler_type : type [types .LeafHandler [Any , Any ]]
86+ secondary_typestrs : Sequence [str ] | None
87+ leaf_specificity_score : int
88+ abstract_specificity_score : int
89+
90+
91+ def _is_abstract_subprotocol (
92+ type_a : type [Any ], type_b : type [Any ]
93+ ) -> bool :
94+ """Checks if 'type_a' is a subclass or sub-protocol of 'type_b'."""
95+ try :
96+ if typing_extensions .is_protocol (type_b ): # pytype: disable=not-supported-yet
97+ return protocol_utils .is_subclass_protocol (
98+ cls = type_a , protocol = type_b
99+ )
100+ return issubclass (type_a , type_b )
101+ except TypeError :
102+ return False
103+
104+
62105class BaseLeafHandlerRegistry :
63106 """Base Leaf Handler Registry implements the LeafHandlerRegistry Protocol.
64107
@@ -87,71 +130,57 @@ class CustomArray(np.ndarray): pass
87130 """
88131
89132 def __init__ (self ):
90- self ._leaf_type_registry : Dict [
91- Type [Any ], Type [types .LeafHandler [Any , Any ]]
92- ] = {}
93- self ._abstract_type_registry : Dict [
94- Type [Any ], Type [types .LeafHandler [Any , Any ]]
95- ] = {}
96-
97- # for easy look up for replacement
98- self ._handler_to_types : Dict [
99- Type [types .LeafHandler [Any , Any ]], Tuple [Type [Any ], Type [Any ]]
100- ] = {}
101- self ._secondary_typestrs : Dict [
102- Type [types .LeafHandler [Any , Any ]], Sequence [str ]
103- ] = {}
133+ # Sorted [Generic -> Specific] primarily by leaf_specificity_score.
134+ self ._entries : list [_Registration ] = []
104135
105136 def _try_get (
106- self , leaf_type : Type [types .Leaf ]
107- ) -> Type [types .LeafHandler [types .Leaf , Any ]] | None :
108- """Returns the handler registered for a given type, if available."""
109- for registered_ty , handler_type in self ._leaf_type_registry .items ():
110- if issubclass (leaf_type , registered_ty ):
111- return handler_type
112-
113- # no handler found
137+ self , leaf_type : type [types .Leaf ]
138+ ) -> type [types .LeafHandler [types .Leaf , Any ]] | None :
139+ """Returns the most specific handler for a given type, if available."""
140+ # self._entries is sorted Generic -> Specific by leaf_specificity_score.
141+ # Iterating reversed checks the most specific handlers first.
142+ for entry in reversed (self ._entries ):
143+ try :
144+ if issubclass (leaf_type , entry .leaf_type ):
145+ return entry .handler_type
146+ except TypeError :
147+ pass
114148 return None
115149
116150 def get (
117- self , leaf_type : Type [types .Leaf ]
118- ) -> Type [types .LeafHandler [types .Leaf , Any ]]:
151+ self , leaf_type : type [types .Leaf ]
152+ ) -> type [types .LeafHandler [types .Leaf , Any ]]:
119153 if (handler_type := self ._try_get (leaf_type )) is None :
120154 raise ValueError (
121- f'Unknown Leaf type: "{ leaf_type } ". Must register it with'
155+ f'Unknown Leaf type: "{ leaf_type !r } ". Must register it with'
122156 ' LeafHandlerRegistry.'
123157 )
124-
125158 return handler_type
126159
127160 def _try_get_abstract (
128161 self ,
129- abstract_type : Type [types .AbstractLeaf ],
130- ) -> Type [types .LeafHandler [Any , types .AbstractLeaf ]] | None :
131- """Returns the handler registered for a given abstract type, if available."""
132- for (
133- registered_abstract_ty ,
134- handler_type ,
135- ) in self ._abstract_type_registry .items ():
136- if typing_extensions .is_protocol (registered_abstract_ty ): # pytype: disable=not-supported-yet
137- if protocol_utils .is_subclass_protocol (
138- cls = abstract_type , protocol = registered_abstract_ty
139- ):
140- return handler_type
141- elif issubclass (abstract_type , registered_abstract_ty ):
142- return handler_type
143-
144- # no handler found
162+ abstract_type : type [types .AbstractLeaf ],
163+ ) -> type [types .LeafHandler [Any , types .AbstractLeaf ]] | None :
164+ """Returns the most specific handler for a given abstract type."""
165+ # Sort ascending by abstract_specificity_score (lowest to highest).
166+ sorted_entries = sorted (
167+ self ._entries ,
168+ key = lambda e : e .abstract_specificity_score
169+ )
170+ # Iterating reversed checks the most specific handlers first.
171+ for entry in reversed (sorted_entries ):
172+ if _is_abstract_subprotocol (abstract_type , entry .abstract_type ):
173+ return entry .handler_type
145174 return None
146175
147176 def get_abstract (
148177 self ,
149- abstract_type : Type [types .AbstractLeaf ],
150- ) -> Type [types .LeafHandler [Any , types .AbstractLeaf ]]:
178+ abstract_type : type [types .AbstractLeaf ],
179+ ) -> type [types .LeafHandler [Any , types .AbstractLeaf ]]:
151180 if (handler_type := self ._try_get_abstract (abstract_type )) is None :
152181 raise ValueError (
153- f'Unknown AbstractLeaf type: "{ abstract_type } ". Must register it with '
154- ' LeafHandlerRegistry.'
182+ f'Unknown AbstractLeaf type: "{ abstract_type !r } ". Must register it'
183+ ' with LeafHandlerRegistry.'
155184 )
156185
157186 return handler_type
@@ -167,24 +196,32 @@ def get_all(
167196 """
168197 return [
169198 (
170- leaf_type ,
171- abstract_type ,
172- handler_type ,
173- )
174- for (leaf_type , handler_type ), abstract_type in zip (
175- self ._leaf_type_registry .items (), self ._abstract_type_registry
199+ entry .leaf_type ,
200+ entry .abstract_type ,
201+ entry .handler_type ,
176202 )
203+ for entry in self ._entries
177204 ]
178205
179206 def add (
180207 self ,
181- leaf_type : Type [types .Leaf ],
182- abstract_type : Type [types .AbstractLeaf ],
183- handler_type : Type [types .LeafHandler [types .Leaf , types .AbstractLeaf ]],
208+ leaf_type : type [types .Leaf ],
209+ abstract_type : type [types .AbstractLeaf ],
210+ handler_type : type [types .LeafHandler [types .Leaf , types .AbstractLeaf ]],
184211 override : bool = False ,
185212 secondary_typestrs : Sequence [str ] | None = None ,
186213 ):
187- """Adds a handler_type for a given leaf_type and abstract_type pair.
214+ """Registers a `handler_type` for a `leaf_type` and `abstract_type` pair.
215+
216+ The registry automatically maintains a [Generic -> Specific] hierarchy for
217+ both leaf and abstract types using dynamic topological priorities to ensure
218+ correct resolution. We maintain and recalculate these specificity scores to
219+ ensure that the most specific handler is chosen during resolution.
220+
221+ A conflict occurs if the exact `leaf_type` is already registered, or if the
222+ `abstract_type` is already mapped to a different handler. Set
223+ `override=True` to automatically remove conflicting entries and force the
224+ new registration.
188225
189226 Args:
190227 leaf_type: The concrete PyTree leaf type to register.
@@ -196,56 +233,110 @@ def add(
196233 secondary identifiers for the handler.
197234
198235 Raises:
199- ValueError: If the `leaf_type` or `abstract_type` is already registered
200- and `override` is False. Also raised if the `abstract_type` is already
201- registered with a fundamentally different handler type.
236+ ValueError: If a duplicate `leaf_type` or conflicting `abstract_type`
237+ mapping exists and `override` is False.
202238 """
203- current_handler_type = self ._try_get (leaf_type )
204- current_abstract_handle_type = self ._try_get_abstract (abstract_type )
205-
206- if not override and (current_handler_type or current_abstract_handle_type ):
207- raise ValueError (
208- f'Leaf_type[{ leaf_type } ] or abstract_type[{ abstract_type } ] has'
209- f' already registered, current_handler: { current_handler_type } , '
210- f'current_abstract_handle_type: { current_abstract_handle_type } '
211- )
212239
213- logging .vlog (
214- 1 ,
215- 'add: leaf_type[%s], abstract_type[%s], handler_type[%s],'
216- ' current_handler[%s], current_abstract_handle_type[%s]' ,
240+ # Check for exact duplicate registration
241+ for e in self ._entries :
242+ if (
243+ e .leaf_type == leaf_type
244+ and e .abstract_type == abstract_type
245+ and e .handler_type == handler_type
246+ ):
247+ logging .info (
248+ 'Registration already exists for leaf_type[%s], '
249+ 'abstract_type[%s], handler_type[%s]. Skipping.' ,
250+ leaf_type ,
251+ abstract_type ,
252+ handler_type ,
253+ )
254+ return
255+
256+ if override :
257+ # Filter out conflicting entries if override is True.
258+ new_entries = []
259+ for e in self ._entries :
260+ is_conflict = (e .leaf_type == leaf_type ) or (
261+ e .abstract_type == abstract_type and e .handler_type != handler_type
262+ )
263+ if is_conflict :
264+ logging .warning (
265+ 'clearing conflicting entry: leaf_type[%s], abstract_type[%s]'
266+ ' handler_type[%s] during override.' ,
267+ e .leaf_type ,
268+ e .abstract_type ,
269+ e .handler_type ,
270+ )
271+ else :
272+ new_entries .append (e )
273+ self ._entries = new_entries
274+ else :
275+ for e in self ._entries :
276+ if e .leaf_type == leaf_type :
277+ raise ValueError (
278+ f'leaf_type [{ leaf_type } ] is already handled by '
279+ f'{ e .handler_type } . Use override=True to replace its entry. '
280+ f'Registry: { self ._entries } '
281+ )
282+ if e .abstract_type == abstract_type and e .handler_type != handler_type :
283+ raise ValueError (
284+ f'abstract_type[{ abstract_type } ] is already handled by '
285+ f'{ e .handler_type } . Use override=True to replace all '
286+ f'conflicting entries. Registry: { self ._entries } '
287+ )
288+
289+ # Append the new entry with default priorities
290+ new_reg = _Registration (
217291 leaf_type ,
218292 abstract_type ,
219293 handler_type ,
220- current_handler_type ,
221- current_abstract_handle_type ,
294+ secondary_typestrs ,
295+ leaf_specificity_score = 0 ,
296+ abstract_specificity_score = 0 ,
297+ )
298+ self ._entries .append (new_reg )
299+ # Recalculate specificity scores for all entries since new entry was added
300+ # and may change the specificity scores of existing entries.
301+ self ._recalculate_specificity_scores ()
302+
303+ # Sort the single source of truth [Generic -> Specific] based on leaf type
304+ # primarily, and abstract type secondarily.
305+ self ._entries .sort (
306+ key = lambda x : (
307+ x .leaf_specificity_score ,
308+ x .abstract_specificity_score ,
309+ x .handler_type .__name__ ,
310+ )
222311 )
223312
224- if current_handler_type and (
225- current_abstract_handle_type
226- and current_handler_type != current_abstract_handle_type
227- ):
228- raise ValueError (
229- f'Abstract_type[{ abstract_type } ] has already registered with a'
230- ' different type.'
231- )
232- elif current_handler_type and not current_abstract_handle_type :
233- # need to remove the previous abstract type
234- _ , old_abstract_ty = self ._handler_to_types .pop (current_handler_type )
235- self ._abstract_type_registry .pop (old_abstract_ty )
236-
237- # new type and abstract type pair
238- self ._leaf_type_registry [leaf_type ] = handler_type
239- self ._abstract_type_registry [abstract_type ] = handler_type
240- self ._handler_to_types [handler_type ] = (leaf_type , abstract_type )
241- # Allows for multiple handlers to be associated with the same leaf_type and
242- # abstract_type pair, typically for backward compatibility.
243- if secondary_typestrs is not None :
244- self ._secondary_typestrs [handler_type ] = (
245- secondary_typestrs
246- )
313+ def _recalculate_specificity_scores (self ) -> None :
314+ """Recalculates specificity scores and sorts the registry."""
315+ for target_entry in self ._entries :
316+ leaf_count = 0
317+ abstract_count = 0
318+ for other_entry in self ._entries :
319+ # Count how many leaf types this target is a subclass of.
320+ try :
321+ if (
322+ target_entry .leaf_type != other_entry .leaf_type and
323+ issubclass (target_entry .leaf_type , other_entry .leaf_type )
324+ ):
325+ leaf_count += 1
326+ except TypeError :
327+ pass
328+ # Count how many abstract types this target is a subprotocol of.
329+ if (
330+ target_entry .abstract_type != other_entry .abstract_type and
331+ _is_abstract_subprotocol (
332+ target_entry .abstract_type , other_entry .abstract_type
333+ )
334+ ):
335+ abstract_count += 1
336+ target_entry .leaf_specificity_score = leaf_count
337+ target_entry .abstract_specificity_score = abstract_count
247338
248- def is_handleable (self , leaf_type : Type [Any ]) -> bool :
339+ def is_handleable (self , leaf_type : type [Any ]) -> bool :
249340 """Returns True if the type is handleable.
250341
251342 This checks if the provided concrete leaf type, or any of its base classes,
@@ -259,8 +350,8 @@ def is_handleable(self, leaf_type: Type[Any]) -> bool:
259350 """
260351 return self ._try_get (leaf_type ) is not None
261352
262- def is_abstract_handleable (self , abstract_type : Type [Any ]) -> bool :
263- """Returns True if the abstract type is handlable .
353+ def is_abstract_handleable (self , abstract_type : type [Any ]) -> bool :
354+ """Returns True if the abstract type is handleable .
264355
265356 This checks if the provided abstract leaf type, or any of its matching base
266357 classes or protocols, has a registered handler in the registry.
@@ -274,9 +365,12 @@ def is_abstract_handleable(self, abstract_type: Type[Any]) -> bool:
274365 return self ._try_get_abstract (abstract_type ) is not None
275366
276367 def get_secondary_typestrs (
277- self , handler_type : Type [types .LeafHandler [Any , Any ]]
368+ self , handler_type : type [types .LeafHandler [Any , Any ]]
278369 ) -> Sequence [str ]:
279- return self ._secondary_typestrs .get (handler_type , [])
370+ for entry in self ._entries :
371+ if entry .handler_type == handler_type :
372+ return entry .secondary_typestrs or []
373+ return []
280374
281375
282376class StandardLeafHandlerRegistry (BaseLeafHandlerRegistry ):
0 commit comments