diff --git a/references/classification/train.py b/references/classification/train.py index 66dc7801d0c..210d47f7dc9 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -42,6 +42,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg else: loss = criterion(output, target) loss.backward() + + if args.clip_grad_norm is not None: + nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) + optimizer.step() if model_ema and i % args.model_ema_steps == 0: @@ -472,6 +476,7 @@ def get_args_parser(add_help=True): parser.add_argument( "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) + parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") diff --git a/references/classification/utils.py b/references/classification/utils.py index 473684fe162..ac09bd69d86 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -409,3 +409,11 @@ def reduce_across_processes(val): dist.barrier() dist.all_reduce(t) return t + + +def get_optimizer_params(optimizer): + """Generator to iterate over all parameters in the optimizer param_groups.""" + + for group in optimizer.param_groups: + for p in group["params"]: + yield p