-
Notifications
You must be signed in to change notification settings - Fork 693
[memory Improvement] delete logits before bwd #1235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[memory Improvement] delete logits before bwd #1235
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1235
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 71920dd with merge base 1157b94 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1235 +/- ##
==========================================
+ Coverage 67.81% 71.33% +3.52%
==========================================
Files 219 221 +2
Lines 9908 10013 +105
==========================================
+ Hits 6719 7143 +424
+ Misses 3189 2870 -319 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
ebsmothers
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice find! I think this should be a free win. Looks like some CI jobs are failing but otherwise no concerns from me
|
Hi @felipemello1 I saw that you added "nproc8" config to the table. Could you elaborate what is that. Is that FSDP with 8x GPUs?. Im asking as I see it has a lower memory footprint on FFT |
@musabgultekin you are correct, 8xA100 with FSDP |
|
@SalmanMohammadi since you are touching some RL recipes, if you have a chance to test it there too and you dont mind, there may be similar gains :) |
|
Hey @felipemello1. Thanks for your work here - a very neat change. I can take a look to see if there's a neat way around this. As an aside, how did the improvements you saw in memory usage scale with batch size/max seq len? |
I think that doing .detach().cpu() should do the trick? But then it may make the code a bit ugly/weird. Before going down the rabit hole of how to organize the code, I guess we could just test it, deleting the metrics logging, and see if it impacts memory?
13GB less for 24k sequence len with QLoRA, bsz=1, but i dont have a graph showing the % diff for multiple seq_len and bsz :/ |
|
Thanks so much for pointing me towards this @felipemello1 . I made a hopefully minimal change to maintain the logging behaviour whilst deleting the unnecessary logits. This only affects DPO, since RLHF is still in the works, but I've done my best to be ruthless with freeing memory wherever I can when implementing RLHF. If you're interested I wouldn't say no to a review in the relevant functions : ) |

Context
What is the purpose of this PR? Is it to
Inspired by #1046
Releasing logits before backward releases memory and reduces peak allocated memory
Changelog
delete the logits before backward for lora/FFT/QAT recipes. I didn't do it for RL, since there is a bit more complexity there. Lora distributed already had it.
Test plan
ran it for 5 epochs with/without the change. Same loss and tok/s, but lower memory
pre-commit install)pytest testspytest tests -m integration_test