Skip to content

Commit 6bd8412

Browse files
committed
WIP on MultiArmFollower and MultiArmLeader
1 parent 2b71789 commit 6bd8412

File tree

8 files changed

+504
-0
lines changed

8 files changed

+504
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .config_multi_arm_follower import MultiArmFollowerConfig
2+
from .multi_arm_follower import MultiArmFollower
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from dataclasses import dataclass, field
18+
19+
from lerobot.common.cameras import CameraConfig
20+
21+
from ..config import RobotConfig
22+
23+
24+
@RobotConfig.register_subclass("multi_arm_follower")
25+
@dataclass
26+
class MultiArmFollowerConfig(RobotConfig):
27+
arms: list[RobotConfig]
28+
29+
# cameras
30+
cameras: dict[str, CameraConfig] = field(default_factory=dict)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import logging
18+
import time
19+
from functools import cached_property
20+
from typing import Any
21+
22+
from lerobot.common.cameras.utils import make_cameras_from_configs
23+
from lerobot.common.robots.utils import make_robot_from_config
24+
25+
from ..robot import Robot
26+
from .config_multi_arm_follower import MultiArmFollowerConfig
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class MultiArmFollower(Robot):
32+
"""
33+
Multiple Arms Follower.
34+
"""
35+
36+
config_class = MultiArmFollowerConfig
37+
name = "multi_arm_follower"
38+
39+
def __init__(self, config: MultiArmFollowerConfig):
40+
super().__init__(config)
41+
self.config = config
42+
43+
self.arms = [make_robot_from_config(arm_config) for arm_config in config.arms]
44+
45+
self.cameras = make_cameras_from_configs(config.cameras)
46+
47+
def _encode_arm_index(self, key: str, index: int) -> str:
48+
return f"arm{index}__{key}"
49+
50+
def _decode_arm_index(self, key: str) -> int:
51+
arm_id, *remaining = key.split("__")
52+
assert arm_id.startswith("arm"), (arm_id, key)
53+
return int(arm_id[len("arm") :]), "__".join(remaining)
54+
55+
@cached_property
56+
def observation_features(self) -> dict[str, type | tuple]:
57+
# Get quickly all observation_features
58+
# assuming minimal latency due the loop
59+
all_observations = [arm.observation_features for arm in self.arms]
60+
# Post-process the results:
61+
all_observations = [
62+
{self._encode_arm_index(key, i): value for key, value in obs.items()}
63+
for i, obs in enumerate(all_observations)
64+
]
65+
return {k: v for obs_ft in all_observations for k, v in obs_ft.items()}
66+
67+
@cached_property
68+
def action_features(self) -> dict[str, type]:
69+
# Get quickly all action_features
70+
# assuming minimal latency due the loop
71+
all_actions = [arm.action_features for arm in self.arms]
72+
# Post-process the results:
73+
all_actions = [
74+
{self._encode_arm_index(key, i): value for key, value in actions.items()}
75+
for i, actions in enumerate(all_actions)
76+
]
77+
return {k: v for actions in all_actions for k, v in actions.items()}
78+
79+
@property
80+
def is_connected(self) -> bool:
81+
all_arms_connected = all(arm.is_connected for arm in self.arms)
82+
return all_arms_connected and all(cam.is_connected for cam in self.cameras.values())
83+
84+
def connect(self, calibrate: bool = True) -> None:
85+
"""
86+
We assume that at connection time, arms are in a rest position,
87+
and torque can be safely disabled to run calibration.
88+
"""
89+
for arm in self.arms:
90+
arm.connect(calibrate=calibrate)
91+
92+
for cam in self.cameras.values():
93+
cam.connect()
94+
95+
logger.info(f"{self} connected.")
96+
97+
@property
98+
def is_calibrated(self) -> bool:
99+
return all(arm.is_calibrated for arm in self.arms)
100+
101+
def calibrate(self) -> None:
102+
logger.info(f"\nRunning calibration of {self}")
103+
for arm in self.arms:
104+
arm.calibrate()
105+
106+
def configure(self) -> None:
107+
for arm in self.arms:
108+
arm.configure()
109+
110+
def setup_motors(self) -> None:
111+
for arm in self.arms:
112+
arm.setup_motors()
113+
114+
def get_observation(self) -> dict[str, Any]:
115+
# Get quickly all observations
116+
# assuming minimal latency due the loop
117+
all_observations = [arm.get_observation() for arm in self.arms]
118+
# Post-process the results:
119+
all_observations = [
120+
{self._encode_arm_index(key, i): value for key, value in obs.items()}
121+
for i, obs in enumerate(all_observations)
122+
]
123+
obs_dict = {k: v for obs in all_observations for k, v in obs.items()}
124+
125+
# Capture images from cameras
126+
for cam_key, cam in self.cameras.items():
127+
start = time.perf_counter()
128+
obs_dict[cam_key] = cam.async_read()
129+
dt_ms = (time.perf_counter() - start) * 1e3
130+
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
131+
132+
return obs_dict
133+
134+
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
135+
"""Command arm to move to a target joint configuration.
136+
137+
The relative action magnitude may be clipped depending on the configuration parameter
138+
`max_relative_target`. In this case, the action sent differs from original action.
139+
Thus, this function always returns the action actually sent.
140+
141+
Raises:
142+
RobotDeviceNotConnectedError: if robot is not connected.
143+
144+
Returns:
145+
the action sent to the motors, potentially clipped.
146+
"""
147+
action_per_arm = [None] * len(self.arms)
148+
for key in action:
149+
index, base_key = self._decode_arm_index(key)
150+
if action_per_arm[index] is None:
151+
action_per_arm[index] = {base_key: action[key]}
152+
else:
153+
action_per_arm[index][base_key] = action[key]
154+
155+
output = [
156+
arm.send_action(action_per_arm)
157+
for arm, action_per_arm in zip(self.arms, action_per_arm, strict=False)
158+
]
159+
output = [
160+
{self._encode_arm_index(key, i): value for key, value in action.items()}
161+
for i, action in enumerate(output)
162+
]
163+
return {k: v for action in output for k, v in action.items()}
164+
165+
def disconnect(self):
166+
for arm in self.arms:
167+
arm.disconnect()
168+
169+
for cam in self.cameras.values():
170+
cam.disconnect()
171+
172+
logger.info(f"{self} disconnected.")

lerobot/common/robots/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
5353
from tests.mocks.mock_robot import MockRobot
5454

5555
return MockRobot(config)
56+
elif config.type == "multi_arm_follower":
57+
from .multi_arm_follower import MultiArmFollower
58+
59+
return MultiArmFollower(config)
5660
else:
5761
raise ValueError(config.type)
5862

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .config_multi_arm_leader import MultiArmLeaderConfig
2+
from .multi_arm_leader import MultiArmLeader
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from dataclasses import dataclass
18+
19+
from ..config import TeleoperatorConfig
20+
21+
22+
@TeleoperatorConfig.register_subclass("multi_arm_leader")
23+
@dataclass
24+
class MultiArmLeaderConfig(TeleoperatorConfig):
25+
arms: list[TeleoperatorConfig]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import logging
18+
19+
from lerobot.common.teleoperators.utils import make_teleoperator_from_config
20+
21+
from ..teleoperator import Teleoperator
22+
from .config_multi_arm_leader import MultiArmLeaderConfig
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
class MultiArmLeader(Teleoperator):
28+
"""
29+
Multiple Arms Leader.
30+
"""
31+
32+
config_class = MultiArmLeaderConfig
33+
name = "multi_arm_leader"
34+
35+
def __init__(self, config: MultiArmLeaderConfig):
36+
super().__init__(config)
37+
self.config = config
38+
39+
self.arms = [make_teleoperator_from_config(arm_config) for arm_config in config.arms]
40+
41+
def _encode_arm_index(self, key: str, index: int) -> str:
42+
return f"arm{index}__{key}"
43+
44+
def _decode_arm_index(self, key: str) -> int:
45+
arm_id, *remaining = key.split("__")
46+
assert arm_id.startswith("arm"), (arm_id, key)
47+
return int(arm_id[len("arm") :]), "__".join(remaining)
48+
49+
@property
50+
def action_features(self) -> dict[str, type]:
51+
# Get quickly all action_features
52+
# assuming minimal latency due the loop
53+
all_actions = [arm.action_features for arm in self.arms]
54+
# Post-process the results:
55+
all_actions = [
56+
{self._encode_arm_index(key, i): value for key, value in action.items()}
57+
for i, action in enumerate(all_actions)
58+
]
59+
return {k: v for action_fts in all_actions for k, v in action_fts.items()}
60+
61+
@property
62+
def feedback_features(self) -> dict[str, type]:
63+
# Get quickly all action_features
64+
# assuming minimal latency due the loop
65+
all_feedback_fts = [arm.feedback_features for arm in self.arms]
66+
# Post-process the results:
67+
all_feedback_fts = [
68+
{self._encode_arm_index(key, i): value for key, value in feedback_ft.items()}
69+
for i, feedback_ft in enumerate(all_feedback_fts)
70+
]
71+
return {k: v for feedback_fts in all_feedback_fts for k, v in feedback_fts.items()}
72+
73+
@property
74+
def is_connected(self) -> bool:
75+
return all(arm.is_connected for arm in self.arms)
76+
77+
def connect(self, calibrate: bool = True) -> None:
78+
for arm in self.arms:
79+
arm.connect(calibrate=calibrate)
80+
81+
logger.info(f"{self} connected.")
82+
83+
@property
84+
def is_calibrated(self) -> bool:
85+
return all(arm.is_calibrated for arm in self.arms)
86+
87+
def calibrate(self) -> None:
88+
logger.info(f"\nRunning calibration of {self}")
89+
for arm in self.arms:
90+
arm.calibrate()
91+
92+
def configure(self) -> None:
93+
for arm in self.arms:
94+
arm.configure()
95+
96+
def setup_motors(self) -> None:
97+
for arm in self.arms:
98+
arm.setup_motors()
99+
100+
def get_action(self) -> dict[str, float]:
101+
all_actions = [arm.get_action() for arm in self.arms]
102+
all_actions = [
103+
{self._encode_arm_index(key, i): value for key, value in actions.items()}
104+
for i, actions in enumerate(all_actions)
105+
]
106+
return {k: v for actions in all_actions for k, v in actions.items()}
107+
108+
def send_feedback(self, feedback: dict[str, float]) -> None:
109+
# TODO(rcadene, aliberts): Implement force feedback
110+
raise NotImplementedError
111+
112+
def disconnect(self) -> None:
113+
for arm in self.arms:
114+
arm.disconnect()
115+
logger.info(f"{self} disconnected.")

0 commit comments

Comments
 (0)