[MoE] DeepEP refactor and fix memory leak during training and inference#2296
Conversation
|
@elfiegg please take a look |
|
Looks good, is my understanding correct that this PR mainly implements:
|
|
FYI @goldhuang to unblock your work |
Yes. During training, handles are saved in both |
|
Sounds good, logic makes totally sense to me. Good to know inference mode |
|
@shuhuayu With the changes in the PR, can I train without selective activation checkpointing on deepep? The extra memory cost from SAC is huge. It may not be a good idea to enable SAC in some cases. |
|
@goldhuang You can configure |
|
@shuhuayu @elfiegg My point is that deepep integration in current main branch also has memory leak when you use it without SAC (meaning it's running in the pattern of |
@goldhuang Thanks for the question.
I think 1 saves most memory and 3 saves memory specifically from deepep communications. In my test of a deepseek 16b model on 16 h100s (seqlen=4096, bsz=4, fsdp=ep=8, pp=2, attention=sdpa, compile=loss, moe_communication=deepep, no moe force load balance), the results are:
So in these tests the savings from excluding deepep ops from sac save list are not significant in small scale. |
@goldhuang Thanks for pointing this out.
|
…ce (pytorch#2296) 1. Simplified the token permutation logic. 2. Updated the handle management so there will be no memory leak during training and inference. Related issue: pytorch#2273 On training on 16b deepseek v3 model, before the fix there was a growing memory usage. <img width="479" height="331" alt="image" src="https://github.com/user-attachments/assets/12571963-47a5-4e13-b66a-1b213fc10d66" /> After the fix, the memory usage stabilizes. <img width="479" height="328" alt="image" src="https://github.com/user-attachments/assets/9257c7ce-faf6-4330-a295-1ef1150d4ab0" />
…ce (pytorch#2296) 1. Simplified the token permutation logic. 2. Updated the handle management so there will be no memory leak during training and inference. Related issue: pytorch#2273 On training on 16b deepseek v3 model, before the fix there was a growing memory usage. <img width="479" height="331" alt="image" src="https://github.com/user-attachments/assets/12571963-47a5-4e13-b66a-1b213fc10d66" /> After the fix, the memory usage stabilizes. <img width="479" height="328" alt="image" src="https://github.com/user-attachments/assets/9257c7ce-faf6-4330-a295-1ef1150d4ab0" />
On training on 16b deepseek v3 model, before the fix there was a growing memory usage.


After the fix, the memory usage stabilizes.