-
Notifications
You must be signed in to change notification settings - Fork 7.1k
fix bug in training model by amp #4874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a7fa1e6
4e5f2b4
88cc0b4
24e404d
c22cfc0
25842f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,23 +30,24 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg | |
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): | ||
start_time = time.time() | ||
image, target = image.to(device), target.to(device) | ||
output = model(image) | ||
with torch.cuda.amp.autocast(enabled=args.amp): | ||
output = model(image) | ||
loss = criterion(output, target) | ||
|
||
optimizer.zero_grad() | ||
if args.amp: | ||
with torch.cuda.amp.autocast(): | ||
loss = criterion(output, target) | ||
scaler.scale(loss).backward() | ||
if args.clip_grad_norm is not None: | ||
# we should unscale the gradients of optimizer's assigned params if do gradient clipping | ||
scaler.unscale_(optimizer) | ||
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) | ||
scaler.step(optimizer) | ||
scaler.update() | ||
else: | ||
loss = criterion(output, target) | ||
loss.backward() | ||
|
||
if args.clip_grad_norm is not None: | ||
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) | ||
|
||
optimizer.step() | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if args.clip_grad_norm is not None: | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for bringing this out - I was just referring to ClassyVision's implementation before. Given that the official documentation is using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Classy might had it like this to support learnable params on the loss (we don't have this on Vision). Another reason might be that it was convenient in terms of code structure. @mannatsingh Do you have any idea why it was used like that in Classy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, so the only important reason that I can think of is that Apex's AMP works on its own (different) parameters which are disconnected from the model in certain settings (like O2). If you used the other approach, you would not actually be clipping the gradients. I'm not sure if torchvision even supports Apex AMP though! Other situations are manageable, for instance, if you optimize the model and the loss, you just need to make sure to use both everywhere (it's slightly risky but not a blocker). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks a lot @mannatsingh, this was very helpful. I think this means we can with |
||
optimizer.step() | ||
|
||
if model_ema and i % args.model_ema_steps == 0: | ||
model_ema.update_parameters(model) | ||
|
Uh oh!
There was an error while loading. Please reload this page.