|
53 | 53 | """
|
54 | 54 |
|
55 | 55 | import math
|
| 56 | +import os |
| 57 | +import re |
56 | 58 | from collections import deque
|
57 | 59 |
|
| 60 | +import safetensors |
58 | 61 | import torch
|
59 | 62 | import torch.nn.functional as F # noqa: N812
|
60 | 63 | from torch import Tensor, nn
|
|
73 | 76 | )
|
74 | 77 | from lerobot.common.utils.utils import get_safe_dtype
|
75 | 78 |
|
| 79 | +# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker |
| 80 | +_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") |
| 81 | + |
| 82 | + |
| 83 | +def canonicalise(k: str) -> str: |
| 84 | + """ |
| 85 | + Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a |
| 86 | + normalisation-buffer key. |
| 87 | + """ |
| 88 | + return _VARIANT_RE.sub(".buffer_", k) |
| 89 | + |
| 90 | + |
| 91 | +def standardise_state_dict( |
| 92 | + checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True |
| 93 | +) -> tuple[dict[str, torch.Tensor], list[str]]: |
| 94 | + """ |
| 95 | + • Re-keys `checkpoint ` so that every entry matches the *reference* key set. |
| 96 | + • If several variant keys collapse to the same canonical name we keep the |
| 97 | + first one and log the collision. |
| 98 | + • Returns the new dict + a list of entries that could not be matched. |
| 99 | + """ |
| 100 | + out, collisions, unmatched = {}, {}, [] |
| 101 | + |
| 102 | + for k, v in checkpoint.items(): |
| 103 | + canon = canonicalise(k) |
| 104 | + if canon in ref_keys: |
| 105 | + if canon in out: # duplicate after collapsing |
| 106 | + collisions.setdefault(canon, []).append(k) |
| 107 | + else: |
| 108 | + out[canon] = v |
| 109 | + else: |
| 110 | + unmatched.append(k) |
| 111 | + |
| 112 | + if verbose: |
| 113 | + for canon, variants in collisions.items(): |
| 114 | + print(f"[standardise_state_dict] '{canon}' ← {variants}") |
| 115 | + if unmatched: |
| 116 | + print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") |
| 117 | + |
| 118 | + out.update({k: checkpoint[k] for k in unmatched}) |
| 119 | + return out, unmatched |
| 120 | + |
| 121 | + |
| 122 | +def rename_checkpoint_keys(checkpoint: dict, rename_str: str): |
| 123 | + """ |
| 124 | + Renames keys in a checkpoint dictionary based on the given rename string. |
| 125 | +
|
| 126 | + Args: |
| 127 | + checkpoint (dict): The checkpoint dictionary. |
| 128 | + rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". |
| 129 | +
|
| 130 | + Returns: |
| 131 | + dict: The modified checkpoint with renamed keys. |
| 132 | + """ |
| 133 | + |
| 134 | + rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) |
| 135 | + |
| 136 | + new_checkpoint = {} |
| 137 | + for k, v in checkpoint.items(): |
| 138 | + for old_key, new_key in rename_dict.items(): |
| 139 | + if old_key in k: |
| 140 | + k = k.replace(old_key, new_key) |
| 141 | + new_checkpoint[k] = v |
| 142 | + return new_checkpoint |
| 143 | + |
| 144 | + |
| 145 | +def load_smolvla( |
| 146 | + model: torch.nn.Module, |
| 147 | + filename: str | os.PathLike, |
| 148 | + *, |
| 149 | + device: str = "cpu", |
| 150 | + checkpoint_keys_mapping: str = "", |
| 151 | +) -> torch.nn.Module: |
| 152 | + state_dict = safetensors.torch.load_file(filename, device=device) |
| 153 | + |
| 154 | + # Optional user-supplied renames (e.g. "model._orig_mod.//model.") |
| 155 | + if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: |
| 156 | + state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) |
| 157 | + |
| 158 | + state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) |
| 159 | + |
| 160 | + missing, unexpected = model.load_state_dict(state_dict) |
| 161 | + |
| 162 | + if missing or unexpected: |
| 163 | + raise RuntimeError( |
| 164 | + "SmolVLA %d missing / %d unexpected keys", |
| 165 | + len(missing), |
| 166 | + len(unexpected), |
| 167 | + ) |
| 168 | + |
| 169 | + return model |
| 170 | + |
76 | 171 |
|
77 | 172 | def create_sinusoidal_pos_embedding(
|
78 | 173 | time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
@@ -264,6 +359,23 @@ def reset(self):
|
264 | 359 | ACTION: deque(maxlen=self.config.n_action_steps),
|
265 | 360 | }
|
266 | 361 |
|
| 362 | + # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues |
| 363 | + @classmethod |
| 364 | + def _load_as_safetensor( |
| 365 | + cls, |
| 366 | + model: "SmolVLAPolicy", |
| 367 | + model_file: str, |
| 368 | + map_location: str, |
| 369 | + strict: bool, |
| 370 | + ): |
| 371 | + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) |
| 372 | + return load_smolvla( |
| 373 | + model, |
| 374 | + model_file, |
| 375 | + device=map_location, |
| 376 | + checkpoint_keys_mapping="model._orig_mod.//model.", |
| 377 | + ) |
| 378 | + |
267 | 379 | def get_optim_params(self) -> dict:
|
268 | 380 | return self.parameters()
|
269 | 381 |
|
@@ -387,10 +499,14 @@ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
387 | 499 | """Tokenize the text input"""
|
388 | 500 | device = batch[OBS_STATE].device
|
389 | 501 | tasks = batch["task"]
|
| 502 | + if isinstance(tasks, str): |
| 503 | + tasks = [tasks] |
| 504 | + |
390 | 505 | if len(tasks) == 1:
|
391 | 506 | tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
392 | 507 |
|
393 | 508 | tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
| 509 | + |
394 | 510 | tokenized_prompt = self.language_tokenizer.__call__(
|
395 | 511 | tasks,
|
396 | 512 | padding=self.config.pad_language_to,
|
|
0 commit comments