Skip to content

Commit 52abd3a

Browse files
committed
Torch opt param groups configurable
Three variants how this can be configured by the user in the optimizer options dict: - ``param_groups_custom``: callable which returns a list of param groups. This is the most flexible option, and could also go beyond just weight decay logic, or having more than two param groups (weight decay disabled/enabled). - ``weight_decay_custom_include_check``: callable which returns True/False for each param, to either include it in the weight decay group or not, or None to use the default logic. - ``weight_decay_modules_blacklist``: list of modules types which should not get weight decay. Those can be RF modules or pure PyTorch modules. The types can be specified as string (e.g. ``"torch.nn.LayerNorm"``) or as the type itself.
1 parent 427cefb commit 52abd3a

File tree

2 files changed

+84
-14
lines changed

2 files changed

+84
-14
lines changed

returnn/torch/frontend/bridge.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from __future__ import annotations
9+
from typing import Optional
910
import torch
1011
import returnn.frontend as rf
1112
from returnn.tensor import Dim
@@ -22,6 +23,17 @@ def pt_module_to_rf_module(pt_module: torch.nn.Module) -> rf.Module:
2223
return _PTModuleAsRFModule(pt_module=pt_module)
2324

2425

26+
def pt_module_to_wrapped_rf_module(pt_module: torch.nn.Module) -> Optional[rf.Module]:
27+
"""
28+
:param pt_module: torch module
29+
:return: RF module if the torch module is a wrapped RF module, or None otherwise
30+
"""
31+
assert isinstance(pt_module, torch.nn.Module)
32+
if isinstance(pt_module, _RFModuleAsPTModule):
33+
return pt_module.rf_module
34+
return None
35+
36+
2537
def rf_module_to_pt_module(rf_module: rf.Module) -> torch.nn.Module:
2638
"""
2739
:param rf_module: RF module

returnn/torch/updater.py

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55

66
from __future__ import annotations
77

8+
from typing import Optional, Union, Any, Type, Sequence, Set, Dict, List, Tuple
89
import os
910
import gc
1011
import torch
1112
import typing
12-
from typing import Any, Set, Dict, Optional
1313

1414
import returnn
1515
from returnn.log import log
1616
from returnn.util.basic import RefIdEq
17+
import returnn.frontend as rf
18+
from returnn.torch.frontend.bridge import pt_module_to_wrapped_rf_module
1719

1820
_OptimizerClassesDictInitialized = False
1921
_OptimizerClassesDict = {}
@@ -36,12 +38,11 @@ def _init_optimizer_classes_dict():
3638
_OptimizerClassesDict[name.lower()] = cls
3739

3840

39-
def get_optimizer_class(class_name):
41+
def get_optimizer_class(class_name) -> Type[torch.optim.Optimizer]:
4042
"""
4143
:param str|()->torch.optim.Optimizer|type[torch.optim.Optimizer] class_name:
4244
Optimizer data, e.g. "adam", torch.optim.Adam...
4345
:return: Optimizer class
44-
:rtype: type[torch.optim.Optimizer]
4546
"""
4647
_init_optimizer_classes_dict()
4748
if isinstance(class_name, type):
@@ -299,9 +300,9 @@ def _create_optimizer(self, optimizer_opts):
299300
# If the user specified it as epsilon, parse it as eps for the optimizer
300301
if "eps" in optim_class_init_kwargs and "epsilon" in opt_kwargs:
301302
opt_kwargs["eps"] = opt_kwargs.pop("epsilon")
302-
if "learning_rate" in optimizer_opts:
303+
if "learning_rate" in opt_kwargs:
303304
raise ValueError("'learning_rate' should be set outside of the 'optimizer' dict.")
304-
lr = lr * optimizer_opts.get("learning_rate_multiplier", 1.0)
305+
lr = lr * opt_kwargs.pop("learning_rate_multiplier", 1.0)
305306
opt_kwargs["lr"] = lr
306307

307308
param_groups = self._get_optimizer_param_groups(optim_class, opt_kwargs)
@@ -321,7 +322,9 @@ def _create_default_optimizer(self):
321322

322323
return optimizer
323324

324-
def _get_optimizer_param_groups(self, optim_class, optimizer_opts):
325+
def _get_optimizer_param_groups(
326+
self, optim_class: Type[torch.optim.Optimizer], optimizer_opts: Dict[str, Any]
327+
) -> List[Dict[str, Any]]:
325328
"""
326329
The weight_decay parameter from AdamW affects the weights of layers such as LayerNorm and Embedding.
327330
This function creates a blacklist of network modules and splits the optimizer groups in two:
@@ -334,14 +337,34 @@ def _get_optimizer_param_groups(self, optim_class, optimizer_opts):
334337
This code is based on https://github.com/karpathy/minGPT (MIT license):
335338
https://github.com/karpathy/minGPT/blob/3ed14b2cec0dfdad3f4b2831f2b4a86d11aef150/mingpt/model.py#L136.
336339
337-
:param type[torch.optim.Optimizer] optim_class: Optimizer class.
338-
:param dict[str] optimizer_opts: Optimizer configuration specified by the user.
340+
Three variants how this can be configured by the user in the optimizer options dict:
341+
342+
- ``param_groups_custom``: callable which returns a list of param groups.
343+
This is the most flexible option, and could also go beyond just weight decay logic,
344+
or having more than two param groups (weight decay disabled/enabled).
345+
- ``weight_decay_custom_include_check``: callable which returns True/False for each param,
346+
to either include it in the weight decay group or not,
347+
or None to use the default logic.
348+
- ``weight_decay_modules_blacklist``: list of modules types which should not get weight decay.
349+
Those can be RF modules or pure PyTorch modules.
350+
The types can be specified as string (e.g. ``"torch.nn.LayerNorm"``) or as the type itself.
351+
352+
:param optim_class: Optimizer class.
353+
:param optimizer_opts: Optimizer configuration specified by the user. Might be modified inplace here.
339354
:return: List of configurations for the different sets of parameters.
340-
:rtype: List[Dict[str]]
341355
"""
356+
custom_param_groups = optimizer_opts.pop("param_groups_custom", None)
357+
if custom_param_groups is not None:
358+
assert callable(custom_param_groups), f"invalid param_groups_custom {custom_param_groups!r}"
359+
rf_model = pt_module_to_wrapped_rf_module(self.network)
360+
custom_param_groups = custom_param_groups(
361+
model=self.network, rf_model=rf_model, optimizer_class=optim_class, optimizer_opts=optimizer_opts
362+
)
363+
return custom_param_groups
364+
342365
network_params = self.network.parameters()
343366

344-
# By default insert the weight_decay constraints in the optimizer, as this is default PyTorch behavior.
367+
# By default, insert the weight_decay constraints in the optimizer, as this is default PyTorch behavior.
345368
# If the user doesn't accept this, throw an error message.
346369
assert self.config.bool("decouple_constraints", True), (
347370
"L2/weight_decay constraints are decoupled in PyTorch, but "
@@ -366,23 +389,44 @@ def _get_optimizer_param_groups(self, optim_class, optimizer_opts):
366389
# Parameters without weight decay: biases + LayerNorm/Embedding layers.
367390
wd_params = set()
368391
no_wd_params = set()
369-
blacklist_wd_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
392+
blacklist_wd_modules = optimizer_opts.pop("weight_decay_modules_blacklist", None)
393+
if blacklist_wd_modules is None:
394+
blacklist_wd_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
395+
else:
396+
blacklist_wd_modules = _wrap_user_blacklist_wd_modules(blacklist_wd_modules)
397+
custom_include_check = optimizer_opts.pop("weight_decay_custom_include_check", None)
398+
if custom_include_check:
399+
assert callable(custom_include_check), f"invalid weight_decay_custom_include_check {custom_include_check!r}"
370400
# Tracker of visited parameters to only add each parameter once, in case two modules share common parameters.
371401
# We need the wrapper class RefIdEq because Parameters are compared by value and not by reference.
372402
visited_params: Set[RefIdEq[torch.nn.Parameter]] = set()
373403
for module_name, module in self.network.named_modules():
374404
module_name: str
375405
module: torch.nn.Module
406+
rf_module = pt_module_to_wrapped_rf_module(module)
376407
for param_name, param in module.named_parameters(recurse=False):
377408
param_name: str
378409
param: torch.nn.Parameter
379410
if RefIdEq(param) in visited_params:
380411
continue
381412
visited_params.add(RefIdEq(param))
382413
full_param_name = "%s.%s" % (module_name, param_name) if module_name else param_name
383-
if param_name.endswith("bias"):
384-
no_wd_params.add(full_param_name)
385-
elif param_name.endswith("weight") and isinstance(module, blacklist_wd_modules):
414+
custom_include = None
415+
if custom_include_check:
416+
custom_include = custom_include_check(
417+
module=module, rf_module=rf_module, full_param_name=param_name, param=param
418+
)
419+
if custom_include is not None:
420+
assert isinstance(custom_include, bool), "weight_decay_custom_include_check did not return bool"
421+
if custom_include:
422+
wd_params.add(full_param_name)
423+
else:
424+
no_wd_params.add(full_param_name)
425+
elif (
426+
param_name.endswith("bias")
427+
or isinstance(module, blacklist_wd_modules)
428+
or isinstance(rf_module, blacklist_wd_modules)
429+
):
386430
no_wd_params.add(full_param_name)
387431
else:
388432
wd_params.add(full_param_name)
@@ -394,3 +438,17 @@ def _get_optimizer_param_groups(self, optim_class, optimizer_opts):
394438
]
395439

396440
return optim_groups
441+
442+
443+
def _wrap_user_blacklist_wd_modules(
444+
mods: Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]]
445+
) -> Tuple[type, ...]:
446+
assert isinstance(mods, (list, tuple)), f"invalid blacklist_weight_decay_modules {mods!r}"
447+
res = []
448+
for mod in mods:
449+
if isinstance(mod, str):
450+
assert mod.startswith("torch.") or mod.startswith("rf."), f"invalid blacklist_weight_decay_modules {mods!r}"
451+
mod = eval(mod)
452+
assert issubclass(mod, (rf.Module, torch.nn.Module)), f"invalid blacklist_weight_decay_modules {mods!r}"
453+
res.append(mod)
454+
return tuple(res)

0 commit comments

Comments
 (0)