Skip to content

support amp training for detection models #4933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 15, 2021
Merged

support amp training for detection models #4933

merged 7 commits into from
Nov 15, 2021

Conversation

xiaohu2015
Copy link
Contributor

@xiaohu2015 xiaohu2015 commented Nov 14, 2021

The pr is about #4509.
Since amp is supported on classification training, I also modify some files to support amp training on detetction models.

cc @datumbox

@facebook-github-bot
Copy link

facebook-github-bot commented Nov 14, 2021

💊 CI failures summary and remediations

As of commit bfc9225 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @xiaohu2015!

There are some related linter issues, but overall it looks good.

@prabhat00155 could you have also a look since you did the classification one?

xiaohu2015 and others added 2 commits November 15, 2021 19:12
@xiaohu2015
Copy link
Contributor Author

@datumbox Thanks. I have test the amp code, it works well.

@datumbox
Copy link
Contributor

@xiaohu2015 There are a couple of more linter issues (spaces). Have a look a the CI job, at the end it will show you the errors. For your convenience here are the things you need to change to keep it happy:

diff --git a/references/detection/train.py b/references/detection/train.py
index 5c50dcfa..ae13a32b 100644
--- a/references/detection/train.py
+++ b/references/detection/train.py
@@ -143,7 +143,7 @@ def get_args_parser(add_help=True):
 
     # Prototype models only
     parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
-    
+
     # Mixed precision training parameters
     parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
 
@@ -211,9 +211,9 @@ def main(args):
 
     params = [p for p in model.parameters() if p.requires_grad]
     optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
-    
+
     scaler = torch.cuda.amp.GradScaler() if args.amp else None
-    
+
     args.lr_scheduler = args.lr_scheduler.lower()
     if args.lr_scheduler == "multisteplr":
         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)

Other than that, the changes look good to me. :)

I'll leave it to @prabhat00155 to do the final checks on our side and merge when ready.

Copy link
Contributor

@prabhat00155 prabhat00155 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xiaohu2015!

@prabhat00155 prabhat00155 merged commit 59ec1df into pytorch:main Nov 15, 2021
@github-actions
Copy link

Hey @prabhat00155!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* support amp training

* support amp training

* support amp training

* Update references/detection/train.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

* Update references/detection/engine.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

* fix lint issues

Co-authored-by: Vasilis Vryniotis <[email protected]>
facebook-github-bot pushed a commit that referenced this pull request Nov 17, 2021
Summary:
* support amp training

* support amp training

* support amp training

* Update references/detection/train.py

* Update references/detection/engine.py

* fix lint issues

Reviewed By: datumbox

Differential Revision: D32470476

fbshipit-source-id: d0ef0c561b4eed2d0cf654741bd2d108ce65411e

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants