Skip to content

sp + ep training / tp + ep inference#46292

Merged
3outeille merged 32 commits into
tp_param_levelfrom
sp_tp_ep_plan
Jun 16, 2026
Merged

sp + ep training / tp + ep inference#46292
3outeille merged 32 commits into
tp_param_levelfrom
sp_tp_ep_plan

Conversation

@3outeille

@3outeille 3outeille commented May 29, 2026

Copy link
Copy Markdown
Member
- SP=true,  EP=false    -> _sp_plan
- SP=false, EP=true     -> _tp_ep_plan (inference TP+EP)
- SP=true,  EP=true     -> _sp_ep_plan (training SP+EP)
- SP=false, EP=false    -> _tp_plan (experts keep moe_tp_*)
  • add test_tp_ep_forward, test_tp_ep_backward, test_tp_ep_generation, test_tp_ep_generation_quantized
  • add test_sp_ep_forward, test_sp_ep_backward
  • verify_tp_plan + verify_fsdp_plan
"""
torchrun --nproc_per_node=4 overfit_demo.py

The script overfit one sentence following the steps:
    - Train first half using FSDP=2+ SP&EP= 2
    - Save the model and optimizer in distributed checkpoint
    - Reload the model and optimizer from the distributed checkpoint
    - Train the rest in SP&EP=4 (change distributed config)
    - Save the model and optimizer in distributed checkpoint
    - Reload the model in a single safetensors file.
    - Do inference in TP&EP=4  and assert greedy generation reproduces the sentence verbatim
"""

import os
import shutil

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.distributed import DistributedConfig
from transformers.distributed.utils import (
    clip_grad_norm,
    load_optimizer_distributed,
    save_optimizer_distributed,
)

NAME = "Isotonic/TinyMixtral-4x248M-MoE"
TEXT = "In a quiet village nestled between rolling hills and a slow river, the autumn mornings arrived with mist that hung low over the fields and a sky that turned from grey to pale gold as the sun climbed."
CKPT = "./checkpoints"
OPT = os.path.join(CKPT, "optimizer")
STEPS = 10
HALF = STEPS // 2

rank, local_rank = int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
torch.manual_seed(0)

if rank == 0 and os.path.isdir(CKPT):
    shutil.rmtree(CKPT)
if torch.distributed.is_initialized():
    torch.distributed.barrier()

tokenizer = AutoTokenizer.from_pretrained(NAME)
ids = tokenizer(TEXT, return_tensors="pt").input_ids.to(f"cuda:{local_rank}")

# Train first half, distributed-save model + optimizer. (SP + EP / FSDP)
model = AutoModelForCausalLM.from_pretrained(
    NAME,
    distributed_config=DistributedConfig(tp_size=2, fsdp_size=2, enable_sequence_parallel=True, enable_expert_parallel=True),
    dtype=torch.bfloat16,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.train()
for step in range(0, HALF):
    loss = model(ids, labels=ids).loss
    loss.backward()
    total_norm = clip_grad_norm(model.parameters(), max_norm=1.0)
    optimizer.step()
    optimizer.zero_grad()
    if rank == 0:
        print(f"step {step:>2} | loss {loss.item():.5f} grad norm {total_norm.item():.5f}")

model.save_pretrained(CKPT, distributed_checkpoint=True)
save_optimizer_distributed(model, optimizer, OPT)
del model, optimizer
torch.cuda.empty_cache()

# Reload model + optimizer from the distributed checkpoint, train the rest.

model = AutoModelForCausalLM.from_pretrained(
    CKPT,
    distributed_config=DistributedConfig(tp_size=4, enable_sequence_parallel=True, enable_expert_parallel=True),
    dtype=torch.bfloat16,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
load_optimizer_distributed(model, optimizer, OPT)

model.train()
for step in range(HALF, STEPS):
    loss = model(ids, labels=ids).loss
    loss.backward()
    total_norm = clip_grad_norm(model.parameters(), max_norm=1.0)
    optimizer.step()
    optimizer.zero_grad()
    if rank == 0:
        print(f"step {step:>2} | loss {loss.item():.5f} grad norm {total_norm.item():.5f}")

# INFERENCE in TP=4
model.save_pretrained(CKPT)
save_optimizer_distributed(model, optimizer, OPT + "_tp4")
del model, optimizer
torch.cuda.empty_cache()

model = AutoModelForCausalLM.from_pretrained(
    CKPT,
    distributed_config=DistributedConfig(tp_size=4, enable_expert_parallel=True),
    dtype=torch.bfloat16,
)
model.eval()
prompt = tokenizer("In a quiet village", return_tensors="pt").to(f"cuda:{local_rank}")
out = model.generate(**prompt, max_new_tokens=ids.shape[-1] - prompt.input_ids.shape[-1], do_sample=False)

got, want = out[0].tolist(), ids[0].tolist()
if rank == 0:
    print(f"generated: {tokenizer.decode(got, skip_special_tokens=True)!r}")
    print(f"expected: {tokenizer.decode(want, skip_special_tokens=True)!r}")
assert got == want, (
    f"generation mismatch at index {next((i for i, (g, e) in enumerate(zip(got, want)) if g != e), -1)}"
)

torch.distributed.destroy_process_group()

3outeille added 5 commits May 29, 2026 22:52
Compose SP/TP dense recipes with an optional EP overlay and strip
intra-expert moe_tp_* when expert parallelism is enabled. Add unit tests
for training (SP+EP), inference (TP+EP), and TP-only paths.
Replace exclusive SP|EP|TP plan selection with merged plans when tp_plan
is unset. Add distributed test for TP+EP merged expert sharding.
Expose resolve_parallel_plan via PreTrainedModel.tp_plan and set
active_tp_plan during from_pretrained so checkpoint sharding matches
the applied layout.
Expert weight TP under sequence parallelism comes from the EP overlay
(grouped_gemm) when enable_expert_parallel is set; keep moe_tp_* only in
base_model_tp_plan for TP-only MoE.
Update expert_parallelism guide and DistributedConfig docs for merged
plans. Export resolve_parallel_plan and extend resolve-plan tests for
trimmed SP sources.
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille mentioned this pull request May 30, 2026
5 tasks
3outeille and others added 23 commits May 30, 2026 02:54
* remove _accumulate_local_param_grad

* comments

* linting

* fix

* clean _accumulate_local_param_grad

* linting

* cleaning

* cleaning

* fix mellun test because of bug in parsing sp_ep plan with regex

* aea
Introduce base_model_tp_ep_plan / base_model_sp_ep_plan on PreTrainedConfig,
select_parallel_plan() with legacy resolve_parallel_plan fallback, and wire
apply_tensor_parallel to use the selector. Model post_init tracks _tp_ep_plan
and _sp_ep_plan for composite models.
…wen3-MoE

Define complete inference TP+EP and training SP+EP plans on the pilot MoE
configs. Qwen3-MoE expands per-layer entries in _update_parallel_plans.
Add plan_utils and golden tests against legacy resolve_parallel_plan merge.
Populate combo plans via init_combo_plans() at config init time for MoE
configs that still use split tp/sp/ep recipes. Dynamic configs call it after
_update_sp_plan(); modular sources updated for generated configuration files.
Delete runtime plan merging; select_parallel_plan now requires a complete
combo dict and raises when missing. apply_tensor_parallel uses DistributedConfig
flags directly for SP/EP behavior. Drop model._ep_plan aggregation; load-time
verification checks the active plan only. Refresh combo plans after MXFP4
quantizer patches.
Propagate init_combo_plans from modular sources to generated configuration
files and document select_parallel_plan combo lookup in expert_parallelism.md.
Use explicit if/elif branches for the SP/EP flag matrix and derive
config_attr from plan_attr instead of parallel lookup dicts.
Define base_model_tp_ep_plan and base_model_sp_ep_plan directly in each MoE
configuration (or via config-time _update_parallel_plans for dynamic models).
Delete plan_utils.py and all init_combo_plans / refresh_combo_plans usage.
Explicit combo plan selection no longer merges _sp_plan with _ep_plan,
so head-level lm_head rules must live on _tp_ep_plan/_sp_ep_plan directly.
Fixes SP+EP training loss shape mismatch under sequence parallelism.
@3outeille 3outeille marked this pull request as ready for review June 15, 2026 05:42
@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe, apertus, arcee, aria, bamba, bitnet, cohere, cohere2, cohere2_moe, csm, cwm, dbrx, deepseek_v2, deepseek_v3, deepseek_v4, diffllama

@github-actions

Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=46292&sha=be3f54

@3outeille 3outeille merged commit d736cb6 into tp_param_level Jun 16, 2026
25 of 31 checks passed
@3outeille 3outeille deleted the sp_tp_ep_plan branch June 16, 2026 02:41
3outeille added a commit that referenced this pull request Jun 17, 2026
* [distributed] Add param-level MoE TP/EP styles and ep_router

Decompose MoE tensor/expert parallelism per review feedback: weight sharding
is declared per-parameter, while the experts module entry stays forward-comm only.

- MoEParamShard: parameter-only style wrapping named expert weights as DTensor
  placeholders (no forward hook). grouped_gemm shards the expert dim and updates
  module.num_experts to the per-rank local count.
- Register grouped_gemm (Shard(0)), moe_gate_up_colwise (_StridedShard(-2)),
  moe_gate_up_colwise_alt (_StridedShard(-1)), moe_down_rowwise (Shard(-1)).
- EpRouterParallel (ep_router): forward-only slicing of router outputs to local
  experts, ported from the original RouterParallel (#39501).
- moe_experts_allreduce is now forward-comm only: strip the baked shard_plan and
  drop the now-dead shard_plan ctor arg / _moe_shard_plan / shard_parameters
  override from MoEExpertsParallel; skip _AllReduceBackward on routing weights
  under EP.
- verify_tp_plan: treat moe_experts_allreduce / ep_router as forward-only.

* [distributed] Add param-level apply pass to apply_tensor_parallel

Run tensor parallelism in two passes:
- Pass 1 (param-level): walk named_parameters() and, for styles in PARAM_ONLY_STYLES
  (grouped_gemm, moe_gate_up_colwise[_alt], moe_down_rowwise), shard the parameter
  directly via shard_parameters(). No forward hook.
- Pass 2 (module-level): the existing named_modules() loop for forward hooks, now
  skipping PARAM_ONLY_STYLES.

Param sharding runs first so module forward hooks (moe_experts_allreduce) see the
already-sharded DTensor params. Also wire the EP-plan fallback so
enable_expert_parallel uses model._ep_plan when no explicit plan is passed.

* [distributed] Add MoE TP/EP plan tests and a two-sided sharding assertion

- New tests/distributed/test_moe_tensor_parallel_plan.py: plan resolution, placement
  expectations for grouped_gemm / moe_gate_up_colwise[_alt] / moe_down_rowwise, gloo
  distributed integration (EP Shard(0), TP _StridedShard(-2)+Shard(-1), ep_router
  slicing), and a registry guard that moe_experts_allreduce carries no baked shard plan.
- _verify_tp_sharding: add a two-sided check asserting that every parameter whose plan
  entry is a weight-sharding style actually comes back as a non-replicate DTensor. The
  prior check only validated params that happened to be sharded, so a style that
  gracefully degrades to replicated when unsharded (e.g. MoEExpertsParallel) could pass
  output-equality while silently running unparallelized.

* [distributed] Migrate MoE configs to decomposed TP/SP expert plans

For every TP/SP plan that sharded experts, declare per-parameter entries:
    "layers.*.mlp.experts.gate_up_proj": "moe_gate_up_colwise"
    "layers.*.mlp.experts.down_proj":    "moe_down_rowwise"
while keeping the forward-only "layers.*.mlp.experts": "moe_experts_allreduce".

This matches the now-empty moe_experts_allreduce shard_plan; sharding is declared in
config at parameter granularity. EP plans already used "grouped_gemm" and are unchanged.

hy_v3 and laguna previously used "packed_colwise" / "rowwise_allreduce" on the 3D expert
*parameters*; those styles are module-level and were silently no-ops on params (the bundled
shard_plan did the work). They now use the param-level moe_gate_up_colwise / moe_down_rowwise
like every other MoE model.

Edited modular files where they own the plan literal; generated configs and inherited
plans (e.g. from qwen3_moe) propagated via modular conversion.

* [distributed] Document decomposed MoE TP/EP plans

- expert_parallelism.md: describe the param-level decomposition (grouped_gemm,
  ep_router, moe_experts_allreduce) instead of the removed GroupedGemmParallel class,
  and note the TP equivalents (moe_gate_up_colwise / moe_down_rowwise).
- weightconverter.md: note that fused expert weights are sharded at parameter
  granularity by the parallel plan.

* [distributed] Rename MoE intra-expert TP styles to moe_tp_*

Rename registry and plan entries so TP-on-expert sharding is distinct
from EP (grouped_gemm) and dense packed_colwise: moe_gate_up_colwise ->
moe_tp_gate_up_colwise, moe_down_rowwise -> moe_tp_down_rowwise. Drop
unused moe_tp_gate_up_colwise_alt (GPT-OSS-style layouts stay EP-only).

* handle sparse and dense sp plan for qwen3_moe

* better tests coverage for sp & ep

* linting

* uniformize TP Api to avoid confusion with torch native ops

* inline tp

* rename

* cleaning

* inline

* cleaning

* cleaning

* linting

* fix ci ep_backward

* linting

* remove flag expert parallel

* fix

* add tp plan + ep_plan

* revert doc

* fix install_forward

* linting

* add moe identity back

* no need aymore

* update tp_plan for ernie4_5_vl_moe

* sp + ep training / tp + ep inference (#46292)

* [distributed] Add resolve_parallel_plan merge helper

Compose SP/TP dense recipes with an optional EP overlay and strip
intra-expert moe_tp_* when expert parallelism is enabled. Add unit tests
for training (SP+EP), inference (TP+EP), and TP-only paths.

* [distributed] Wire resolve_parallel_plan into apply_tensor_parallel

Replace exclusive SP|EP|TP plan selection with merged plans when tp_plan
is unset. Add distributed test for TP+EP merged expert sharding.

* [distributed] Use merged plan in tp_plan property and load path

Expose resolve_parallel_plan via PreTrainedModel.tp_plan and set
active_tp_plan during from_pretrained so checkpoint sharding matches
the applied layout.

* [distributed] Drop intra-expert moe_tp_* from MoE SP plans

Expert weight TP under sequence parallelism comes from the EP overlay
(grouped_gemm) when enable_expert_parallel is set; keep moe_tp_* only in
base_model_tp_plan for TP-only MoE.

* [distributed] Document SP+EP and TP+EP flag combinations

Update expert_parallelism guide and DistributedConfig docs for merged
plans. Export resolve_parallel_plan and extend resolve-plan tests for
trimmed SP sources.

* refactor merging plans

* add test sp_ep and tp_ep

* extend verify_tp_plan to verify_tp_sp_ep_plan

* add ep_plan to mixtral and olmoe

* cleaning _accumulate_local_param_grad (#46394)

* remove _accumulate_local_param_grad

* comments

* linting

* fix

* clean _accumulate_local_param_grad

* linting

* cleaning

* cleaning

* fix mellun test because of bug in parsing sp_ep plan with regex

* aea

* Add select_parallel_plan and explicit combo plan config fields

Introduce base_model_tp_ep_plan / base_model_sp_ep_plan on PreTrainedConfig,
select_parallel_plan() with legacy resolve_parallel_plan fallback, and wire
apply_tensor_parallel to use the selector. Model post_init tracks _tp_ep_plan
and _sp_ep_plan for composite models.

* Add base_model_tp_ep_plan and base_model_sp_ep_plan for Mixtral and Qwen3-MoE

Define complete inference TP+EP and training SP+EP plans on the pilot MoE
configs. Qwen3-MoE expands per-layer entries in _update_parallel_plans.
Add plan_utils and golden tests against legacy resolve_parallel_plan merge.

* Add explicit tp_ep / sp_ep plans for remaining MoE models

Populate combo plans via init_combo_plans() at config init time for MoE
configs that still use split tp/sp/ep recipes. Dynamic configs call it after
_update_sp_plan(); modular sources updated for generated configuration files.

* Remove resolve_parallel_plan and use explicit combo plan selection

Delete runtime plan merging; select_parallel_plan now requires a complete
combo dict and raises when missing. apply_tensor_parallel uses DistributedConfig
flags directly for SP/EP behavior. Drop model._ep_plan aggregation; load-time
verification checks the active plan only. Refresh combo plans after MXFP4
quantizer patches.

* Sync modular MoE configs and update expert parallelism docs

Propagate init_combo_plans from modular sources to generated configuration
files and document select_parallel_plan combo lookup in expert_parallelism.md.

* Refactor select_parallel_plan flag lookup for readability

Use explicit if/elif branches for the SP/EP flag matrix and derive
config_attr from plan_attr instead of parallel lookup dicts.

* Write explicit combo parallel plans in MoE configs and remove plan_utils

Define base_model_tp_ep_plan and base_model_sp_ep_plan directly in each MoE
configuration (or via config-time _update_parallel_plans for dynamic models).
Delete plan_utils.py and all init_combo_plans / refresh_combo_plans usage.

* Add lm_head entries to _tp_ep_plan and _sp_ep_plan on CausalLM classes

Explicit combo plan selection no longer merges _sp_plan with _ep_plan,
so head-level lm_head rules must live on _tp_ep_plan/_sp_ep_plan directly.
Fixes SP+EP training loss shape mismatch under sequence parallelism.

* cleaning

* cleaning

* cleaning

* cleaning

* linting

* add verify tp and fsdp pla

aeaea

* revert doc

* cleaning

* check-repository-consistency

* linting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants