@@ -241,23 +241,43 @@ def get_merged_lora_ckpt(
241241
242242@contextlib .contextmanager
243243def disable_adapter (model : nn .Module ) -> Generator [None , None , None ]:
244- for _ , v in model .named_modules ():
244+ """
245+ Temporarily disable the adapters in a neural network model. This can be used,
246+ for example, in DPO for treating the lora adapters as the policy model
247+ and disabling it to treat the base model as the reference model.
248+
249+ This context manager goes through all modules in the provided neural network model,
250+ and if a module has an 'adapter_params' attribute that is callable and a 'disabled' attribute,
251+ it sets 'disabled' to True. Then, the control is given back to caller. Once that finalizes,
252+ it sets 'disabled' back to False for all modules that were temporarily disabled.
253+
254+ Args:
255+ model (nn.Module): The neural network model whose adapters are to be temporarily disabled.
256+ Yields:
257+ None: This function yields control back to the caller, with the adapters disabled.
258+ Example:
259+ >>> with disable_adapter(model):
260+ ... # Perform operations with adapters disabled
261+ ... pass
262+
263+ """
264+ for _ , module in model .named_modules ():
245265 if (
246- hasattr (v , "adapter_params" )
247- and callable (v .adapter_params )
248- and hasattr (v , "disabled" )
266+ hasattr (module , "adapter_params" )
267+ and callable (module .adapter_params )
268+ and hasattr (module , "disabled" )
249269 ):
250- v .disabled = True
270+ module .disabled = True
251271 try :
252272 yield
253273 finally :
254- for _ , v in model .named_modules ():
274+ for _ , module in model .named_modules ():
255275 if (
256- hasattr (v , "adapter_params" )
257- and callable (v .adapter_params )
258- and hasattr (v , "disabled" )
276+ hasattr (module , "adapter_params" )
277+ and callable (module .adapter_params )
278+ and hasattr (module , "disabled" )
259279 ):
260- v .disabled = False
280+ module .disabled = False
261281
262282
263283def validate_missing_and_unexpected_for_lora (
@@ -272,7 +292,7 @@ def validate_missing_and_unexpected_for_lora(
272292 """
273293 A more memory-efficient way to validate that LoRA state dict loading was done properly.
274294
275- Similar to validate_state_dict_for_lora, this function uses a model's LoRA config to
295+ Similar to :func:` validate_state_dict_for_lora` , this function uses a model's LoRA config to
276296 check that LoRA and/or base model weights are loaded into the full model correctly.
277297 Unlike that function, this method relies only on the values of missing and unexpected
278298 as returned by the load_state_dict API with strict=False. This allows us to do the
0 commit comments