Skip to content

Commit 031e129

Browse files
xiaohu2015datumboxprabhat00155
authored
fix bug in training model by amp (#4874)
* fix bug in amp * fix bug in training by amp * support use gradient clipping when amp is enabled Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Prabhat Roy <[email protected]>
1 parent 8af692a commit 031e129

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

references/classification/train.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,24 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
3030
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
3131
start_time = time.time()
3232
image, target = image.to(device), target.to(device)
33-
output = model(image)
33+
with torch.cuda.amp.autocast(enabled=args.amp):
34+
output = model(image)
35+
loss = criterion(output, target)
3436

3537
optimizer.zero_grad()
3638
if args.amp:
37-
with torch.cuda.amp.autocast():
38-
loss = criterion(output, target)
3939
scaler.scale(loss).backward()
40+
if args.clip_grad_norm is not None:
41+
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
42+
scaler.unscale_(optimizer)
43+
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
4044
scaler.step(optimizer)
4145
scaler.update()
4246
else:
43-
loss = criterion(output, target)
4447
loss.backward()
45-
46-
if args.clip_grad_norm is not None:
47-
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
48-
49-
optimizer.step()
48+
if args.clip_grad_norm is not None:
49+
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
50+
optimizer.step()
5051

5152
if model_ema and i % args.model_ema_steps == 0:
5253
model_ema.update_parameters(model)

0 commit comments

Comments
 (0)