Skip to content

Conversation

@felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Jul 27, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Inspired by #1046

Releasing logits before backward releases memory and reduces peak allocated memory

image

Changelog

delete the logits before backward for lora/FFT/QAT recipes. I didn't do it for RL, since there is a bit more complexity there. Lora distributed already had it.

Test plan

ran it for 5 epochs with/without the change. Same loss and tok/s, but lower memory

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1235

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 71920dd with merge base 1157b94 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 27, 2024
@codecov-commenter
Copy link

codecov-commenter commented Jul 27, 2024

Codecov Report

Attention: Patch coverage is 0% with 4 lines in your changes missing coverage. Please review.

Project coverage is 71.33%. Comparing base (7eb89e2) to head (71920dd).
Report is 694 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 1 Missing ⚠️
recipes/qat_distributed.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1235      +/-   ##
==========================================
+ Coverage   67.81%   71.33%   +3.52%     
==========================================
  Files         219      221       +2     
  Lines        9908    10013     +105     
==========================================
+ Hits         6719     7143     +424     
+ Misses       3189     2870     -319     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find! I think this should be a free win. Looks like some CI jobs are failing but otherwise no concerns from me

@joecummings joecummings merged commit 8a98fba into meta-pytorch:main Jul 28, 2024
@musabgultekin
Copy link
Contributor

musabgultekin commented Jul 29, 2024

Hi @felipemello1 I saw that you added "nproc8" config to the table. Could you elaborate what is that. Is that FSDP with 8x GPUs?. Im asking as I see it has a lower memory footprint on FFT

@felipemello1
Copy link
Contributor Author

Hi @felipemello1 I saw that you added "nproc8" config to the table. Could you elaborate what is that. Is that FSDP with 8x GPUs?. Im asking as I see it has a lower memory footprint on FFT

@musabgultekin you are correct, 8xA100 with FSDP

@felipemello1
Copy link
Contributor Author

felipemello1 commented Jul 29, 2024

@SalmanMohammadi since you are touching some RL recipes, if you have a chance to test it there too and you dont mind, there may be similar gains :)

@SalmanMohammadi
Copy link
Contributor

Hey @felipemello1. Thanks for your work here - a very neat change.
The main blocker in RL recipes is that the logits are also used for metric logging as a way to measure the extent of divergence for the model being trained, or to measure how well differentiated logits/logprobs for "preferred" vs "non-preferred" are.

I can take a look to see if there's a neat way around this. As an aside, how did the improvements you saw in memory usage scale with batch size/max seq len?

@felipemello1
Copy link
Contributor Author

felipemello1 commented Jul 29, 2024

The main blocker in RL recipes is that the logits are also used for metric logging

I think that doing .detach().cpu() should do the trick? But then it may make the code a bit ugly/weird. Before going down the rabit hole of how to organize the code, I guess we could just test it, deleting the metrics logging, and see if it impacts memory?

how did the improvements you saw in memory usage scale with batch size/max seq len?

13GB less for 24k sequence len with QLoRA, bsz=1, but i dont have a graph showing the % diff for multiple seq_len and bsz :/

image

@SalmanMohammadi
Copy link
Contributor

SalmanMohammadi commented Jul 29, 2024

Thanks so much for pointing me towards this @felipemello1 . I made a hopefully minimal change to maintain the logging behaviour whilst deleting the unnecessary logits.

This only affects DPO, since RLHF is still in the works, but I've done my best to be ruthless with freeing memory wherever I can when implementing RLHF. If you're interested I wouldn't say no to a review in the relevant functions : )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants