Skip to content

[MoE] DeepEP refactor and fix memory leak during training and inference#2296

Merged
shuhuayu merged 9 commits into
pytorch:mainfrom
shuhuayu:deepep
Jan 29, 2026
Merged

[MoE] DeepEP refactor and fix memory leak during training and inference#2296
shuhuayu merged 9 commits into
pytorch:mainfrom
shuhuayu:deepep

Conversation

@shuhuayu

@shuhuayu shuhuayu commented Jan 29, 2026

Copy link
Copy Markdown
Contributor
  1. Simplified the token permutation logic.
  2. Updated the handle management so there will be no memory leak during training and inference. Related issue: [DeepEP] How is _handle_cache handled during inference? #2273

On training on 16b deepseek v3 model, before the fix there was a growing memory usage.
image
After the fix, the memory usage stabilizes.
image

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 29, 2026
@tianyu-l

Copy link
Copy Markdown
Contributor

@elfiegg please take a look

@elfiegg

elfiegg commented Jan 29, 2026

Copy link
Copy Markdown
Contributor

Looks good, is my understanding correct that this PR mainly implements:

  1. use torch.argsort for sorting the tokens for performance reasons
  2. save handle in dispatch's setup_context, and clear handle cache in combine forward if inference mode else clear it in setup_context

@elfiegg

elfiegg commented Jan 29, 2026

Copy link
Copy Markdown
Contributor

FYI @goldhuang to unblock your work

@shuhuayu

Copy link
Copy Markdown
Contributor Author

Looks good, is my understanding correct that this PR mainly implements:

  1. use torch.argsort for sorting the tokens for performance reasons
  2. save handle in dispatch's setup_context, and clear handle cache in combine forward if inference mode else clear it in setup_context

Yes. During training, handles are saved in both dispatch_ctx and combine_ctx, and the handle in _handle_cache is cleared in combine_setup_context. During inference, no handles saved in op_ctx, and _handle_cache is cleared during combine_forward.

@elfiegg

elfiegg commented Jan 29, 2026

Copy link
Copy Markdown
Contributor

Sounds good, logic makes totally sense to me. Good to know inference mode setup_context won't be called

@goldhuang

Copy link
Copy Markdown

@shuhuayu With the changes in the PR, can I train without selective activation checkpointing on deepep? The extra memory cost from SAC is huge. It may not be a good idea to enable SAC in some cases.

@elfiegg

elfiegg commented Jan 29, 2026

Copy link
Copy Markdown
Contributor

@goldhuang You can configure --activation_checkpoint.mode=full and --parallelism.expert_parallel_comm_backend=deepep to turn on full checkpointing with deepep being all-to-all backend

@goldhuang

Copy link
Copy Markdown

@shuhuayu @elfiegg My point is that deepep integration in current main branch also has memory leak when you use it without SAC (meaning it's running in the pattern of fwd-fwd-bwd). You may want to make sure your changes also cover the fwd-fwd-bwd case.
The current main branch can only do fwd-bwd without a leak. Both fwd and fwd-fwd-bwd will cause a leak. I only reported the fwd case earlier.

@shuhuayu

Copy link
Copy Markdown
Contributor Author

@shuhuayu With the changes in the PR, can I train without selective activation checkpointing on deepep? The extra memory cost from SAC is huge. It may not be a good idea to enable SAC in some cases.

@goldhuang Thanks for the question.

  1. You can use full AC to save memory.
  2. Currently, during training, dispatch_op.ctx and combine_op.ctx save layout medata, i.e., the handle regardless AC configuration, which should be not big.
  3. If we use selective and op in AC, deepep ops are included in the op_sac_save_list, so their activations get saved. We can remove them from op_sac_save_list by commenting out these two lines:
    _op_sac_save_list.add(torch.ops.deepep.dispatch.default)
    _op_sac_save_list.add(torch.ops.deepep.combine.default)

I think 1 saves most memory and 3 saves memory specifically from deepep communications.

In my test of a deepseek 16b model on 16 h100s (seqlen=4096, bsz=4, fsdp=ep=8, pp=2, attention=sdpa, compile=loss, moe_communication=deepep, no moe force load balance), the results are:

  1. selective ac with op: mfu 11.46%, memory 27.5%.
  2. selective ac with op, but recompute deepep ops (bullet 3 above): mfu 11.5%, memory 27.38%
  3. full ac: mfu 12.8%, memory 24.5%.

So in these tests the savings from excluding deepep ops from sac save list are not significant in small scale.

@shuhuayu

shuhuayu commented Jan 29, 2026

Copy link
Copy Markdown
Contributor Author

@shuhuayu @elfiegg My point is that deepep integration in current main branch also has memory leak when you use it without SAC (meaning it's running in the pattern of fwd-fwd-bwd). You may want to make sure your changes also cover the fwd-fwd-bwd case. The current main branch can only do fwd-bwd without a leak. Both fwd and fwd-fwd-bwd will cause a leak. I only reported the fwd case earlier.

@goldhuang Thanks for pointing this out.

  1. fwd case should be fixed if you use torch.inference_mode().
  2. fwd-fwd-bwd case like in full ac is now covered in this pr and tested by running full ac test.

@shuhuayu shuhuayu merged commit 808cdf7 into pytorch:main Jan 29, 2026
25 checks passed
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
…ce (pytorch#2296)

1. Simplified the token permutation logic.
2. Updated the handle management so there will be no memory leak during
training and inference. Related issue:
pytorch#2273

On training on 16b deepseek v3 model, before the fix there was a growing
memory usage.
<img width="479" height="331" alt="image"
src="https://github.com/user-attachments/assets/12571963-47a5-4e13-b66a-1b213fc10d66"
/>
After the fix, the memory usage stabilizes. 
<img width="479" height="328" alt="image"
src="https://github.com/user-attachments/assets/9257c7ce-faf6-4330-a295-1ef1150d4ab0"
/>
ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
…ce (pytorch#2296)

1. Simplified the token permutation logic.
2. Updated the handle management so there will be no memory leak during
training and inference. Related issue:
pytorch#2273

On training on 16b deepseek v3 model, before the fix there was a growing
memory usage.
<img width="479" height="331" alt="image"
src="https://github.com/user-attachments/assets/12571963-47a5-4e13-b66a-1b213fc10d66"
/>
After the fix, the memory usage stabilizes. 
<img width="479" height="328" alt="image"
src="https://github.com/user-attachments/assets/9257c7ce-faf6-4330-a295-1ef1150d4ab0"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants