Skip to content
48 changes: 46 additions & 2 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Literal, Optional

from safetensors import safe_open

from ..conversion_mapping import (
_MODEL_TO_CONVERSION_PATTERN,
get_checkpoint_conversion_mapping,
Expand Down Expand Up @@ -55,7 +57,7 @@
from accelerate.utils import get_balanced_memory, infer_auto_device_map

# Minimum PEFT version supported for the integration
MIN_PEFT_VERSION = "0.18.0"
MIN_PEFT_VERSION = "0.18.2"


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -498,8 +500,9 @@ def load_adapter(
`find_adapter_config_file` method.
"""
from peft import PeftType
from peft.utils.save_and_load import _maybe_shard_state_dict_for_tp

from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files
from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files, load_state_dict

if local_files_only:
kwargs["local_files_only"] = True
Expand Down Expand Up @@ -608,13 +611,54 @@ def load_adapter(
checkpoint_files, sharded_metadata = [], {}

device_map = getattr(self, "hf_device_map", {"": self.device})

# If the model is tensor parallel, we handle the sharding of the state dict here since the logic in `self._load_pretrained_model`
# is not compatible with the way PEFT adapter should be sharded.
has_tp_adapters = False
for module in self.modules():
tp_info = getattr(module, "_tp_info", None)
if tp_info is not None:
has_tp_adapters = True
break

if has_tp_adapters:
all_pointer = set()
if adapter_state_dict is not None:
merged_state_dict = adapter_state_dict
elif (
checkpoint_files is not None
and checkpoint_files[0].endswith(".safetensors")
and adapter_state_dict is None
):
merged_state_dict = {}
for file in checkpoint_files:
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
for k in file_pointer.keys():
merged_state_dict[k] = file_pointer.get_tensor(k)
# Checkpoints are .bin
elif checkpoint_files is not None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
merged_state_dict.update(load_state_dict(ckpt_file))
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")

adapter_state_dict = merged_state_dict

if any(not isinstance(v, torch.Tensor) for v in adapter_state_dict.values()):
raise ValueError("Expected all values in the adapter state dict to be tensors.")

_maybe_shard_state_dict_for_tp(self, adapter_state_dict, adapter_name)

load_config = replace(
load_config,
pretrained_model_name_or_path=peft_model_id,
sharded_metadata=sharded_metadata,
weight_mapping=peft_weight_conversions,
device_map=device_map,
)

loading_info, _ = self._load_pretrained_model(
model=self,
state_dict=adapter_state_dict,
Expand Down
Loading