-
Notifications
You must be signed in to change notification settings - Fork 693
Add tests and implementation for disabling dropout layers in models #2378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2378
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 3 PendingAs of commit f89ed8f with merge base 8c9235e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High quality work!
I am uncomfortable with how many one-off files though we have for various utilities during training. Can you expose this through the training module __init__ and make the file "private" (prefix an underscore) so that I can follow up and consolidate some of these utilities?
torchtune/training/model_util.py
Outdated
| model (torch.nn.Module): The model in which dropout layers should be disabled. | ||
| """ | ||
| for module in model.modules(): | ||
| if isinstance(module, torch.nn.Dropout): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also check that the Dropout set to something other than zero. No need to reset Dropout to zero if it's already zero.
torchtune/training/model_util.py
Outdated
| for module in model.modules(): | ||
| if isinstance(module, torch.nn.Dropout): | ||
| warnings.warn( | ||
| f"Dropout found in {module}. This is likely to cause issues during training. Disabling." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| f"Dropout found in {module}. This is likely to cause issues during training. Disabling." | |
| f"Found Dropout with value {module.p} in module {module}. Setting to zero." |
|
Thanks for adding this @Ankur-singh! Could you also update our recipes to use this, please? |
|
@SalmanMohammadi and @joecummings I have made all the requested changes😁 |
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🫡
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
#2353
Changelog
What are the changes made in this PR?
disable_dropoutfunction totorchtune/training/model_util.pydisable_dropoutfunctionTest plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install)pytest testspytest tests -m integration_testUX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example