Skip to content

Commit c1a5f58

Browse files
alibertsfracapuano
authored andcommitted
Skip normalization parameters in load_smolvla (#1274)
1 parent f7b1605 commit c1a5f58

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

lerobot/common/policies/smolvla/modeling_smolvla.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,13 @@ def load_smolvla(
157157

158158
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
159159

160-
missing, unexpected = model.load_state_dict(state_dict)
160+
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
161+
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
162+
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
161163

162-
if missing or unexpected:
164+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
165+
166+
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
163167
raise RuntimeError(
164168
"SmolVLA %d missing / %d unexpected keys",
165169
len(missing),

0 commit comments

Comments
 (0)