Skip to content

Commit d052118

Browse files
danaaubakirovadanaaubakirovapre-commit-ci[bot]aliberts
authored
fix issues: checkpoints keys mismatch and 'task' tokenisation in smolvla (#1256)
Co-authored-by: danaaubakirova <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <[email protected]> Co-authored-by: Simon Alibert <[email protected]>
1 parent 10b7b35 commit d052118

File tree

2 files changed

+117
-1
lines changed

2 files changed

+117
-1
lines changed

lerobot/common/policies/smolvla/modeling_smolvla.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@
5353
"""
5454

5555
import math
56+
import os
57+
import re
5658
from collections import deque
5759

60+
import safetensors
5861
import torch
5962
import torch.nn.functional as F # noqa: N812
6063
from torch import Tensor, nn
@@ -73,6 +76,98 @@
7376
)
7477
from lerobot.common.utils.utils import get_safe_dtype
7578

79+
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
80+
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
81+
82+
83+
def canonicalise(k: str) -> str:
84+
"""
85+
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
86+
normalisation-buffer key.
87+
"""
88+
return _VARIANT_RE.sub(".buffer_", k)
89+
90+
91+
def standardise_state_dict(
92+
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
93+
) -> tuple[dict[str, torch.Tensor], list[str]]:
94+
"""
95+
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
96+
• If several variant keys collapse to the same canonical name we keep the
97+
first one and log the collision.
98+
• Returns the new dict + a list of entries that could not be matched.
99+
"""
100+
out, collisions, unmatched = {}, {}, []
101+
102+
for k, v in checkpoint.items():
103+
canon = canonicalise(k)
104+
if canon in ref_keys:
105+
if canon in out: # duplicate after collapsing
106+
collisions.setdefault(canon, []).append(k)
107+
else:
108+
out[canon] = v
109+
else:
110+
unmatched.append(k)
111+
112+
if verbose:
113+
for canon, variants in collisions.items():
114+
print(f"[standardise_state_dict] '{canon}' ← {variants}")
115+
if unmatched:
116+
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
117+
118+
out.update({k: checkpoint[k] for k in unmatched})
119+
return out, unmatched
120+
121+
122+
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
123+
"""
124+
Renames keys in a checkpoint dictionary based on the given rename string.
125+
126+
Args:
127+
checkpoint (dict): The checkpoint dictionary.
128+
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
129+
130+
Returns:
131+
dict: The modified checkpoint with renamed keys.
132+
"""
133+
134+
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
135+
136+
new_checkpoint = {}
137+
for k, v in checkpoint.items():
138+
for old_key, new_key in rename_dict.items():
139+
if old_key in k:
140+
k = k.replace(old_key, new_key)
141+
new_checkpoint[k] = v
142+
return new_checkpoint
143+
144+
145+
def load_smolvla(
146+
model: torch.nn.Module,
147+
filename: str | os.PathLike,
148+
*,
149+
device: str = "cpu",
150+
checkpoint_keys_mapping: str = "",
151+
) -> torch.nn.Module:
152+
state_dict = safetensors.torch.load_file(filename, device=device)
153+
154+
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
155+
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
156+
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
157+
158+
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
159+
160+
missing, unexpected = model.load_state_dict(state_dict)
161+
162+
if missing or unexpected:
163+
raise RuntimeError(
164+
"SmolVLA %d missing / %d unexpected keys",
165+
len(missing),
166+
len(unexpected),
167+
)
168+
169+
return model
170+
76171

77172
def create_sinusoidal_pos_embedding(
78173
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
@@ -264,6 +359,23 @@ def reset(self):
264359
ACTION: deque(maxlen=self.config.n_action_steps),
265360
}
266361

362+
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
363+
@classmethod
364+
def _load_as_safetensor(
365+
cls,
366+
model: "SmolVLAPolicy",
367+
model_file: str,
368+
map_location: str,
369+
strict: bool,
370+
):
371+
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
372+
return load_smolvla(
373+
model,
374+
model_file,
375+
device=map_location,
376+
checkpoint_keys_mapping="model._orig_mod.//model.",
377+
)
378+
267379
def get_optim_params(self) -> dict:
268380
return self.parameters()
269381

@@ -387,10 +499,14 @@ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
387499
"""Tokenize the text input"""
388500
device = batch[OBS_STATE].device
389501
tasks = batch["task"]
502+
if isinstance(tasks, str):
503+
tasks = [tasks]
504+
390505
if len(tasks) == 1:
391506
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
392507

393508
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
509+
394510
tokenized_prompt = self.language_tokenizer.__call__(
395511
tasks,
396512
padding=self.config.pad_language_to,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ intelrealsense = [
9090
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
9191
]
9292
pi0 = ["transformers>=4.48.0"]
93-
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
93+
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
9494
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
9595
stretch = [
9696
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",

0 commit comments

Comments
 (0)