From 2910da6879a53a1b78f9415d46db199a7d443cdc Mon Sep 17 00:00:00 2001 From: Yiwen Song Date: Mon, 1 Nov 2021 17:08:37 -0700 Subject: [PATCH 1/4] [references] Adding gradient clipping --- references/classification/train.py | 7 +++++++ references/classification/utils.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/references/classification/train.py b/references/classification/train.py index 79ba21b263d..78ea74ea12a 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -42,6 +42,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg else: loss = criterion(output, target) loss.backward() + + if args.clip_grad_norm is not None: + nn.utils.clip_grad_norm_(utils.master_params(optimizer), args.clip_grad_norm) + optimizer.step() if model_ema and i % args.model_ema_steps == 0: @@ -472,6 +476,9 @@ def get_args_parser(add_help=True): parser.add_argument( "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) + parser.add_argument( + "--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)" + ) # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") diff --git a/references/classification/utils.py b/references/classification/utils.py index 473684fe162..f7d24f218a8 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -409,3 +409,21 @@ def reduce_across_processes(val): dist.barrier() dist.all_reduce(t) return t + + +try: + import apex + + apex_available = True +except ImportError: + apex_available = False + +def master_params(optimizer): + """Generator to iterate over all parameters in the optimizer param_groups.""" + + if apex_available: + yield from apex.amp.master_params(optimizer) + else: + for group in optimizer.param_groups: + for p in group["params"]: + yield p \ No newline at end of file From 7d7eb7a795d0fb0f29991b4a719cefa59e82f5ff Mon Sep 17 00:00:00 2001 From: sallysyw Date: Tue, 2 Nov 2021 00:33:46 +0000 Subject: [PATCH 2/4] ufmt formatting --- references/classification/train.py | 6 ++---- references/classification/utils.py | 5 +++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 78ea74ea12a..b8d8e1773bb 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -42,7 +42,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg else: loss = criterion(output, target) loss.backward() - + if args.clip_grad_norm is not None: nn.utils.clip_grad_norm_(utils.master_params(optimizer), args.clip_grad_norm) @@ -476,9 +476,7 @@ def get_args_parser(add_help=True): parser.add_argument( "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) - parser.add_argument( - "--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)" - ) + parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") diff --git a/references/classification/utils.py b/references/classification/utils.py index f7d24f218a8..41c6c8a428e 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -418,12 +418,13 @@ def reduce_across_processes(val): except ImportError: apex_available = False + def master_params(optimizer): """Generator to iterate over all parameters in the optimizer param_groups.""" - + if apex_available: yield from apex.amp.master_params(optimizer) else: for group in optimizer.param_groups: for p in group["params"]: - yield p \ No newline at end of file + yield p From 28cf637f3e156c1fb91d6927ebd961b2697d3227 Mon Sep 17 00:00:00 2001 From: sallysyw Date: Tue, 2 Nov 2021 20:57:03 +0000 Subject: [PATCH 3/4] remove apex code --- references/classification/utils.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/references/classification/utils.py b/references/classification/utils.py index 41c6c8a428e..cd4a55005e3 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -411,20 +411,9 @@ def reduce_across_processes(val): return t -try: - import apex - - apex_available = True -except ImportError: - apex_available = False - - def master_params(optimizer): """Generator to iterate over all parameters in the optimizer param_groups.""" - if apex_available: - yield from apex.amp.master_params(optimizer) - else: - for group in optimizer.param_groups: - for p in group["params"]: - yield p + for group in optimizer.param_groups: + for p in group["params"]: + yield p From 96eb6ef6f232fffbb3abbe5080c25589e10269fa Mon Sep 17 00:00:00 2001 From: sallysyw Date: Tue, 2 Nov 2021 21:37:49 +0000 Subject: [PATCH 4/4] resolve naming issue --- references/classification/train.py | 2 +- references/classification/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index b8d8e1773bb..d0b986a0100 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -44,7 +44,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg loss.backward() if args.clip_grad_norm is not None: - nn.utils.clip_grad_norm_(utils.master_params(optimizer), args.clip_grad_norm) + nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) optimizer.step() diff --git a/references/classification/utils.py b/references/classification/utils.py index cd4a55005e3..ac09bd69d86 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -411,7 +411,7 @@ def reduce_across_processes(val): return t -def master_params(optimizer): +def get_optimizer_params(optimizer): """Generator to iterate over all parameters in the optimizer param_groups.""" for group in optimizer.param_groups: