Skip to content

Commit c05e483

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 808cf63 commit c05e483

File tree

16 files changed

+92
-90
lines changed

16 files changed

+92
-90
lines changed

lerobot/common/envs/factory.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
import importlib
17-
from collections import deque
1817

1918
import gymnasium as gym
2019

lerobot/common/optim/optimizers.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,52 +99,55 @@ def build(self, params: dict) -> torch.optim.Optimizer:
9999
@dataclass
100100
class MultiAdamConfig(OptimizerConfig):
101101
"""Configuration for multiple Adam optimizers with different parameter groups.
102-
102+
103103
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
104-
104+
105105
Args:
106106
lr: Default learning rate (used if not specified for a group)
107107
weight_decay: Default weight decay (used if not specified for a group)
108108
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
109109
grad_clip_norm: Gradient clipping norm
110110
"""
111+
111112
lr: float = 1e-3
112113
weight_decay: float = 0.0
113114
grad_clip_norm: float = 10.0
114115
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
115-
116+
116117
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
117118
"""Build multiple Adam optimizers.
118-
119+
119120
Args:
120121
params_dict: Dictionary mapping parameter group names to lists of parameters
121122
The keys should match the keys in optimizer_groups
122-
123+
123124
Returns:
124125
Dictionary mapping parameter group names to their optimizers
125126
"""
126127
optimizers = {}
127-
128+
128129
for name, params in params_dict.items():
129130
# Get group-specific hyperparameters or use defaults
130131
group_config = self.optimizer_groups.get(name, {})
131-
132+
132133
# Create optimizer with merged parameters (defaults + group-specific)
133134
optimizer_kwargs = {
134135
"lr": group_config.get("lr", self.lr),
135136
"betas": group_config.get("betas", (0.9, 0.999)),
136137
"eps": group_config.get("eps", 1e-5),
137138
"weight_decay": group_config.get("weight_decay", self.weight_decay),
138139
}
139-
140+
140141
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
141-
142+
142143
return optimizers
143144

144145

145-
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None:
146+
def save_optimizer_state(
147+
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
148+
) -> None:
146149
"""Save optimizer state to disk.
147-
150+
148151
Args:
149152
optimizer: Either a single optimizer or a dictionary of optimizers.
150153
save_dir: Directory to save the optimizer state.
@@ -173,11 +176,11 @@ def load_optimizer_state(
173176
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
174177
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
175178
"""Load optimizer state from disk.
176-
179+
177180
Args:
178181
optimizer: Either a single optimizer or a dictionary of optimizers.
179182
save_dir: Directory to load the optimizer state from.
180-
183+
181184
Returns:
182185
The updated optimizer(s) with loaded state.
183186
"""
@@ -201,7 +204,7 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
201204
current_state_dict = optimizer.state_dict()
202205
flat_state = load_file(save_dir / OPTIMIZER_STATE)
203206
state = unflatten_dict(flat_state)
204-
207+
205208
# Handle case where 'state' key might not exist (for newly created optimizers)
206209
if "state" in state:
207210
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}

lerobot/common/policies/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
from lerobot.common.envs.utils import env_to_policy_features
2525
from lerobot.common.policies.act.configuration_act import ACTConfig
2626
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
27+
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
2728
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
2829
from lerobot.common.policies.pretrained import PreTrainedPolicy
2930
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
3031
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
31-
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
3232
from lerobot.configs.policies import PreTrainedConfig
3333
from lerobot.configs.types import FeatureType
3434

lerobot/common/policies/hilserl/classifier/configuration_classifier.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from dataclasses import dataclass, field
2-
from typing import Dict, List
1+
from dataclasses import dataclass
2+
from typing import List
33

44
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
55
from lerobot.common.optim.schedulers import LRSchedulerConfig
66
from lerobot.configs.policies import PreTrainedConfig
7-
from lerobot.configs.types import FeatureType, PolicyFeature
87

98

109
@PreTrainedConfig.register_subclass(name="hilserl_classifier")

lerobot/common/policies/normalize.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ def create_stats_buffers(
8282
if stats and key in stats:
8383
if norm_mode is NormalizationMode.MEAN_STD:
8484
if "mean" not in stats[key] or "std" not in stats[key]:
85-
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization")
86-
85+
raise ValueError(
86+
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
87+
)
88+
8789
if isinstance(stats[key]["mean"], np.ndarray):
8890
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
8991
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
@@ -96,12 +98,16 @@ def create_stats_buffers(
9698
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
9799
else:
98100
type_ = type(stats[key]["mean"])
99-
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.")
100-
101+
raise ValueError(
102+
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
103+
)
104+
101105
elif norm_mode is NormalizationMode.MIN_MAX:
102106
if "min" not in stats[key] or "max" not in stats[key]:
103-
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization")
104-
107+
raise ValueError(
108+
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
109+
)
110+
105111
if isinstance(stats[key]["min"], np.ndarray):
106112
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
107113
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
@@ -110,7 +116,9 @@ def create_stats_buffers(
110116
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
111117
else:
112118
type_ = type(stats[key]["min"])
113-
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.")
119+
raise ValueError(
120+
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
121+
)
114122

115123
stats_buffers[key] = buffer
116124
return stats_buffers

lerobot/common/policies/sac/configuration_sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from lerobot.common.optim.optimizers import MultiAdamConfig
2121
from lerobot.configs.policies import PreTrainedConfig
22-
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
22+
from lerobot.configs.types import NormalizationMode
2323

2424

2525
@dataclass

lerobot/common/policies/sac/modeling_sac.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,6 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
897897
# for j in range(i + 1, num_critics):
898898
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
899899
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
900-
import draccus
901900

902901
from lerobot.configs import parser
903902

lerobot/common/utils/wandb_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,13 @@ def log_policy(self, checkpoint_dir: Path):
115115
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
116116
self._wandb.log_artifact(artifact)
117117

118-
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None):
118+
def log_dict(
119+
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
120+
):
119121
if mode not in {"train", "eval"}:
120122
raise ValueError(mode)
121123
if step is None and custom_step_key is None:
122-
raise ValueError("Either step or custom_step_key must be provided.")
124+
raise ValueError("Either step or custom_step_key must be provided.")
123125

124126
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
125127
# increases with each wandb.log call, but in the case of asynchronous RL for example,
@@ -142,10 +144,7 @@ def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom
142144
continue
143145

144146
# Do not log the custom step key itself.
145-
if (
146-
self._wandb_custom_step_key is not None
147-
and k in self._wandb_custom_step_key
148-
):
147+
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
149148
continue
150149

151150
if custom_step_key is not None:
@@ -160,7 +159,6 @@ def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom
160159

161160
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
162161

163-
164162
def log_video(self, video_path: str, step: int, mode: str = "train"):
165163
if mode not in {"train", "eval"}:
166164
raise ValueError(mode)

lerobot/configs/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
@dataclass
3636
class TrainPipelineConfig(HubMixin):
37-
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
37+
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need a dataset
3838
env: envs.EnvConfig | None = None
3939
policy: PreTrainedConfig | None = None
4040
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.

lerobot/scripts/server/actor_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
python_object_to_bytes,
4848
transitions_to_bytes,
4949
)
50-
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
50+
from lerobot.scripts.server.gym_manipulator import make_robot_env
5151
from lerobot.scripts.server.network_utils import (
5252
receive_bytes_in_chunks,
5353
send_bytes_in_chunks,
@@ -444,7 +444,7 @@ def receive_policy(
444444

445445
# Initialize logging with explicit log file
446446
init_logging(log_file=log_file)
447-
logging.info(f"Actor receive policy process logging initialized")
447+
logging.info("Actor receive policy process logging initialized")
448448

449449
# Setup process handlers to handle shutdown signal
450450
# But use shutdown event from the main process

0 commit comments

Comments
 (0)