diff --git a/references/classification/train.py b/references/classification/train.py index 29de4fce91c..2fbe61dd65f 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -40,13 +40,13 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg if args.clip_grad_norm is not None: # we should unscale the gradients of optimizer's assigned params if do gradient clipping scaler.unscale_(optimizer) - nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) + nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) scaler.step(optimizer) scaler.update() else: loss.backward() if args.clip_grad_norm is not None: - nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) + nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) optimizer.step() if model_ema and i % args.model_ema_steps == 0: diff --git a/references/classification/utils.py b/references/classification/utils.py index ac09bd69d86..473684fe162 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -409,11 +409,3 @@ def reduce_across_processes(val): dist.barrier() dist.all_reduce(t) return t - - -def get_optimizer_params(optimizer): - """Generator to iterate over all parameters in the optimizer param_groups.""" - - for group in optimizer.param_groups: - for p in group["params"]: - yield p