Skip to content
12 changes: 11 additions & 1 deletion .circleci/create_circleci_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,15 @@ def job_name(self):
parallelism=6,
)

fsdp_ci_job = CircleCIJob(
"fsdp_ci",
additional_env={"RUN_FSDP_TESTS": True},
docker_image=[{"image": "huggingface/transformers-torch-light"}],
install_steps=["uv pip install .", "uv pip install torchao"],
marker="is_fsdp_test",
parallelism=6,
)

# We also include a `dummy.py` file in the files to be doc-tested to prevent edge case failure. Otherwise, the pytest
# hangs forever during test collection while showing `collecting 0 items / 21 errors`. (To see this, we have to remove
# the bash output redirection.)
Expand Down Expand Up @@ -435,7 +444,8 @@ def job_name(self):
DOC_TESTS = [doc_test_job]
TRAINING_CI_TESTS = [training_ci_job]
TENSOR_PARALLEL_CI_TESTS = [tensor_parallel_ci_job]
ALL_TESTS = REGULAR_TESTS + EXAMPLES_TESTS + PIPELINE_TESTS + REPO_UTIL_TESTS + DOC_TESTS + [custom_tokenizers_job] + [exotic_models_job] + TRAINING_CI_TESTS + TENSOR_PARALLEL_CI_TESTS # fmt: skip
FSDP_CI_TESTS = [fsdp_ci_job]
ALL_TESTS = REGULAR_TESTS + EXAMPLES_TESTS + PIPELINE_TESTS + REPO_UTIL_TESTS + DOC_TESTS + [custom_tokenizers_job] + [exotic_models_job] + TRAINING_CI_TESTS + TENSOR_PARALLEL_CI_TESTS + FSDP_CI_TESTS # fmt: skip


def create_circleci_config(folder=None):
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ __pycache__/

# C extensions
*.so

checkpoints/
# tests and logs
tests/fixtures/cached_*_text.txt
logs/
Expand Down
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"debug_utils": [],
"dependency_versions_check": [],
"dependency_versions_table": [],
"distributed": [],
"dynamic_module_utils": [],
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/cli/serving/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def __init__(
self._tokenizer = tokenizer
self._loop = loop
self._queue = queue
self._decode_stream = DecodeStream([], skip_special_tokens)
self._decode_stream: DecodeStream = DecodeStream([], skip_special_tokens)
self._stc_id = tool_config["stc_id"] if tool_config else None
self._etc_id = tool_config["etc_id"] if tool_config else None
self._inside_tool_call = False
Expand Down Expand Up @@ -523,7 +523,7 @@ def put(self, value: "torch.Tensor") -> None:

is_start_or_end_token = _advance_thinking_state(self, token_id)

text = self._decode_stream.step(self._tokenizer, token_id)
text = self._decode_stream.step(self._tokenizer, token_id) # ty:ignore[unresolved-attribute]
if text is None or self._inside_tool_call or token_id == self._etc_id or is_start_or_end_token:
continue
if self._inside_thinking:
Expand Down Expand Up @@ -575,7 +575,7 @@ def __init__(
self._loop = loop
self._queue = queue
self._tokenizer = tokenizer
self._decode_stream = DecodeStream([], True)
self._decode_stream: DecodeStream = DecodeStream([], True)
self._stc_id = tool_config["stc_id"] if tool_config else None
self._etc_id = tool_config["etc_id"] if tool_config else None
self._inside_tool_call = False
Expand All @@ -602,7 +602,7 @@ def put(self, output: "GenerationOutput") -> None:

is_start_or_end_token = _advance_thinking_state(self, token_id)

text = self._decode_stream.step(self._tokenizer, token_id)
text = self._decode_stream.step(self._tokenizer, token_id) # ty:ignore[unresolved-attribute]
if text is None or self._inside_tool_call or token_id == self._etc_id or is_start_or_end_token:
continue
if self._inside_thinking:
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
naming of attributes.
- **base_model_tp_plan** (`dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
- **base_model_sp_plan** (`dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a sequence
parallel plan, used when `distributed_config.enable_sequence_parallel=True` and
`enable_expert_parallel=False`. Same key/value shape as the TP plan.
- **base_model_tp_ep_plan** (`dict[str, Any]`) -- Complete plan for inference TP + expert parallel
(`enable_sequence_parallel=False`, `enable_expert_parallel=True`).
- **base_model_sp_ep_plan** (`dict[str, Any]`) -- Complete plan for training SP + expert parallel
(`enable_sequence_parallel=True`, `enable_expert_parallel=True`).
- **base_model_fsdp_plan** (`dict[Any, str]`) -- A dict that maps sub-modules of a base model to an FSDP2
sharding strategy (e.g. `"free_full_weight"` / `"keep_full_weight"`). Keys can be wildcard module paths
(e.g. `"layers.*"`) or tuples of paths (grouped into a single `fully_shard` call).
- **base_model_pp_plan** (`dict[str, tuple[list[str]]]`) -- A dict that maps child-modules of a base model to a
pipeline parallel plan that enables users to place the child-module on the appropriate device.

Expand Down Expand Up @@ -220,6 +230,10 @@ class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
keys_to_ignore_at_inference: ClassVar[list[str]] = []
attribute_map: ClassVar[dict[str, str]] = {}
base_model_tp_plan: ClassVar[dict[str, Any] | None] = None
base_model_sp_plan: ClassVar[dict[str, Any] | None] = None
base_model_tp_ep_plan: ClassVar[dict[str, Any] | None] = None
base_model_sp_ep_plan: ClassVar[dict[str, Any] | None] = None
base_model_fsdp_plan: ClassVar[dict[Any, str] | None] = None
base_model_pp_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
base_model_ep_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
_auto_class: ClassVar[str | None] = None
Expand Down Expand Up @@ -1022,6 +1036,9 @@ def to_dict(self) -> dict[str, Any]:
# Pop "kwargs" since they are unpacked and set in the post init
output.pop("kwargs", None)

if "distributed_config" in output and hasattr(output["distributed_config"], "to_dict"):
output["distributed_config"] = output["distributed_config"].to_dict()

def to_list(value):
if isinstance(value, tuple):
value = [to_list(item) for item in value]
Expand Down Expand Up @@ -1167,6 +1184,7 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
"_experts_implementation_internal",
"ignore_keys_at_rope_validation",
"base_model_tp_plan",
"base_model_sp_plan",
"base_model_pp_plan",
]:
d.pop(key_to_remove, None)
Expand Down
130 changes: 48 additions & 82 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@

import torch

from .distributed.sharding_utils import DtensorShardOperation, _dtensor_from_local_like
from .integrations.accelerate import get_device, offload_weight
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
from .utils import is_env_variable_true
from .utils.loading_report import LoadStateDictInfo
from .utils.logging import get_logger, tqdm


_torch_distributed_available = torch.distributed.is_available()
if _torch_distributed_available:
from torch.distributed.tensor import DTensor

if TYPE_CHECKING:
from .integrations.tensor_parallel import TensorParallelLayer
from .modeling_utils import LoadStateDictConfig, PreTrainedModel
from .quantizers import HfQuantizer


logger = get_logger(__name__)


Expand Down Expand Up @@ -388,7 +388,7 @@ def __init__(self):

def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
dim1, dim2 = tensor.shape
n_heads = self.config.getattr("num_attention_heads", 1)
n_heads = getattr(self.config, "num_attention_heads", 1)

tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
tensor = tensor.transpose(1, 2).reshape(dim1, dim2)
Expand All @@ -404,11 +404,10 @@ def convert(
**kwargs,
) -> dict[str, list[torch.Tensor]]:
self.config = config
output: dict[str, list[torch.Tensor]] = {}
output = {}
for key, tensors in input_dict.items():
if len(tensors) != 1:
raise ValueError("PermuteForRope expects a single tensor per key.")
output[key] = [self._apply(tensors[0])]
tensor = tensors[0] if isinstance(tensors, list) else tensors
output[key] = self._apply(tensor)
return output


Expand Down Expand Up @@ -610,7 +609,7 @@ def __init__(self, source_patterns: str | list[str], target_patterns: str | list
self._original_target_patterns = self.target_patterns.copy()

# Init fields that will be used during conversion
self.distributed_operation: TensorParallelLayer | None = None
self.distributed_operation: Any = None
self.quantization_operation: ConversionOps | None = None
self.collected_tensors: dict[str, list[Future]] = defaultdict(list)
self.layer_targets: dict[str, set[str]] = defaultdict(set)
Expand Down Expand Up @@ -768,7 +767,9 @@ def reverse_transform(self) -> WeightTransform:
kwargs["operations"] = [op.reverse_op for op in self.operations[::-1]]

reverse_transform = self.__class__(
source_patterns=self._original_target_patterns, target_patterns=self._original_source_patterns, **kwargs
source_patterns=self._original_target_patterns,
target_patterns=self._original_source_patterns,
**kwargs,
)
reverse_transform.scope_prefix = self.scope_prefix
reverse_transform.base_model_prefix = self.base_model_prefix
Expand All @@ -795,6 +796,8 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]:
# Sync loading
elif callable(tensors[0]):
tensors = [func() for func in tensors]
# Some may be None for some distributed setups
tensors = [tensor for tensor in tensors if tensor is not None]
# Add them to the new dictionary
collected_tensors[key] = tensors

Expand Down Expand Up @@ -997,29 +1000,20 @@ def spawn_materialize(
tensor: torch.Tensor,
device=None,
dtype=None,
sharding_op: DtensorShardOperation | None = None,
tensor_idx: int | None = None,
) -> Future | Callable:
"""Materialize a tensor from file asynchronously if `thread_pool` is provided, or return a Callable that will
load the tensor synchronously when called."""

def _job():
return _materialize_copy(tensor, device, dtype)

if thread_pool is not None:
return thread_pool.submit(_job)
else:
# Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
# memory during Conversion
return _job

"""Materialize (and optionally shard) a tensor, asynchronously if a thread pool is provided.

def spawn_tp_materialize(
thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, sharding_method, tensor_idx, device=None, dtype=None
) -> Future | Callable:
"""Materialize and shard a tensor (according to the TP-plan) from file asynchronously if `thread_pool` is provided, or
return a Callable that will load the tensor synchronously when called."""
When ``sharding_op`` is given the tensor is sharded according to the DTensor
placement strategy; otherwise it is simply copied to *device*/*dtype*.
Without a thread pool a deferred callable is returned instead of a Future.
"""

def _job():
return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype)
if sharding_op is not None:
return sharding_op.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype)
return _materialize_copy(tensor, device, dtype)

if thread_pool is not None:
return thread_pool.submit(_job)
Expand Down Expand Up @@ -1092,12 +1086,12 @@ def _format_op_name(curr_op: list[ConversionOps] | ConversionOps | None) -> str
raise SkipParameters()


@torch.no_grad()
def set_param_for_module(
model: PreTrainedModel,
target_name: str,
param_value: torch.Tensor,
loading_info: LoadStateDictInfo,
distributed_operation: TensorParallelLayer | None,
hf_quantizer: HfQuantizer,
):
module_path, _, param_name = target_name.rpartition(".")
Expand All @@ -1112,27 +1106,25 @@ def set_param_for_module(
if ref is None:
loading_info.unexpected_keys.add(target_name)
else:
if not isinstance(param_value, torch.nn.Parameter):
if not isinstance(param_value, torch.nn.Parameter) and not isinstance(ref, DTensor):
if param_name not in module_obj._buffers:
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())

# Remove from missing keys (it's either mismatched, or all good)
loading_info.missing_keys.discard(target_name)

# Determine expected shape: for TP, use sharded shape; otherwise, use full shape
if distributed_operation is not None:
expected_shape = torch.Size(distributed_operation.get_expected_sharded_shape(ref.shape))
else:
expected_shape = ref.shape
expected_shape = ref._local_tensor.shape if isinstance(ref, DTensor) else ref.shape

if ref is not None and param_value.shape != expected_shape and hf_quantizer is None:
loading_info.mismatched_keys.add((target_name, param_value.shape, expected_shape))
else:
if isinstance(ref, DTensor):
local_param = param_value.detach() if isinstance(param_value, torch.nn.Parameter) else param_value
dtensor_param = _dtensor_from_local_like(local_param, ref)
param_value = torch.nn.Parameter(dtensor_param, requires_grad=ref.requires_grad)
# super important otherwise _init_weight will re-init the param
param_value._is_hf_initialized = True
setattr(module_obj, param_name, param_value)
if distributed_operation is not None:
distributed_operation.update_module_attributes(module_obj)


def offload_and_maybe_resave_param(
Expand Down Expand Up @@ -1317,7 +1309,6 @@ def convert_and_load_state_dict_in_model(
device_map = load_config.device_map or {"": "cpu"}
hf_quantizer = load_config.hf_quantizer
dtype = load_config.dtype
device_mesh = load_config.device_mesh
disk_offload_folder = load_config.disk_offload_folder
offload_buffers = load_config.offload_buffers
dtype_plan = load_config.dtype_plan or {}
Expand Down Expand Up @@ -1352,10 +1343,6 @@ def convert_and_load_state_dict_in_model(
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
param_name_to_load: dict[str, WeightRenaming | WeightConverter] = {}

# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
if tp_plan != {}:
tp_plan_alt, tp_plan_by_group_name, _ = build_glob_alternation(list(tp_plan.keys()))
if dtype_plan != {}:
dtype_policy_alt, dtype_policy_by_group_name, _ = build_glob_alternation(list(dtype_plan.keys()))

Expand Down Expand Up @@ -1419,37 +1406,23 @@ def convert_and_load_state_dict_in_model(
elif empty_param is not None and empty_param.dtype != _dtype:
_dtype = empty_param.dtype # usually correct when initializing

# 4. Handle TP sharding or device_map placement
future_or_tensor = None
if device_mesh and tp_plan:
if matched_tp_pattern := tp_plan_alt.search(renamed_key):
matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup]
if getattr(mapping, "distributed_operation", None) is None:
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
mapping.distributed_operation = tp_layer(
device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone()
)
# Per-expert sharding (EP) needs `tensor_idx` = the expert index so the
# distributed op selects whole experts. The signal is a `MergeModulelist`
# in the chain; it isn't always `operations[0]` (e.g. an FP8 quantizer
# prepends a scale-decode op), so scan the whole chain rather than just the head.
shard_index = (
len(mapping.collected_tensors.get(source_pattern, []))
if isinstance(mapping, WeightConverter)
and any(isinstance(op, MergeModulelist) for op in mapping.operations)
else None
)
future_or_tensor = spawn_tp_materialize(
thread_pool,
tensor,
mapping.distributed_operation,
shard_index,
device_map[""],
_dtype,
)

if future_or_tensor is None:
param_device = get_device(device_map, renamed_key, valid_torch_device=True)
# 4. Materialize tensor — shard-on-read for DTensor params, plain copy otherwise
param_device = get_device(device_map, renamed_key, valid_torch_device=True)
if isinstance(empty_param, DTensor):
tensor_idx = (
len(mapping.collected_tensors.get(source_pattern, []))
if isinstance(mapping, WeightConverter) and isinstance(mapping.operations[0], MergeModulelist)
else None
)
future_or_tensor = spawn_materialize(
thread_pool,
tensor,
param_device,
_dtype,
sharding_op=DtensorShardOperation(empty_param),
tensor_idx=tensor_idx,
)
else:
future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype)

mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor)
Expand Down Expand Up @@ -1479,14 +1452,7 @@ def convert_and_load_state_dict_in_model(
target_name, param, loading_info, disk_offload_folder, disk_offload_index, mapping
)
else:
set_param_for_module(
model,
target_name,
param,
loading_info,
mapping.distributed_operation,
hf_quantizer,
)
set_param_for_module(model, target_name, param, loading_info, hf_quantizer)

# Cleanup all the tensors that were gathered before next iteration
del realized_value
Expand Down
Loading
Loading