-
Notifications
You must be signed in to change notification settings - Fork 693
compile utils and version-gating #1512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compile utils and version-gating #1512
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1512
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2e86cce with merge base 31a95a9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| None | ||
| """ | ||
| backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") | ||
| if torch_version_ge("2.5.0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
: (
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Come 2.5 our codebase will be so beautiful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh but then we will just have a bunch of 2.6 nightly features..
|
@gau-nernst fyi |
To clarify
Just curious, is the error related to chunking in LM head, or the chunked CE loss? Also, does the error correlate to a jump in peak memory reserved? From my experience running the built-in configs, peak memory consistently jumps up after a while. I'm guessing it is due to an increase in seq_len in the dataset due to dynamic padding. |
that is interesting. Alpaca does have some very large sequences, right? could it be that on step 20 one of those pop up? Maybe we can rerun it and print tokens.shape, and see what happens on step 20 |
Yeah exactly. However I think I had a successful run of full finetune with compile full model + compile chunked CE on 2.4, but the LoRA one failed multiple times.
I believe it's the LM head chunking. Because if I just compile the chunked CE loss it actually seems to work OK. So I made a call here: alternatively we could just disable model compile, warn the user, and compile the loss only (I think that would work). But I figure it's better to just compile everything and tell them that this particular loss isn't supported. Neither option is ideal in my mind but lmk if you have a strong opinion.
Here's the reserved memory just before the failure. There's no big jump but that may not account for the latest step, could just log the seq len for each batch to test your hypothesis.
There is also a warning about |
|
The warning about hitting cache size limit indicates that something keeps recompiling. After which, torch will fallback to eager mode I think. Perhaps this switch from compile to eager mode causes things to break, possibly due to corrupted state/buffers. Since the only dynamic thing during finetuning is the seq len, I'm guessing torch 2.4 (and maybe torch nightly too?) cannot do dynamic-shape compile for LM head chunking logic. To test this hypothesis, could you try wrapping the LM chunking logic in a function and decorate it with |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1512 +/- ##
==========================================
+ Coverage 27.10% 27.21% +0.10%
==========================================
Files 284 286 +2
Lines 13813 13825 +12
==========================================
+ Hits 3744 3762 +18
+ Misses 10069 10063 -6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Yeah disabling the LM chunking does seem to solve this, and I don't see too much negative perf impact either. So we could add a method |
|
I believe "compile per layer" also does not compile "LM head chunking", so maybe nightly also doesn't work with LM head chunking compile -> good idea to always add (Even if you want to do this with version guard, you probably can get away with calling |
|
Oh you're completely right. OK then I think it's a no-brainer |
felipemello1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. Just one comment: We have to also update gemma transformer, since we havent migrated gemma to use the tiedembedding
felipemello1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. Just one comment: We have to also update gemma transformer, since we havent migrated gemma to use the tiedembedding

Fixes #1498
Per-layer compile doesn't work on stable PyTorch so we need to version-gate. Also our compile code is getting a bit messy so might as well add a couple handy utils while we're at it. Note that there are problems when compiling the full model + chunked cross-entropy loss on PyTorch 2.4 that (for me) don't surface until ~20 minutes into running. This is actually not a problem with chunked cross-entropy (we can compile that alone just fine), but a problem compiling the chunked output projection.
Thanks to @gau-nernst for pointing out that we can (/should) just disable compilation of the chunked output layer regardless of the PyTorch version. This fixes the compile errors on stable PyTorch, and with per-layer compile (i.e. what we do on nightlies) we don't compile the chunked output layer anyways.
So we add a handy
chunked_outputmethod to each of our transformer classes and call into that whenever we're doing chunking. This also allows us to wrap that method with@torch.compiler.disable. So to summarize:compile=Truewe compile the full model like we used to. But if we're doing output chunking (which is now the default), we disable compile for the chunked output projection. We still compile the loss on its own no matter what.compile=Truewe use per-layer compile. This means that disabling compilation on the chunked output projection is a no-op anyways. So we still compile every layer and compile the loss separately, just like we do in main today.Test plan
Run the following on both stable and nightly PyTorch
(1) LoRA Llama2 single-device recipe
(2) QLoRA Gemma 2B single-device recipe
(3) Full finetune Qwen2 1.5B single-device recipe
Stable PyTorch
LoRA Llama2-7B
QLoRA Gemma-2B
Full finetune Qwen2-1.5B
Nightly PyTorch
LoRA Llama2-7B
QLoRA Gemma-2B
Full finetune Qwen2-1.5B