Skip to content

fix bug in training model by amp #4874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 9, 2021
Merged

fix bug in training model by amp #4874

merged 6 commits into from
Nov 9, 2021

Conversation

xiaohu2015
Copy link
Contributor

@xiaohu2015 xiaohu2015 commented Nov 5, 2021

This PR:

  • Moves model prediction inside the autocast context
  • Performs gradient clipping properly when autocast is enabled.
  • Stops calling optimizer step twice when autocast is enabled.

cc @datumbox

@facebook-github-bot
Copy link

facebook-github-bot commented Nov 5, 2021

💊 CI failures summary and remediations

As of commit 25842f6 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

1 failure not recognized by patterns:

Job Step Action
CircleCI binary_libtorchvision_ops_android Build 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@xiaohu2015 xiaohu2015 changed the title fix bug in train model using amp fix bug in training model by amp Nov 5, 2021
@datumbox
Copy link
Contributor

datumbox commented Nov 5, 2021

@xiaohu2015 Thanks for the PR. If there is no issue related to describe the bug, please add relevant information on the PR description so that it's clear what was the previous issue, etc.

@prabhat00155 Could you please have a look as you've recently work on this at #4547?

@datumbox datumbox requested a review from prabhat00155 November 5, 2021 16:27
@xiaohu2015
Copy link
Contributor Author

@xiaohu2015 Thanks for the PR. If there is no issue related to describe the bug, please add relevant information on the PR description so that it's clear what was the previous issue, etc.

@prabhat00155 Could you please have a look as you've recently work on this at #4547?

@xiaohu2015 Thanks for the PR. If there is no issue related to describe the bug, please add relevant information on the PR description so that it's clear what was the previous issue, etc.

@prabhat00155 Could you please have a look as you've recently work on this at #4547?

@prabhat00155 can you review the PR? I found the amp was not working, so I update the training code to fix bugs.

Copy link
Contributor

@prabhat00155 prabhat00155 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xiaohu2015! Could you please upload the logs before and after the changes, since this is not covered by our unit tests?

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking this as "Request changes" to avoid accidental merges before we gather enough about the nature of the bug, the training logs and we do proper investigation on our side.

@xiaohu2015 You can unblock this by adding context info on this PR as discussed above. Thanks!

@xiaohu2015
Copy link
Contributor Author

Thanks @xiaohu2015! Could you please upload the logs before and after the changes, since this is not covered by our unit tests?

After changing the code, I trained ResNet50, the result is 75.7 (use amp) and 75.5 (not use amp).

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaohu2015 Your PR contains some good corrections but it's still very thin on information on the bug itself.

After changing the code, I trained ResNet50, the result is 75.7 (use amp) and 75.5 (not use amp).

This is not necessarily an indication of an improvement. Doing multiple runs with different seeds can lead to slightly different results every time due to the random initialization and random transforms applied to the data.


optimizer.step()
if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sallysyw As far as I see, you introduced the get_optimizer_params() at #4824. Could you talk about the reasons you didn't grab the parameters directly from the model? Aka doing:
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing this out - I was just referring to ClassyVision's implementation before.

Given that the official documentation is using model.parameters(), I think we can switch to it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Classy might had it like this to support learnable params on the loss (we don't have this on Vision). Another reason might be that it was convenient in terms of code structure.

@mannatsingh Do you have any idea why it was used like that in Classy?

Copy link

@mannatsingh mannatsingh Nov 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so the only important reason that I can think of is that Apex's AMP works on its own (different) parameters which are disconnected from the model in certain settings (like O2). If you used the other approach, you would not actually be clipping the gradients. I'm not sure if torchvision even supports Apex AMP though!

Other situations are manageable, for instance, if you optimize the model and the loss, you just need to make sure to use both everywhere (it's slightly risky but not a blocker).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @mannatsingh, this was very helpful. I think this means we can with model.parameters().

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaohu2015 I've marked as resolved all the "FYI" comments above and left only those that need to be addressed to merge the PR.

Effectively the only thing required is to add support of gradient clipping when amp is active. I provided a reference from the documentation on how to do it. It's worth refactoring slightly the code to simplify according to the comments.

Please let me know if you plan to continue working on the PR. Thanks!

@xiaohu2015
Copy link
Contributor Author

@xiaohu2015 Your PR contains some good corrections but it's still very thin on information on the bug itself.

After changing the code, I trained ResNet50, the result is 75.7 (use amp) and 75.5 (not use amp).

This is not necessarily an indication of an improvement. Doing multiple runs with different seeds can lead to slightly different results every time due to the random initialization and random transforms applied to the data.

yes, the experment is to check that the amp is working, not to prove the model trained with amp is better.

@xiaohu2015
Copy link
Contributor Author

@xiaohu2015 I've marked as resolved all the "FYI" comments above and left only those that need to be addressed to merge the PR.

Effectively the only thing required is to add support of gradient clipping when amp is active. I provided a reference from the documentation on how to do it. It's worth refactoring slightly the code to simplify according to the comments.

Please let me know if you plan to continue working on the PR. Thanks!

thanks, I do some modification with the help of the document

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaohu2015 Thanks for the PR, LGTM!

Given that the code was modified heavily and it's not covered with tests, it would be good to do a run on our side to confirm that everything works as expected.

@prabhat00155 Let me know if you have the bandwidth for this.

@sallysyw Concerning the simplification discussed at https://github.com/pytorch/vision/pull/4874/files#r744539820, is this something you would be interested in doing or shall we create an issue about it?

@prabhat00155
Copy link
Contributor

Given that the code was modified heavily and it's not covered with tests, it would be good to do a run on our side to confirm that everything works as expected.

@prabhat00155 Let me know if you have the bandwidth for this.

Yeah sure, let me kick off a training run.

@prabhat00155
Copy link
Contributor

This runs fine. Here is the output log:
output_log.txt

@prabhat00155 prabhat00155 merged commit 031e129 into pytorch:main Nov 9, 2021
@github-actions
Copy link

github-actions bot commented Nov 9, 2021

Hey @prabhat00155!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Nov 15, 2021
Summary:
* fix bug in amp

* fix bug in training by amp

* support use gradient clipping when amp is enabled

Reviewed By: datumbox

Differential Revision: D32298968

fbshipit-source-id: 4366674522dc0faf5688207faa7e3cd33be2a6ea

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Prabhat Roy <[email protected]>
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* fix bug in amp

* fix bug in training by amp

* support use gradient clipping when amp is enabled

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Prabhat Roy <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants