Skip to content

MoE expert parallelism + sequence parallelism#45408

Merged
3outeille merged 10 commits into
refactor-tp-dtensorfrom
moe-sequence-parallel
Apr 14, 2026
Merged

MoE expert parallelism + sequence parallelism#45408
3outeille merged 10 commits into
refactor-tp-dtensorfrom
moe-sequence-parallel

Conversation

@3outeille

Copy link
Copy Markdown
Member

Summary

  • Extends the TPStyle API (from FSDP + TP & native save/load distributed #45028) with MoE expert parallelism and sequence parallelism support
  • Adds PackedColwiseParallel, MoEExpertsParallel, PrepareModuleInputOutput, _AllReduceBackward custom ParallelStyle subclasses
  • Extends TPStyle with moe_experts, packed_colwise, activation, module, loss_parallel kinds
  • _StridedShard handling in core_model_loading.py for interleaved gate_up_proj weights
  • MoE model configs for mixtral, deepseek_v3, qwen3 with sequence parallelism plans

Part of the distributed training API chain: #44989

Chain: main ← #44989 ← #44083 ← #44974 ← #45028 ← this PR ← orchestration+save PR

Review question

Are the custom ParallelStyle subclasses correct for expert sharding + sequence parallelism?

Test plan

  • Verify MoE expert sharding produces correct DTensor placements
  • Test sequence parallelism with allgather/split hooks
  • Run existing TP mixin tests to ensure no regression

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral
@3outeille 3outeille force-pushed the moe-sequence-parallel branch from 24ca327 to 7f297e0 Compare April 14, 2026 13:44
3outeille and others added 6 commits April 14, 2026 14:24
# Conflicts:
#	src/transformers/integrations/tensor_parallel.py
# Conflicts:
#	src/transformers/integrations/tensor_parallel.py
The _IdentityOp class (added by PR #44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@3outeille 3outeille force-pushed the moe-sequence-parallel branch from 5031188 to 01866b8 Compare April 14, 2026 15:37
@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: deepseek_v3, dots1, mixtral, nanochat, qwen3, qwen3_5, qwen3_5_moe, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, youtu

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff
@3outeille 3outeille merged commit 7ca7911 into refactor-tp-dtensor Apr 14, 2026
12 of 28 checks passed
@3outeille 3outeille deleted the moe-sequence-parallel branch April 14, 2026 16:12
@github-actions

Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45408&sha=bbf3ab

vasqu pushed a commit to zhang-prog/transformers that referenced this pull request May 19, 2026
* init

* FSDP2 (fully_shard) integration

- Add apply_fully_shard_data_parallel() with auto/manual mode block detection
- FSDP vs DDP loss/grad parity tests
- Distributed test helpers (testing_utils.py)
- is_fsdp_enabled(), is_fsdp_managed_module() utilities
- Minimal FSDP hooks in from_pretrained
- FSDP-aware flash attention check

* DistributedConfig + shard-on-read loading

- DtensorShardOperation for range-math shard-on-read
- spawn_materialize() enhancements
- from_pretrained wiring for distributed config
- Shard operation helpers in tensor_parallel
- Shard-on-read and LoadStateDictConfig tests

* TPStyle API + dense model tensor parallelism

- Replace hook-based TP with DTensor-based TPStyle API
- TPStyle dataclass with dense kinds: colwise, rowwise, vocab
- apply_tensor_parallel() using PyTorch parallelize_module
- verify_tp_plan() for plan validation
- Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle
- DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3
- Extended DistributedConfig with tp/fsdp size and plan fields
- DistributedConfig serialization in configuration_utils
- MXFP4 NotImplementedError for DTensor TP
- Dense TP tests

* revert some files

* Add distributed training scripts

- train_fsdp_tp.py: minimal FSDP+TP training example
- train_fsdp_tp_torchtitan_style.py: torchtitan-style training example
- verify_loading.py: save/load roundtrip verification
- run_compare.sh: FSDP+TP vs FSDP-only comparison
- run_verify_all.sh: run verification across all modes
- tmp_generate.py: quick generation test

* Remove train_fsdp_tp_torchtitan_style.py

* unify the utils for fsdp

* Fix CI: re-export moved FSDP utils + remove stale type: ignore

- Re-export is_fsdp_enabled and is_fsdp_managed_module from
  integrations/fsdp.py (moved to distributed/utils.py)
- Remove unused # type: ignore comments in generation/utils.py

* Fix ruff formatting in core_model_loading.py

* Fix ruff linting and formatting

* Backport new TP/FSDP API from orchestration-save-load branch

* Fix DTensor imports in Copied-from model files

* MoE expert parallelism + sequence parallelism (huggingface#45408)

* MoE expert parallelism + sequence parallelism

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral

* Fix ruff linting and formatting

* Fix ruff formatting in core_model_loading.py

* Restore _IdentityOp accidentally removed in 25a1f48

The _IdentityOp class (added by PR huggingface#44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Backport new TP/FSDP API + fix DTensor imports in Copied-from models

* from_pretrained orchestration + distributed save/load (huggingface#45409)

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* do monkey patching for rotary

* Revert modeling file diffs to match fsdp-core-model-loading base

Restores modeling files to their base branch versions so the PR diff
only shows the distributed/patches.py monkey-patch approach instead of
noisy function moves in modeling files.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Migrate all model TP plans from strings to TPStyle

- Convert string plan values ("colwise", "rowwise", etc.) to TPStyle
  objects across 66+ model configs and modular files
- Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...)
  with shard_plan
- Remove "replicated_with_grad_allreduce" entries (not needed for
  DTensor TP)
- Migrate _tp_plan class attributes in modeling files from
  "colwise_gather_output" to TPStyle("colwise", "allgather")
- Add TypeError in apply_tensor_parallel for unsupported plan values
- Remove old TensorParallelLayer tests (API removed in DTensor refactor)
- Regenerate auto-generated files via modular converter

* Restore mxfp4.py to match base branch

* Drop mla_kv_a_proj and moe_identity_expert from TP plans

These string plan values have no TPStyle equivalent in the DTensor
system. Remove them to avoid TypeError at apply_tensor_parallel time.
Affected models: deepseek_v2, glm4_moe_lite, glm_moe_dsa, longcat_flash.

* more comments

* fix tp for most models.  PyTorch doesn't implement all placement conversions (e.g. _StridedShard↔Shard). We force replicate beforehand

* fix tp through _replicate_dtensor

* revert small change

* push temporary fix for TP and strided shard for backward

* refactor a bit

* patches for rotary

* refactor MoEExpertsParallel

* fix tp for last models

* refactor moe expert parallels

* linting

* add sp plan for models

* add deepseek v2 sp plan

* undo sp plan for some tricky models

* remove lm_head from  config

* first pass of refactoring dtensor shard operator

* better refacto

* batter explanation of DtensorShardOperation

* refactor dtensor test to reflect real world scenario

* more comments

* fix tp olmo hybrid and exaone

* Enhance tensor parallel weight tying logic to prevent clobbering of lm_head when embed_tokens is not in the plan.

* fix fsdp mixin test due to missing args

* fix test non model

* skip sp plan for exaone and olmo hybrid

* linting

* fix import for ci

* test distributed config

* attempt to fix guarding import ci

* fix ci check repro

* add ALL_PARALLEL_STYLES registry alongside TPStyle

* route apply_tensor_parallel through ALL_PARALLEL_STYLES

* migrate modular files to string-based TP plans

* migrate standalone configs and modelings to string-based TP plans

* delete TPStyle dataclass

* fix use_local_output defaults for SequenceParallel and PrepareModuleInput in registry

* use parallel style from torch

* revert changes in weight converter

* remove dead code in set_param_for_module

* remove dead code

* cleaning again

* cleaning

* revert change

* linting

* refactor dtensor shard ops

* revert some stuff in core model loading

* core model loading clean

* guarding import

* better separation tensor parall and generic utils

* isolate DtensorShardOperation into a separate file

* no need to patch rotary

* better seperation

* simplify gather_full_state_dict

* simplify _replicate_dtensor

* fix and clean _replicate_dtensor

* better doc for DtensorShardOperation

* fix saving optimizer with DCP for fused weights

* save_pretrained(distributed_checkpoint=true)

* linting

* refactor into a single function _dtensor_from_local_like

* zeros_like instead of empty_like

* move tp and fsdp under distributed

* distribute_model

* fix deadlock when saving

* clip grad norm function

* maybe_disable_foreach_and_fused_for_mixed_dtensor_groups

* better TP api for ease of understanding

* remove shard_param to make it easier

* fix import in test

* _swap_dtensor_params_for_local

* fix qwen3 nanochat dots1

* add tpu

* move TP refactor experimentation scripts to backup branch

Move ad-hoc training / verification / compare scripts off this branch
into refactor-tp-dtensor-scripts so the diff stays focused on library
changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* linting

* register distributed sharding_utils and utils in __init__

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* rename TP plan styles to match new ALL_PARALLEL_STYLES registry

Replace pre-refactor names that no longer exist in
src/transformers/distributed/tensor_parallel.py:
  rowwise -> rowwise_allreduce
  moe_tp_experts -> moe_experts_allreduce
  replicated_with_grad_allreduce -> activation_seq_dim_2

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* enable EP

* Add enable_expert_parallel configuration option in test_distributed_config

* no more auto mode

* edit fsdp plan to every other models

* update fsdp mixin tests

* linting

* fix test fsdp

* fsdp linting

* revert gitignore

* _apply within for loop

* rename

* doc sp plan

* fix

* unified settattr + torch no grad + _local_tensor

* revert

* linting

* fix ruff

* make check-repository-consistency

* trigger fsdp mixin test in CI

* fix fsdp ci

* Reset tests/test_modeling_common.py to main

Restores legitimate improvements that were accidentally undone during a
stale merge of main into fsdp-vs-ddp:

- Restore test_resize_embeddings_untied_no_reinit_on_post_init
- Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo
  special-cases
- Restore skip_base_model parameter on test_reverse_loading_mapping
- Restore "is not None" guard on subconfig in test_initialization
- Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
* init

* FSDP2 (fully_shard) integration

- Add apply_fully_shard_data_parallel() with auto/manual mode block detection
- FSDP vs DDP loss/grad parity tests
- Distributed test helpers (testing_utils.py)
- is_fsdp_enabled(), is_fsdp_managed_module() utilities
- Minimal FSDP hooks in from_pretrained
- FSDP-aware flash attention check

* DistributedConfig + shard-on-read loading

- DtensorShardOperation for range-math shard-on-read
- spawn_materialize() enhancements
- from_pretrained wiring for distributed config
- Shard operation helpers in tensor_parallel
- Shard-on-read and LoadStateDictConfig tests

* TPStyle API + dense model tensor parallelism

- Replace hook-based TP with DTensor-based TPStyle API
- TPStyle dataclass with dense kinds: colwise, rowwise, vocab
- apply_tensor_parallel() using PyTorch parallelize_module
- verify_tp_plan() for plan validation
- Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle
- DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3
- Extended DistributedConfig with tp/fsdp size and plan fields
- DistributedConfig serialization in configuration_utils
- MXFP4 NotImplementedError for DTensor TP
- Dense TP tests

* revert some files

* Add distributed training scripts

- train_fsdp_tp.py: minimal FSDP+TP training example
- train_fsdp_tp_torchtitan_style.py: torchtitan-style training example
- verify_loading.py: save/load roundtrip verification
- run_compare.sh: FSDP+TP vs FSDP-only comparison
- run_verify_all.sh: run verification across all modes
- tmp_generate.py: quick generation test

* Remove train_fsdp_tp_torchtitan_style.py

* unify the utils for fsdp

* Fix CI: re-export moved FSDP utils + remove stale type: ignore

- Re-export is_fsdp_enabled and is_fsdp_managed_module from
  integrations/fsdp.py (moved to distributed/utils.py)
- Remove unused # type: ignore comments in generation/utils.py

* Fix ruff formatting in core_model_loading.py

* Fix ruff linting and formatting

* Backport new TP/FSDP API from orchestration-save-load branch

* Fix DTensor imports in Copied-from model files

* MoE expert parallelism + sequence parallelism (huggingface#45408)

* MoE expert parallelism + sequence parallelism

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral

* Fix ruff linting and formatting

* Fix ruff formatting in core_model_loading.py

* Restore _IdentityOp accidentally removed in 25a1f48

The _IdentityOp class (added by PR huggingface#44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Backport new TP/FSDP API + fix DTensor imports in Copied-from models

* from_pretrained orchestration + distributed save/load (huggingface#45409)

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* do monkey patching for rotary

* Revert modeling file diffs to match fsdp-core-model-loading base

Restores modeling files to their base branch versions so the PR diff
only shows the distributed/patches.py monkey-patch approach instead of
noisy function moves in modeling files.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Migrate all model TP plans from strings to TPStyle

- Convert string plan values ("colwise", "rowwise", etc.) to TPStyle
  objects across 66+ model configs and modular files
- Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...)
  with shard_plan
- Remove "replicated_with_grad_allreduce" entries (not needed for
  DTensor TP)
- Migrate _tp_plan class attributes in modeling files from
  "colwise_gather_output" to TPStyle("colwise", "allgather")
- Add TypeError in apply_tensor_parallel for unsupported plan values
- Remove old TensorParallelLayer tests (API removed in DTensor refactor)
- Regenerate auto-generated files via modular converter

* Restore mxfp4.py to match base branch

* Drop mla_kv_a_proj and moe_identity_expert from TP plans

These string plan values have no TPStyle equivalent in the DTensor
system. Remove them to avoid TypeError at apply_tensor_parallel time.
Affected models: deepseek_v2, glm4_moe_lite, glm_moe_dsa, longcat_flash.

* more comments

* fix tp for most models.  PyTorch doesn't implement all placement conversions (e.g. _StridedShard↔Shard). We force replicate beforehand

* fix tp through _replicate_dtensor

* revert small change

* push temporary fix for TP and strided shard for backward

* refactor a bit

* patches for rotary

* refactor MoEExpertsParallel

* fix tp for last models

* refactor moe expert parallels

* linting

* add sp plan for models

* add deepseek v2 sp plan

* undo sp plan for some tricky models

* remove lm_head from  config

* first pass of refactoring dtensor shard operator

* better refacto

* batter explanation of DtensorShardOperation

* refactor dtensor test to reflect real world scenario

* more comments

* fix tp olmo hybrid and exaone

* Enhance tensor parallel weight tying logic to prevent clobbering of lm_head when embed_tokens is not in the plan.

* fix fsdp mixin test due to missing args

* fix test non model

* skip sp plan for exaone and olmo hybrid

* linting

* fix import for ci

* test distributed config

* attempt to fix guarding import ci

* fix ci check repro

* add ALL_PARALLEL_STYLES registry alongside TPStyle

* route apply_tensor_parallel through ALL_PARALLEL_STYLES

* migrate modular files to string-based TP plans

* migrate standalone configs and modelings to string-based TP plans

* delete TPStyle dataclass

* fix use_local_output defaults for SequenceParallel and PrepareModuleInput in registry

* use parallel style from torch

* revert changes in weight converter

* remove dead code in set_param_for_module

* remove dead code

* cleaning again

* cleaning

* revert change

* linting

* refactor dtensor shard ops

* revert some stuff in core model loading

* core model loading clean

* guarding import

* better separation tensor parall and generic utils

* isolate DtensorShardOperation into a separate file

* no need to patch rotary

* better seperation

* simplify gather_full_state_dict

* simplify _replicate_dtensor

* fix and clean _replicate_dtensor

* better doc for DtensorShardOperation

* fix saving optimizer with DCP for fused weights

* save_pretrained(distributed_checkpoint=true)

* linting

* refactor into a single function _dtensor_from_local_like

* zeros_like instead of empty_like

* move tp and fsdp under distributed

* distribute_model

* fix deadlock when saving

* clip grad norm function

* maybe_disable_foreach_and_fused_for_mixed_dtensor_groups

* better TP api for ease of understanding

* remove shard_param to make it easier

* fix import in test

* _swap_dtensor_params_for_local

* fix qwen3 nanochat dots1

* add tpu

* move TP refactor experimentation scripts to backup branch

Move ad-hoc training / verification / compare scripts off this branch
into refactor-tp-dtensor-scripts so the diff stays focused on library
changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* linting

* register distributed sharding_utils and utils in __init__

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* rename TP plan styles to match new ALL_PARALLEL_STYLES registry

Replace pre-refactor names that no longer exist in
src/transformers/distributed/tensor_parallel.py:
  rowwise -> rowwise_allreduce
  moe_tp_experts -> moe_experts_allreduce
  replicated_with_grad_allreduce -> activation_seq_dim_2

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* enable EP

* Add enable_expert_parallel configuration option in test_distributed_config

* no more auto mode

* edit fsdp plan to every other models

* update fsdp mixin tests

* linting

* fix test fsdp

* fsdp linting

* revert gitignore

* _apply within for loop

* rename

* doc sp plan

* fix

* unified settattr + torch no grad + _local_tensor

* revert

* linting

* fix ruff

* make check-repository-consistency

* trigger fsdp mixin test in CI

* fix fsdp ci

* Reset tests/test_modeling_common.py to main

Restores legitimate improvements that were accidentally undone during a
stale merge of main into fsdp-vs-ddp:

- Restore test_resize_embeddings_untied_no_reinit_on_post_init
- Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo
  special-cases
- Restore skip_base_model parameter on test_reverse_loading_mapping
- Restore "is not None" guard on subconfig in test_initialization
- Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
* init

* FSDP2 (fully_shard) integration

- Add apply_fully_shard_data_parallel() with auto/manual mode block detection
- FSDP vs DDP loss/grad parity tests
- Distributed test helpers (testing_utils.py)
- is_fsdp_enabled(), is_fsdp_managed_module() utilities
- Minimal FSDP hooks in from_pretrained
- FSDP-aware flash attention check

* DistributedConfig + shard-on-read loading

- DtensorShardOperation for range-math shard-on-read
- spawn_materialize() enhancements
- from_pretrained wiring for distributed config
- Shard operation helpers in tensor_parallel
- Shard-on-read and LoadStateDictConfig tests

* TPStyle API + dense model tensor parallelism

- Replace hook-based TP with DTensor-based TPStyle API
- TPStyle dataclass with dense kinds: colwise, rowwise, vocab
- apply_tensor_parallel() using PyTorch parallelize_module
- verify_tp_plan() for plan validation
- Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle
- DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3
- Extended DistributedConfig with tp/fsdp size and plan fields
- DistributedConfig serialization in configuration_utils
- MXFP4 NotImplementedError for DTensor TP
- Dense TP tests

* revert some files

* Add distributed training scripts

- train_fsdp_tp.py: minimal FSDP+TP training example
- train_fsdp_tp_torchtitan_style.py: torchtitan-style training example
- verify_loading.py: save/load roundtrip verification
- run_compare.sh: FSDP+TP vs FSDP-only comparison
- run_verify_all.sh: run verification across all modes
- tmp_generate.py: quick generation test

* Remove train_fsdp_tp_torchtitan_style.py

* unify the utils for fsdp

* Fix CI: re-export moved FSDP utils + remove stale type: ignore

- Re-export is_fsdp_enabled and is_fsdp_managed_module from
  integrations/fsdp.py (moved to distributed/utils.py)
- Remove unused # type: ignore comments in generation/utils.py

* Fix ruff formatting in core_model_loading.py

* Fix ruff linting and formatting

* Backport new TP/FSDP API from orchestration-save-load branch

* Fix DTensor imports in Copied-from model files

* MoE expert parallelism + sequence parallelism (huggingface#45408)

* MoE expert parallelism + sequence parallelism

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral

* Fix ruff linting and formatting

* Fix ruff formatting in core_model_loading.py

* Restore _IdentityOp accidentally removed in 25a1f48

The _IdentityOp class (added by PR huggingface#44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Backport new TP/FSDP API + fix DTensor imports in Copied-from models

* from_pretrained orchestration + distributed save/load (huggingface#45409)

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* do monkey patching for rotary

* Revert modeling file diffs to match fsdp-core-model-loading base

Restores modeling files to their base branch versions so the PR diff
only shows the distributed/patches.py monkey-patch approach instead of
noisy function moves in modeling files.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Migrate all model TP plans from strings to TPStyle

- Convert string plan values ("colwise", "rowwise", etc.) to TPStyle
  objects across 66+ model configs and modular files
- Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...)
  with shard_plan
- Remove "replicated_with_grad_allreduce" entries (not needed for
  DTensor TP)
- Migrate _tp_plan class attributes in modeling files from
  "colwise_gather_output" to TPStyle("colwise", "allgather")
- Add TypeError in apply_tensor_parallel for unsupported plan values
- Remove old TensorParallelLayer tests (API removed in DTensor refactor)
- Regenerate auto-generated files via modular converter

* Restore mxfp4.py to match base branch

* Drop mla_kv_a_proj and moe_identity_expert from TP plans

These string plan values have no TPStyle equivalent in the DTensor
system. Remove them to avoid TypeError at apply_tensor_parallel time.
Affected models: deepseek_v2, glm4_moe_lite, glm_moe_dsa, longcat_flash.

* more comments

* fix tp for most models.  PyTorch doesn't implement all placement conversions (e.g. _StridedShard↔Shard). We force replicate beforehand

* fix tp through _replicate_dtensor

* revert small change

* push temporary fix for TP and strided shard for backward

* refactor a bit

* patches for rotary

* refactor MoEExpertsParallel

* fix tp for last models

* refactor moe expert parallels

* linting

* add sp plan for models

* add deepseek v2 sp plan

* undo sp plan for some tricky models

* remove lm_head from  config

* first pass of refactoring dtensor shard operator

* better refacto

* batter explanation of DtensorShardOperation

* refactor dtensor test to reflect real world scenario

* more comments

* fix tp olmo hybrid and exaone

* Enhance tensor parallel weight tying logic to prevent clobbering of lm_head when embed_tokens is not in the plan.

* fix fsdp mixin test due to missing args

* fix test non model

* skip sp plan for exaone and olmo hybrid

* linting

* fix import for ci

* test distributed config

* attempt to fix guarding import ci

* fix ci check repro

* add ALL_PARALLEL_STYLES registry alongside TPStyle

* route apply_tensor_parallel through ALL_PARALLEL_STYLES

* migrate modular files to string-based TP plans

* migrate standalone configs and modelings to string-based TP plans

* delete TPStyle dataclass

* fix use_local_output defaults for SequenceParallel and PrepareModuleInput in registry

* use parallel style from torch

* revert changes in weight converter

* remove dead code in set_param_for_module

* remove dead code

* cleaning again

* cleaning

* revert change

* linting

* refactor dtensor shard ops

* revert some stuff in core model loading

* core model loading clean

* guarding import

* better separation tensor parall and generic utils

* isolate DtensorShardOperation into a separate file

* no need to patch rotary

* better seperation

* simplify gather_full_state_dict

* simplify _replicate_dtensor

* fix and clean _replicate_dtensor

* better doc for DtensorShardOperation

* fix saving optimizer with DCP for fused weights

* save_pretrained(distributed_checkpoint=true)

* linting

* refactor into a single function _dtensor_from_local_like

* zeros_like instead of empty_like

* move tp and fsdp under distributed

* distribute_model

* fix deadlock when saving

* clip grad norm function

* maybe_disable_foreach_and_fused_for_mixed_dtensor_groups

* better TP api for ease of understanding

* remove shard_param to make it easier

* fix import in test

* _swap_dtensor_params_for_local

* fix qwen3 nanochat dots1

* add tpu

* move TP refactor experimentation scripts to backup branch

Move ad-hoc training / verification / compare scripts off this branch
into refactor-tp-dtensor-scripts so the diff stays focused on library
changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* linting

* register distributed sharding_utils and utils in __init__

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* rename TP plan styles to match new ALL_PARALLEL_STYLES registry

Replace pre-refactor names that no longer exist in
src/transformers/distributed/tensor_parallel.py:
  rowwise -> rowwise_allreduce
  moe_tp_experts -> moe_experts_allreduce
  replicated_with_grad_allreduce -> activation_seq_dim_2

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* enable EP

* Add enable_expert_parallel configuration option in test_distributed_config

* no more auto mode

* edit fsdp plan to every other models

* update fsdp mixin tests

* linting

* fix test fsdp

* fsdp linting

* revert gitignore

* _apply within for loop

* rename

* doc sp plan

* fix

* unified settattr + torch no grad + _local_tensor

* revert

* linting

* fix ruff

* make check-repository-consistency

* trigger fsdp mixin test in CI

* fix fsdp ci

* Reset tests/test_modeling_common.py to main

Restores legitimate improvements that were accidentally undone during a
stale merge of main into fsdp-vs-ddp:

- Restore test_resize_embeddings_untied_no_reinit_on_post_init
- Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo
  special-cases
- Restore skip_base_model parameter on test_reverse_loading_mapping
- Restore "is not None" guard on subconfig in test_initialization
- Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
khushali9 pushed a commit to khushali9/transformers that referenced this pull request Jun 8, 2026
* init

* FSDP2 (fully_shard) integration

- Add apply_fully_shard_data_parallel() with auto/manual mode block detection
- FSDP vs DDP loss/grad parity tests
- Distributed test helpers (testing_utils.py)
- is_fsdp_enabled(), is_fsdp_managed_module() utilities
- Minimal FSDP hooks in from_pretrained
- FSDP-aware flash attention check

* DistributedConfig + shard-on-read loading

- DtensorShardOperation for range-math shard-on-read
- spawn_materialize() enhancements
- from_pretrained wiring for distributed config
- Shard operation helpers in tensor_parallel
- Shard-on-read and LoadStateDictConfig tests

* TPStyle API + dense model tensor parallelism

- Replace hook-based TP with DTensor-based TPStyle API
- TPStyle dataclass with dense kinds: colwise, rowwise, vocab
- apply_tensor_parallel() using PyTorch parallelize_module
- verify_tp_plan() for plan validation
- Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle
- DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3
- Extended DistributedConfig with tp/fsdp size and plan fields
- DistributedConfig serialization in configuration_utils
- MXFP4 NotImplementedError for DTensor TP
- Dense TP tests

* revert some files

* Add distributed training scripts

- train_fsdp_tp.py: minimal FSDP+TP training example
- train_fsdp_tp_torchtitan_style.py: torchtitan-style training example
- verify_loading.py: save/load roundtrip verification
- run_compare.sh: FSDP+TP vs FSDP-only comparison
- run_verify_all.sh: run verification across all modes
- tmp_generate.py: quick generation test

* Remove train_fsdp_tp_torchtitan_style.py

* unify the utils for fsdp

* Fix CI: re-export moved FSDP utils + remove stale type: ignore

- Re-export is_fsdp_enabled and is_fsdp_managed_module from
  integrations/fsdp.py (moved to distributed/utils.py)
- Remove unused # type: ignore comments in generation/utils.py

* Fix ruff formatting in core_model_loading.py

* Fix ruff linting and formatting

* Backport new TP/FSDP API from orchestration-save-load branch

* Fix DTensor imports in Copied-from model files

* MoE expert parallelism + sequence parallelism (huggingface#45408)

* MoE expert parallelism + sequence parallelism

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral

* Fix ruff linting and formatting

* Fix ruff formatting in core_model_loading.py

* Restore _IdentityOp accidentally removed in 25a1f48

The _IdentityOp class (added by PR huggingface#44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Backport new TP/FSDP API + fix DTensor imports in Copied-from models

* from_pretrained orchestration + distributed save/load (huggingface#45409)

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* do monkey patching for rotary

* Revert modeling file diffs to match fsdp-core-model-loading base

Restores modeling files to their base branch versions so the PR diff
only shows the distributed/patches.py monkey-patch approach instead of
noisy function moves in modeling files.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Migrate all model TP plans from strings to TPStyle

- Convert string plan values ("colwise", "rowwise", etc.) to TPStyle
  objects across 66+ model configs and modular files
- Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...)
  with shard_plan
- Remove "replicated_with_grad_allreduce" entries (not needed for
  DTensor TP)
- Migrate _tp_plan class attributes in modeling files from
  "colwise_gather_output" to TPStyle("colwise", "allgather")
- Add TypeError in apply_tensor_parallel for unsupported plan values
- Remove old TensorParallelLayer tests (API removed in DTensor refactor)
- Regenerate auto-generated files via modular converter

* Restore mxfp4.py to match base branch

* Drop mla_kv_a_proj and moe_identity_expert from TP plans

These string plan values have no TPStyle equivalent in the DTensor
system. Remove them to avoid TypeError at apply_tensor_parallel time.
Affected models: deepseek_v2, glm4_moe_lite, glm_moe_dsa, longcat_flash.

* more comments

* fix tp for most models.  PyTorch doesn't implement all placement conversions (e.g. _StridedShard↔Shard). We force replicate beforehand

* fix tp through _replicate_dtensor

* revert small change

* push temporary fix for TP and strided shard for backward

* refactor a bit

* patches for rotary

* refactor MoEExpertsParallel

* fix tp for last models

* refactor moe expert parallels

* linting

* add sp plan for models

* add deepseek v2 sp plan

* undo sp plan for some tricky models

* remove lm_head from  config

* first pass of refactoring dtensor shard operator

* better refacto

* batter explanation of DtensorShardOperation

* refactor dtensor test to reflect real world scenario

* more comments

* fix tp olmo hybrid and exaone

* Enhance tensor parallel weight tying logic to prevent clobbering of lm_head when embed_tokens is not in the plan.

* fix fsdp mixin test due to missing args

* fix test non model

* skip sp plan for exaone and olmo hybrid

* linting

* fix import for ci

* test distributed config

* attempt to fix guarding import ci

* fix ci check repro

* add ALL_PARALLEL_STYLES registry alongside TPStyle

* route apply_tensor_parallel through ALL_PARALLEL_STYLES

* migrate modular files to string-based TP plans

* migrate standalone configs and modelings to string-based TP plans

* delete TPStyle dataclass

* fix use_local_output defaults for SequenceParallel and PrepareModuleInput in registry

* use parallel style from torch

* revert changes in weight converter

* remove dead code in set_param_for_module

* remove dead code

* cleaning again

* cleaning

* revert change

* linting

* refactor dtensor shard ops

* revert some stuff in core model loading

* core model loading clean

* guarding import

* better separation tensor parall and generic utils

* isolate DtensorShardOperation into a separate file

* no need to patch rotary

* better seperation

* simplify gather_full_state_dict

* simplify _replicate_dtensor

* fix and clean _replicate_dtensor

* better doc for DtensorShardOperation

* fix saving optimizer with DCP for fused weights

* save_pretrained(distributed_checkpoint=true)

* linting

* refactor into a single function _dtensor_from_local_like

* zeros_like instead of empty_like

* move tp and fsdp under distributed

* distribute_model

* fix deadlock when saving

* clip grad norm function

* maybe_disable_foreach_and_fused_for_mixed_dtensor_groups

* better TP api for ease of understanding

* remove shard_param to make it easier

* fix import in test

* _swap_dtensor_params_for_local

* fix qwen3 nanochat dots1

* add tpu

* move TP refactor experimentation scripts to backup branch

Move ad-hoc training / verification / compare scripts off this branch
into refactor-tp-dtensor-scripts so the diff stays focused on library
changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* linting

* register distributed sharding_utils and utils in __init__

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* rename TP plan styles to match new ALL_PARALLEL_STYLES registry

Replace pre-refactor names that no longer exist in
src/transformers/distributed/tensor_parallel.py:
  rowwise -> rowwise_allreduce
  moe_tp_experts -> moe_experts_allreduce
  replicated_with_grad_allreduce -> activation_seq_dim_2

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* enable EP

* Add enable_expert_parallel configuration option in test_distributed_config

* no more auto mode

* edit fsdp plan to every other models

* update fsdp mixin tests

* linting

* fix test fsdp

* fsdp linting

* revert gitignore

* _apply within for loop

* rename

* doc sp plan

* fix

* unified settattr + torch no grad + _local_tensor

* revert

* linting

* fix ruff

* make check-repository-consistency

* trigger fsdp mixin test in CI

* fix fsdp ci

* Reset tests/test_modeling_common.py to main

Restores legitimate improvements that were accidentally undone during a
stale merge of main into fsdp-vs-ddp:

- Restore test_resize_embeddings_untied_no_reinit_on_post_init
- Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo
  special-cases
- Restore skip_base_model parameter on test_reverse_loading_mapping
- Restore "is not None" guard on subconfig in test_initialization
- Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants