Skip to content

Conversation

@ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Sep 6, 2024

Fixes #1498

Per-layer compile doesn't work on stable PyTorch so we need to version-gate. Also our compile code is getting a bit messy so might as well add a couple handy utils while we're at it. Note that there are problems when compiling the full model + chunked cross-entropy loss on PyTorch 2.4 that (for me) don't surface until ~20 minutes into running. This is actually not a problem with chunked cross-entropy (we can compile that alone just fine), but a problem compiling the chunked output projection.

Thanks to @gau-nernst for pointing out that we can (/should) just disable compilation of the chunked output layer regardless of the PyTorch version. This fixes the compile errors on stable PyTorch, and with per-layer compile (i.e. what we do on nightlies) we don't compile the chunked output layer anyways.

So we add a handy chunked_output method to each of our transformer classes and call into that whenever we're doing chunking. This also allows us to wrap that method with @torch.compiler.disable. So to summarize:

  1. On stable PyTorch, when compile=True we compile the full model like we used to. But if we're doing output chunking (which is now the default), we disable compile for the chunked output projection. We still compile the loss on its own no matter what.
  2. On nightly PyTorch, when compile=True we use per-layer compile. This means that disabling compilation on the chunked output projection is a no-op anyways. So we still compile every layer and compile the loss separately, just like we do in main today.

Test plan

Run the following on both stable and nightly PyTorch

(1) LoRA Llama2 single-device recipe
(2) QLoRA Gemma 2B single-device recipe
(3) Full finetune Qwen2 1.5B single-device recipe

Stable PyTorch

LoRA Llama2-7B

tune run lora_finetune_single_device --config llama2/7B_lora_single_device compile=True \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=pr_1512 \
 metric_logger.name=llama2_lora_single_stable_updated log_peak_memory_stats=True \
gradient_accumulation_steps=1 max_steps_per_epoch=100
Screenshot 2024-09-07 at 11 22 17 AM

QLoRA Gemma-2B

tune run lora_finetune_single_device --config gemma/2B_qlora_single_device compile=True \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=pr_1512 \
 metric_logger.name=qlora_gemma_single_stable_latest log_peak_memory_stats=True \
gradient_accumulation_steps=1 max_steps_per_epoch=100 epochs=1
Screenshot 2024-09-07 at 11 18 00 AM

Full finetune Qwen2-1.5B

tune run full_finetune_single_device --config qwen2/1.5B_full_single_device \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=pr_1512 \
 metric_logger.name=full_qwen2_single_stable_latest log_peak_memory_stats=True compile=True \
 max_steps_per_epoch=100
Screenshot 2024-09-07 at 11 19 08 AM

Nightly PyTorch

LoRA Llama2-7B

tune run lora_finetune_single_device --config llama2/7B_lora_single_device compile=True \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=pr_1512 \
 metric_logger.name=llama2_lora_single_nightly_updated log_peak_memory_stats=True \
gradient_accumulation_steps=1 max_steps_per_epoch=100
Screenshot 2024-09-07 at 11 18 23 AM

QLoRA Gemma-2B

tune run lora_finetune_single_device --config gemma/2B_qlora_single_device compile=True \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=pr_1512 \
 metric_logger.name=qlora_gemma_single_nightly_updated log_peak_memory_stats=True \
 gradient_accumulation_steps=1 max_steps_per_epoch=100 epochs=1
Screenshot 2024-09-07 at 11 20 05 AM

Full finetune Qwen2-1.5B

tune run full_finetune_single_device --config qwen2/1.5B_full_single_device \
 metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=pr_1512 \
 metric_logger.name=full_qwen2_single_nightly_latest log_peak_memory_stats=True compile=True \
 max_steps_per_epoch=100
Screenshot 2024-09-07 at 11 19 37 AM

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1512

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2e86cce with merge base 31a95a9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 6, 2024
None
"""
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
if torch_version_ge("2.5.0"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: (

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Come 2.5 our codebase will be so beautiful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh but then we will just have a bunch of 2.6 nightly features..

@ebsmothers ebsmothers marked this pull request as ready for review September 7, 2024 00:15
@ebsmothers ebsmothers changed the title [WIP] compile utils and version-gating compile utils and version-gating Sep 7, 2024
@ebsmothers
Copy link
Contributor Author

@gau-nernst fyi

@gau-nernst
Copy link
Contributor

Note that there are problems when compiling the full model + chunked cross-entropy loss on PyTorch 2.4 that (for me) don't surface until ~20 minutes into running. I believe it's possible to compile just the loss in this case but there seems to be some issue with the chunking at the end of the model. Rather than get too tricky with things I am just gonna raise an error telling folks to use vanilla cross-entropy if they are trying to compile on 2.4.

To clarify

  • Compile full model + compile vanilla works on 2.4
  • Compile full model + compile chunked CE will error on 2.4 after 20min

Just curious, is the error related to chunking in LM head, or the chunked CE loss? Also, does the error correlate to a jump in peak memory reserved? From my experience running the built-in configs, peak memory consistently jumps up after a while. I'm guessing it is due to an increase in seq_len in the dataset due to dynamic padding.

@felipemello1
Copy link
Contributor

peak memory consistently jumps up after a while. I'm guessing it is due to an increase in seq_len in the dataset due to dynamic padding.

that is interesting. Alpaca does have some very large sequences, right? could it be that on step 20 one of those pop up? Maybe we can rerun it and print tokens.shape, and see what happens on step 20

@ebsmothers
Copy link
Contributor Author

To clarify

  • Compile full model + compile vanilla works on 2.4
  • Compile full model + compile chunked CE will error on 2.4 after 20min

Yeah exactly. However I think I had a successful run of full finetune with compile full model + compile chunked CE on 2.4, but the LoRA one failed multiple times.

Just curious, is the error related to chunking in LM head, or the chunked CE loss?

I believe it's the LM head chunking. Because if I just compile the chunked CE loss it actually seems to work OK. So I made a call here: alternatively we could just disable model compile, warn the user, and compile the loss only (I think that would work). But I figure it's better to just compile everything and tell them that this particular loss isn't supported. Neither option is ideal in my mind but lmk if you have a strong opinion.

Also, does the error correlate to a jump in peak memory reserved? From my experience running the built-in configs, peak memory consistently jumps up after a while. I'm guessing it is due to an increase in seq_len in the dataset due to dynamic padding.

Here's the reserved memory just before the failure. There's no big jump but that may not account for the latest step, could just log the seq len for each batch to test your hypothesis.

Screenshot 2024-09-06 at 8 01 29 PM

There is also a warning about torch._dynamo hit config.cache_size_limit (8) shortly before things break. We were seeing this previously but I believe it was fixed by pytorch/pytorch#134272. So this could be a case of something where we do need the nightlies

@gau-nernst
Copy link
Contributor

The warning about hitting cache size limit indicates that something keeps recompiling. After which, torch will fallback to eager mode I think. Perhaps this switch from compile to eager mode causes things to break, possibly due to corrupted state/buffers.

Since the only dynamic thing during finetuning is the seq len, I'm guessing torch 2.4 (and maybe torch nightly too?) cannot do dynamic-shape compile for LM head chunking logic. To test this hypothesis, could you try wrapping the LM chunking logic in a function and decorate it with torch.compiler.disable?

@codecov-commenter
Copy link

codecov-commenter commented Sep 7, 2024

Codecov Report

Attention: Patch coverage is 33.33333% with 36 lines in your changes missing coverage. Please review.

Project coverage is 27.21%. Comparing base (277fbf8) to head (2e86cce).
Report is 508 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_compile.py 37.03% 17 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 5 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 5 Missing ⚠️
torchtune/modules/transformer.py 50.00% 4 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 3 Missing ⚠️
torchtune/models/gemma/transformer.py 60.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1512      +/-   ##
==========================================
+ Coverage   27.10%   27.21%   +0.10%     
==========================================
  Files         284      286       +2     
  Lines       13813    13825      +12     
==========================================
+ Hits         3744     3762      +18     
+ Misses      10069    10063       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ebsmothers
Copy link
Contributor Author

Since the only dynamic thing during finetuning is the seq len, I'm guessing torch 2.4 (and maybe torch nightly too?) cannot do dynamic-shape compile for LM head chunking logic. To test this hypothesis, could you try wrapping the LM chunking logic in a function and decorate it with torch.compiler.disable?

Yeah disabling the LM chunking does seem to solve this, and I don't see too much negative perf impact either. So we could add a method compute_chunked_output and @torch.compiler.disable that. The one thing I am a bit hesitant about is to add version gating inside the TransformerDecoder itself. But the UX hit of always erroring by default when compile=True on stable might outweight that tbh

@gau-nernst
Copy link
Contributor

gau-nernst commented Sep 7, 2024

I believe "compile per layer" also does not compile "LM head chunking", so maybe nightly also doesn't work with LM head chunking compile -> good idea to always add @torch.compiler.disable to the LM head chunking logic -> no need version guard within model code.

(Even if you want to do this with version guard, you probably can get away with calling model.compute_chunked_output = torch.compiler.disable(model.compute_chunked_output) inside the compile utils -> no version guard inside model code)

@ebsmothers
Copy link
Contributor Author

Oh you're completely right. OK then I think it's a no-brainer

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Just one comment: We have to also update gemma transformer, since we havent migrated gemma to use the tiedembedding

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Just one comment: We have to also update gemma transformer, since we havent migrated gemma to use the tiedembedding

@ebsmothers ebsmothers merged commit c169bcd into meta-pytorch:main Sep 7, 2024
@ebsmothers ebsmothers deleted the per-layer-compile-gating branch September 7, 2024 19:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug] Compile error on stable PyTorch

6 participants