5
5
6
6
from __future__ import annotations
7
7
8
+ from typing import Optional , Union , Any , Type , Sequence , Set , Dict , List , Tuple
8
9
import os
9
10
import gc
10
11
import torch
11
12
import typing
12
- from typing import Any , Set , Dict , Optional
13
13
14
14
import returnn
15
15
from returnn .log import log
16
16
from returnn .util .basic import RefIdEq
17
+ import returnn .frontend as rf
18
+ from returnn .torch .frontend .bridge import pt_module_to_wrapped_rf_module
17
19
18
20
_OptimizerClassesDictInitialized = False
19
21
_OptimizerClassesDict = {}
@@ -36,12 +38,11 @@ def _init_optimizer_classes_dict():
36
38
_OptimizerClassesDict [name .lower ()] = cls
37
39
38
40
39
- def get_optimizer_class (class_name ):
41
+ def get_optimizer_class (class_name ) -> Type [ torch . optim . Optimizer ] :
40
42
"""
41
43
:param str|()->torch.optim.Optimizer|type[torch.optim.Optimizer] class_name:
42
44
Optimizer data, e.g. "adam", torch.optim.Adam...
43
45
:return: Optimizer class
44
- :rtype: type[torch.optim.Optimizer]
45
46
"""
46
47
_init_optimizer_classes_dict ()
47
48
if isinstance (class_name , type ):
@@ -299,9 +300,9 @@ def _create_optimizer(self, optimizer_opts):
299
300
# If the user specified it as epsilon, parse it as eps for the optimizer
300
301
if "eps" in optim_class_init_kwargs and "epsilon" in opt_kwargs :
301
302
opt_kwargs ["eps" ] = opt_kwargs .pop ("epsilon" )
302
- if "learning_rate" in optimizer_opts :
303
+ if "learning_rate" in opt_kwargs :
303
304
raise ValueError ("'learning_rate' should be set outside of the 'optimizer' dict." )
304
- lr = lr * optimizer_opts . get ("learning_rate_multiplier" , 1.0 )
305
+ lr = lr * opt_kwargs . pop ("learning_rate_multiplier" , 1.0 )
305
306
opt_kwargs ["lr" ] = lr
306
307
307
308
param_groups = self ._get_optimizer_param_groups (optim_class , opt_kwargs )
@@ -321,7 +322,9 @@ def _create_default_optimizer(self):
321
322
322
323
return optimizer
323
324
324
- def _get_optimizer_param_groups (self , optim_class , optimizer_opts ):
325
+ def _get_optimizer_param_groups (
326
+ self , optim_class : Type [torch .optim .Optimizer ], optimizer_opts : Dict [str , Any ]
327
+ ) -> List [Dict [str , Any ]]:
325
328
"""
326
329
The weight_decay parameter from AdamW affects the weights of layers such as LayerNorm and Embedding.
327
330
This function creates a blacklist of network modules and splits the optimizer groups in two:
@@ -334,14 +337,34 @@ def _get_optimizer_param_groups(self, optim_class, optimizer_opts):
334
337
This code is based on https://github.com/karpathy/minGPT (MIT license):
335
338
https://github.com/karpathy/minGPT/blob/3ed14b2cec0dfdad3f4b2831f2b4a86d11aef150/mingpt/model.py#L136.
336
339
337
- :param type[torch.optim.Optimizer] optim_class: Optimizer class.
338
- :param dict[str] optimizer_opts: Optimizer configuration specified by the user.
340
+ Three variants how this can be configured by the user in the optimizer options dict:
341
+
342
+ - ``param_groups_custom``: callable which returns a list of param groups.
343
+ This is the most flexible option, and could also go beyond just weight decay logic,
344
+ or having more than two param groups (weight decay disabled/enabled).
345
+ - ``weight_decay_custom_include_check``: callable which returns True/False for each param,
346
+ to either include it in the weight decay group or not,
347
+ or None to use the default logic.
348
+ - ``weight_decay_modules_blacklist``: list of modules types which should not get weight decay.
349
+ Those can be RF modules or pure PyTorch modules.
350
+ The types can be specified as string (e.g. ``"torch.nn.LayerNorm"``) or as the type itself.
351
+
352
+ :param optim_class: Optimizer class.
353
+ :param optimizer_opts: Optimizer configuration specified by the user. Might be modified inplace here.
339
354
:return: List of configurations for the different sets of parameters.
340
- :rtype: List[Dict[str]]
341
355
"""
356
+ custom_param_groups = optimizer_opts .pop ("param_groups_custom" , None )
357
+ if custom_param_groups is not None :
358
+ assert callable (custom_param_groups ), f"invalid param_groups_custom { custom_param_groups !r} "
359
+ rf_model = pt_module_to_wrapped_rf_module (self .network )
360
+ custom_param_groups = custom_param_groups (
361
+ model = self .network , rf_model = rf_model , optimizer_class = optim_class , optimizer_opts = optimizer_opts
362
+ )
363
+ return custom_param_groups
364
+
342
365
network_params = self .network .parameters ()
343
366
344
- # By default insert the weight_decay constraints in the optimizer, as this is default PyTorch behavior.
367
+ # By default, insert the weight_decay constraints in the optimizer, as this is default PyTorch behavior.
345
368
# If the user doesn't accept this, throw an error message.
346
369
assert self .config .bool ("decouple_constraints" , True ), (
347
370
"L2/weight_decay constraints are decoupled in PyTorch, but "
@@ -366,23 +389,44 @@ def _get_optimizer_param_groups(self, optim_class, optimizer_opts):
366
389
# Parameters without weight decay: biases + LayerNorm/Embedding layers.
367
390
wd_params = set ()
368
391
no_wd_params = set ()
369
- blacklist_wd_modules = (torch .nn .LayerNorm , torch .nn .Embedding )
392
+ blacklist_wd_modules = optimizer_opts .pop ("weight_decay_modules_blacklist" , None )
393
+ if blacklist_wd_modules is None :
394
+ blacklist_wd_modules = (torch .nn .LayerNorm , torch .nn .Embedding )
395
+ else :
396
+ blacklist_wd_modules = _wrap_user_blacklist_wd_modules (blacklist_wd_modules )
397
+ custom_include_check = optimizer_opts .pop ("weight_decay_custom_include_check" , None )
398
+ if custom_include_check :
399
+ assert callable (custom_include_check ), f"invalid weight_decay_custom_include_check { custom_include_check !r} "
370
400
# Tracker of visited parameters to only add each parameter once, in case two modules share common parameters.
371
401
# We need the wrapper class RefIdEq because Parameters are compared by value and not by reference.
372
402
visited_params : Set [RefIdEq [torch .nn .Parameter ]] = set ()
373
403
for module_name , module in self .network .named_modules ():
374
404
module_name : str
375
405
module : torch .nn .Module
406
+ rf_module = pt_module_to_wrapped_rf_module (module )
376
407
for param_name , param in module .named_parameters (recurse = False ):
377
408
param_name : str
378
409
param : torch .nn .Parameter
379
410
if RefIdEq (param ) in visited_params :
380
411
continue
381
412
visited_params .add (RefIdEq (param ))
382
413
full_param_name = "%s.%s" % (module_name , param_name ) if module_name else param_name
383
- if param_name .endswith ("bias" ):
384
- no_wd_params .add (full_param_name )
385
- elif param_name .endswith ("weight" ) and isinstance (module , blacklist_wd_modules ):
414
+ custom_include = None
415
+ if custom_include_check :
416
+ custom_include = custom_include_check (
417
+ module = module , rf_module = rf_module , full_param_name = param_name , param = param
418
+ )
419
+ if custom_include is not None :
420
+ assert isinstance (custom_include , bool ), "weight_decay_custom_include_check did not return bool"
421
+ if custom_include :
422
+ wd_params .add (full_param_name )
423
+ else :
424
+ no_wd_params .add (full_param_name )
425
+ elif (
426
+ param_name .endswith ("bias" )
427
+ or isinstance (module , blacklist_wd_modules )
428
+ or isinstance (rf_module , blacklist_wd_modules )
429
+ ):
386
430
no_wd_params .add (full_param_name )
387
431
else :
388
432
wd_params .add (full_param_name )
@@ -394,3 +438,17 @@ def _get_optimizer_param_groups(self, optim_class, optimizer_opts):
394
438
]
395
439
396
440
return optim_groups
441
+
442
+
443
+ def _wrap_user_blacklist_wd_modules (
444
+ mods : Sequence [Union [str , Type [rf .Module ], Type [torch .nn .Module ]]]
445
+ ) -> Tuple [type , ...]:
446
+ assert isinstance (mods , (list , tuple )), f"invalid blacklist_weight_decay_modules { mods !r} "
447
+ res = []
448
+ for mod in mods :
449
+ if isinstance (mod , str ):
450
+ assert mod .startswith ("torch." ) or mod .startswith ("rf." ), f"invalid blacklist_weight_decay_modules { mods !r} "
451
+ mod = eval (mod )
452
+ assert issubclass (mod , (rf .Module , torch .nn .Module )), f"invalid blacklist_weight_decay_modules { mods !r} "
453
+ res .append (mod )
454
+ return tuple (res )
0 commit comments