diff --git a/references/classification/train.py b/references/classification/train.py index 220cf001d60..0b855d105c9 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -30,23 +30,24 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): start_time = time.time() image, target = image.to(device), target.to(device) - output = model(image) + with torch.cuda.amp.autocast(enabled=args.amp): + output = model(image) + loss = criterion(output, target) optimizer.zero_grad() if args.amp: - with torch.cuda.amp.autocast(): - loss = criterion(output, target) scaler.scale(loss).backward() + if args.clip_grad_norm is not None: + # we should unscale the gradients of optimizer's assigned params if do gradient clipping + scaler.unscale_(optimizer) + nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) scaler.step(optimizer) scaler.update() else: - loss = criterion(output, target) loss.backward() - - if args.clip_grad_norm is not None: - nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) - - optimizer.step() + if args.clip_grad_norm is not None: + nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) + optimizer.step() if model_ema and i % args.model_ema_steps == 0: model_ema.update_parameters(model)