diff --git a/docs/source/moveit2.mdx b/docs/source/moveit2.mdx new file mode 120000 index 0000000000..c6cc86b884 --- /dev/null +++ b/docs/source/moveit2.mdx @@ -0,0 +1 @@ +lerobot/common/robots/moveit2/moveit2.mdx \ No newline at end of file diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 11114da0ae..0f9053c052 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -177,6 +177,7 @@ "aloha", "so100", "so101", + "annin_ar4_mk1", ] # lists all available cameras from `lerobot/common/robot_devices/cameras` diff --git a/lerobot/common/robots/ros2/__init__.py b/lerobot/common/robots/ros2/__init__.py new file mode 100644 index 0000000000..2701fa7b92 --- /dev/null +++ b/lerobot/common/robots/ros2/__init__.py @@ -0,0 +1,2 @@ +from .config_ros2 import AnninAR4Config, ROS2Config +from .ros2 import ROS2Robot diff --git a/lerobot/common/robots/ros2/config_ros2.py b/lerobot/common/robots/ros2/config_ros2.py new file mode 100644 index 0000000000..731de0860f --- /dev/null +++ b/lerobot/common/robots/ros2/config_ros2.py @@ -0,0 +1,99 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from enum import Enum + +from lerobot.common.cameras import CameraConfig + +from ..config import RobotConfig + + +class ActionType(Enum): + CARTESIAN_VELOCITY = "cartesian_velocity" + JOINT_POSITION = "joint_position" + # For future extension: + JOINT_VELOCITY = "joint_velocity" + + +@dataclass +class ROS2InterfaceConfig: + # Namespace used by ros2_control / MoveIt2 nodes + namespace: str = "" + + arm_joint_names: list[str] = field( + default_factory=lambda: [ + "joint_1", + "joint_2", + "joint_3", + "joint_4", + "joint_5", + "joint_6", + ] + ) + gripper_joint_name: str = "gripper_joint" + + # Base link name for computing end effector pose / velocity + # Only applicable for cartesian control + base_link: str = "base_link" + + # Only applicable if velocity control is used. + max_linear_velocity: float = 0.05 + max_angular_velocity: float = 0.25 # rad/s + + # Only applicable if position control is used. + min_joint_positions: list[float] | None = None + max_joint_positions: list[float] | None = None + + gripper_open_position: float = 0.0 + gripper_close_position: float = 1.0 + + +@dataclass +class ROS2Config(RobotConfig): + # Action type for controlling the robot. Can be 'cartesian_velocity' or 'joint_position'. + action_type: ActionType = ActionType.JOINT_POSITION + + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + # cameras + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # ROS2 interface configuration + ros2_interface: ROS2InterfaceConfig = field(default_factory=ROS2InterfaceConfig) + + action_from_keyboard: bool = False + + +@RobotConfig.register_subclass("annin_ar4_mk1") +@dataclass +class AnninAR4Config(ROS2Config): + """Annin Robotics AR4 robot configuration - extends ROS2Config with + AR4-specific settings + """ + + action_type: ActionType = ActionType.CARTESIAN_VELOCITY + + ros2_interface: ROS2InterfaceConfig = field( + default_factory=lambda: ROS2InterfaceConfig( + base_link="base_link", + gripper_joint_name="gripper_jaw1_joint", + min_joint_positions=[-2.9671, -0.7330, -1.5533, -2.8798, -1.8326, -2.7053], + max_joint_positions=[2.9671, 1.5708, 0.9076, 2.8798, 1.8326, 2.7053], + gripper_open_position=0.014, + gripper_close_position=0.0, + ), + ) diff --git a/lerobot/common/robots/ros2/moveit2_servo.py b/lerobot/common/robots/ros2/moveit2_servo.py new file mode 100644 index 0000000000..515da10fdc --- /dev/null +++ b/lerobot/common/robots/ros2/moveit2_servo.py @@ -0,0 +1,115 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +logger = logging.getLogger(__name__) + +try: + from rclpy import qos + from rclpy.callback_groups import CallbackGroup + from rclpy.node import Node + + ROS2_AVAILABLE = True +except ImportError as e: + logger.info(f"ROS2 dependencies not available: {e}") + ROS2_AVAILABLE = False + + +class MoveIt2Servo: + """ + Python interface for MoveIt2 Servo. + """ + + def __init__( + self, + node: "Node", + frame_id: str, + callback_group: "CallbackGroup", + ): + if not ROS2_AVAILABLE: + raise ImportError("ROS2 is not available") + + self._node = node + self._frame_id = frame_id + self._enabled = False + + from geometry_msgs.msg import TwistStamped + from moveit_msgs.srv import ServoCommandType + from std_srvs.srv import SetBool + + self._twist_pub = node.create_publisher( + TwistStamped, + "/servo_node/delta_twist_cmds", + qos.QoSProfile( + durability=qos.QoSDurabilityPolicy.VOLATILE, + reliability=qos.QoSReliabilityPolicy.RELIABLE, + history=qos.QoSHistoryPolicy.KEEP_ALL, + ), + callback_group=callback_group, + ) + self._pause_srv = node.create_client( + SetBool, "/servo_node/pause_servo", callback_group=callback_group + ) + self._cmd_type_srv = node.create_client( + ServoCommandType, "/servo_node/switch_command_type", callback_group=callback_group + ) + self._twist_msg = TwistStamped() + self._enable_req = SetBool.Request(data=False) + self._disable_req = SetBool.Request(data=True) + self._twist_type_req = ServoCommandType.Request(command_type=ServoCommandType.Request.TWIST) + + def enable(self, wait_for_server_timeout_sec=1.0) -> bool: + if not self._pause_srv.wait_for_service(timeout_sec=wait_for_server_timeout_sec): + logger.warning("Pause service not available.") + return False + if not self._cmd_type_srv.wait_for_service(timeout_sec=wait_for_server_timeout_sec): + logger.warning("Command type service not available.") + return False + result = self._pause_srv.call(self._enable_req) + if not result or not result.success: + logger.error(f"Enable failed: {getattr(result, 'message', '')}") + self._enabled = False + return False + cmd_result = self._cmd_type_srv.call(self._twist_type_req) + if not cmd_result or not cmd_result.success: + logger.error("Switch to TWIST command type failed.") + self._enabled = False + return False + logger.info("MoveIt Servo enabled.") + self._enabled = True + return True + + def disable(self, wait_for_server_timeout_sec=1.0) -> bool: + if not self._pause_srv.wait_for_service(timeout_sec=wait_for_server_timeout_sec): + logger.warning("Pause service not available.") + return False + result = self._pause_srv.call(self._disable_req) + self._enabled = not (result and result.success) + return bool(result and result.success) + + def servo(self, linear=(0.0, 0.0, 0.0), angular=(0.0, 0.0, 0.0), enable_if_disabled=True): + if not self._enabled and enable_if_disabled and not self.enable(): + logger.warning("Dropping servo command because MoveIt2 Servo is not enabled.") + return + + self._twist_msg.header.frame_id = self._frame_id + self._twist_msg.header.stamp = self._node.get_clock().now().to_msg() + self._twist_msg.twist.linear.x = float(linear[0]) + self._twist_msg.twist.linear.y = float(linear[1]) + self._twist_msg.twist.linear.z = float(linear[2]) + self._twist_msg.twist.angular.x = float(angular[0]) + self._twist_msg.twist.angular.y = float(angular[1]) + self._twist_msg.twist.angular.z = float(angular[2]) + self._twist_pub.publish(self._twist_msg) diff --git a/lerobot/common/robots/ros2/ros2.mdx b/lerobot/common/robots/ros2/ros2.mdx new file mode 100644 index 0000000000..2b736e6269 --- /dev/null +++ b/lerobot/common/robots/ros2/ros2.mdx @@ -0,0 +1,146 @@ +# ROS 2 + +This guide illustrates how to integrate ros2_control or MoveIt2 compatible robots with LeRobot. +For now, a single robot arm with a 1-DoF gripper is supported. + +## Overview + +**Supported control modes:** +- ✅ Joint position via ros2_control +- ✅ Cartesian velocity via MoveIt2 +- 🚧 End-effector pose (coming soon) + +This integration allows you to use LeRobot's full range of features with your robot, including teleoperation, recording, and replay. +A basic understanding of ROS 2, ros2_control, and MoveIt2 is required. +Your robot should already be configured to work with these frameworks. + +## Prerequisites + +### System Requirements + +Before getting started, ensure you have the following installed: + +- [ROS 2 Jazzy](https://docs.ros.org/en/jazzy/Installation.html) + - Older ROS 2 versions will still work but cartesian velocity control will be broken. +- [MoveIt2](https://moveit.ai/install-moveit2/binary) if cartesian velocity control is desired + +For joint position control, the robot should be configured with the following: +- `position_controllers/JointGroupPositionController` for the robot arm joints +- `position_controllers/GripperActionController` for the gripper +- `joint_state_broadcaster/JointStateBroadcaster` for joint state feedback + +For cartesian velocity control, the robot should be configured with the same gripper and state feedback as above, and additionally: +- `joint_trajectory_controller/JointTrajectoryController` for robot arm control +- `moveit_servo` node for real-time control + +LeRobot should run in a virtual environment that has the same Python version as your ROS 2 distribution, i.e for jazzy, Python 3.12 should be used: +```bash +uv venv --python 3.12 && source .venv/bin/activate +source /opt/ros/jazzy/setup.bash +``` + +### Robot Setup + +Your robot must be properly configured and working with MoveIt2 before integration: + +- ✅ MoveIt2 configuration package exists for your robot +- ✅ `move_group` and `moveit_servo` nodes can be launched successfully +- ✅ Robot is calibrated and ready for operation + +See [AR4 ROS Driver](https://github.com/ycheng517/ar4_ros_driver) for an example of a MoveIt2-enabled robot that works with LeRobot. + +## Configuration + +Create a config class for your robot by sub-classing `ROS2Config`. +You may override joint names, gripper configurations, and other parameters as needed. +An example config class for velocity control may look like this: + +```python +from dataclasses import dataclass, field +from lerobot.common.robots.config import RobotConfig +from lerobot.common.robots.ros2.config_ros2 import ROS2Config, ROS2InterfaceConfig + +@RobotConfig.register_subclass("my_ros2_robot") +@dataclass +class MyRobotConfig(ROS2Config): + action_type: ActionType = ActionType.JOINT_VELOCITY + + ros2_interface: ROS2InterfaceConfig = field( + default_factory=lambda: ROS2InterfaceConfig( + base_link="base_link", + arm_joint_names=[ + "joint_1", + "joint_2", + "joint_3", + "joint_4", + "joint_5", + "joint_6", + ], + gripper_joint_name="gripper_joint", + gripper_open_position=0.0, + gripper_close_position=1.0, + max_linear_velocity=0.05, # m/s + max_angular_velocity=0.25, # rad/s + ) + ) +``` + +## Getting Started + +The following steps will guide you through launching your robot and performing keyboard teleoperation with cartesian velocity control. + +### Step 1: Launch Your Robot + +Start your ROS2-enabled robot stack. +The exact commands depend on your robot model and configuration. + +**Example: [Annin Robotics AR4](https://github.com/ycheng517/ar4_ros_driver)** +```bash +# Launch the robot driver and ros2_control controllers +ros2 launch annin_ar4_driver driver.launch.py ar_model:=mk1 calibrate:=True + +# Launch MoveIt2 with Moveit Servo enabled +ros2 launch annin_ar4_moveit_config moveit.launch.py ar_model:=mk1 moveit_servo:=True +``` + +### Step 2: Teleoperation (with keyboard) + +Once your robot is launched and ready, you can teleoperate it using keyboard controls: + +```bash +# Source the ROS environment +source /opt/ros/jazzy/setup.bash + +python -m lerobot.teleoperate \ + --robot.type=my_robot \ + --robot.id=my_robot_arm \ + --robot.action_from_keyboard=True \ + --teleop.type=keyboard \ + --teleop.id=keyboard_controller \ + --display_data=true +``` + +#### Control Mapping (For Cartesian Velocity Control) + +The default keyboard mapping provides intuitive control: + +| Key | Action | +|-----|--------| +| `A/D` | Move left/right (X-axis) | +| `W/S` | Move backward/forward (Y-axis) | +| `N/M` | Move down/up (Z-axis) | +| `I/K` | Rotate around X-axis | +| `J/L` | Rotate around Y-axis | +| `U/O` | Rotate around Z-axis | +| `Space` | Close gripper (hold to close) | + +> **Note**: The gripper closes when Space is held down, and opens when released. All movements are velocity-based and stop when keys are released. + +### Next Steps + +Once you have teleoperation working, you can use all standard LeRobot features as usual: + +- Incorporate cameras and other sensors +- Use `lerobot.record` to collect demonstration datasets +- Use `lerobot.replay` to test recorded trajectories +- Train policies on your robot's data diff --git a/lerobot/common/robots/ros2/ros2.py b/lerobot/common/robots/ros2/ros2.py new file mode 100644 index 0000000000..496a5de407 --- /dev/null +++ b/lerobot/common/robots/ros2/ros2.py @@ -0,0 +1,271 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.common.cameras.utils import make_cameras_from_configs +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_ros2 import ActionType, ROS2Config +from .ros2_interface import ROS2Interface + +logger = logging.getLogger(__name__) + + +class ROS2Robot(Robot): + config_class = ROS2Config + name = "ros2" + + def __init__(self, config: ROS2Config): + super().__init__(config) + self.config = config + self.ros2_interface = ROS2Interface(config.ros2_interface, config.action_type) + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + all_joint_names = self.config.ros2_interface.arm_joint_names.copy() + if self.config.ros2_interface.gripper_joint_name: + all_joint_names.append(self.config.ros2_interface.gripper_joint_name) + motor_state_ft = {f"{motor}.pos": float for motor in all_joint_names} + return {**motor_state_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + if self.config.action_type == ActionType.CARTESIAN_VELOCITY: + return { + "linear_x.vel": float, + "linear_y.vel": float, + "linear_z.vel": float, + "angular_x.vel": float, + "angular_y.vel": float, + "angular_z.vel": float, + "gripper.pos": float, + } + elif self.config.action_type == ActionType.JOINT_POSITION: + return {f"{joint}.pos": float for joint in self.config.ros2_interface.arm_joint_names} | { + "gripper.pos": float + } + else: + raise ValueError(f"Unsupported action type: {self.config.action_type}") + + @property + def is_connected(self) -> bool: + return self.ros2_interface.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + for cam in self.cameras.values(): + cam.connect() + self.ros2_interface.connect() + + @property + def is_calibrated(self) -> bool: + return True + + def calibrate(self) -> None: + pass # robot must be calibrated before running LeRobot + + def configure(self) -> None: + pass # robot must be configured before running LeRobot + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + obs_dict: dict[str, Any] = {} + joint_state = self.ros2_interface.joint_state + if joint_state is None: + raise ValueError("Joint state is not available yet.") + obs_dict.update({f"{joint}.pos": pos for joint, pos in joint_state["position"].items()}) + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, float]) -> dict[str, float]: + """Command arm to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Args: + action (dict[str, float]): The goal positions for the motors or pressed_keys dict. + + Raises: + DeviceNotConnectedError: if robot is not connected. + + Returns: + dict[str, float]: The action sent to the motors, potentially clipped. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.config.action_type == ActionType.CARTESIAN_VELOCITY: + if self.config.action_from_keyboard: + action = self.keyboard_to_velocity_action(action) + + if self.config.max_relative_target is not None: + # We don't have the current velocity of the arm, so set it to 0.0 + # Effectively the goal velocity gets clipped by max_relative_target + goal_present_vel = {key: (act, 0.0) for key, act in action.items()} + action = ensure_safe_goal_position(goal_present_vel, self.config.max_relative_target) + + linear_vel = ( + action["linear_x.vel"], + action["linear_y.vel"], + action["linear_z.vel"], + ) + angular_vel = ( + action["angular_x.vel"], + action["angular_y.vel"], + action["angular_z.vel"], + ) + self.ros2_interface.servo(linear=linear_vel, angular=angular_vel) + elif self.config.action_type == ActionType.JOINT_POSITION: + if self.config.action_from_keyboard: + action = self.keyboard_to_joint_position_action(action) + + if self.config.max_relative_target is not None: + goal_present_pos = {} + joint_state = self.ros2_interface.joint_state + if joint_state is None: + raise ValueError("Joint state is not available yet.") + + for key, goal in action.items(): + present_pos = joint_state["position"].get(key.replace(".pos", ""), 0.0) + goal_present_pos[key] = (goal, present_pos) + action = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + joint_positions = [action[joint + ".pos"] for joint in self.config.ros2_interface.arm_joint_names] + self.ros2_interface.send_joint_position_command(joint_positions) + + gripper_pos = action["gripper.pos"] + self.ros2_interface.send_gripper_command(gripper_pos) + return action + + def keyboard_to_joint_position_action(self, pressed_keys: dict[str, Any]) -> dict[str, float]: + """Convert pressed keys to joint position action commands for teleop. + hardcoded for a 6-DOF arm with a gripper. + """ + action = {f"{joint}.pos": 0.5 for joint in self.config.ros2_interface.arm_joint_names} + if "q" in pressed_keys: + action["joint_1.pos"] += 0.2 + if "a" in pressed_keys: + action["joint_1.pos"] -= 0.2 + if "w" in pressed_keys: + action["joint_2.pos"] += 0.2 + if "s" in pressed_keys: + action["joint_2.pos"] -= 0.2 + if "e" in pressed_keys: + action["joint_3.pos"] += 0.2 + if "d" in pressed_keys: + action["joint_3.pos"] -= 0.2 + if "r" in pressed_keys: + action["joint_4.pos"] += 0.2 + if "f" in pressed_keys: + action["joint_4.pos"] -= 0.2 + if "t" in pressed_keys: + action["joint_5.pos"] += 0.2 + if "g" in pressed_keys: + action["joint_5.pos"] -= 0.2 + if "y" in pressed_keys: + action["joint_6.pos"] += 0.2 + if "h" in pressed_keys: + action["joint_6.pos"] -= 0.2 + + gripper_pos = 0.0 + if "space" in pressed_keys: + gripper_pos = 1.0 + + action["gripper.pos"] = gripper_pos + return action + + def keyboard_to_velocity_action(self, pressed_keys: dict[str, Any]) -> dict[str, float]: + """Convert pressed keys to velocity action commands for teleop.""" + lin_vel_x = 0.0 + if "a" in pressed_keys: + lin_vel_x += 1.0 + if "d" in pressed_keys: + lin_vel_x -= 1.0 + lin_vel_y = 0.0 + if "w" in pressed_keys: + lin_vel_y -= 1.0 + if "s" in pressed_keys: + lin_vel_y += 1.0 + lin_vel_z = 0.0 + if "n" in pressed_keys: + lin_vel_z -= 1.0 + if "m" in pressed_keys: + lin_vel_z += 1.0 + + ang_vel_y = 0.0 + if "j" in pressed_keys: + ang_vel_y -= 1.0 + if "l" in pressed_keys: + ang_vel_y += 1.0 + + ang_vel_x = 0.0 + if "i" in pressed_keys: + ang_vel_x -= 1.0 + if "k" in pressed_keys: + ang_vel_x += 1.0 + ang_vel_z = 0.0 + if "u" in pressed_keys: + ang_vel_z += 1.0 + if "o" in pressed_keys: + ang_vel_z -= 1.0 + + gripper_pos = 0.0 + if "space" in pressed_keys: + gripper_pos = 1.0 + + return { + "linear_x.vel": lin_vel_x, + "linear_y.vel": lin_vel_y, + "linear_z.vel": lin_vel_z, + "angular_x.vel": ang_vel_x, + "angular_y.vel": ang_vel_y, + "angular_z.vel": ang_vel_z, + "gripper.pos": gripper_pos, + } + + def disconnect(self): + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + for cam in self.cameras.values(): + cam.disconnect() + self.ros2_interface.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/robots/ros2/ros2_interface.py b/lerobot/common/robots/ros2/ros2_interface.py new file mode 100644 index 0000000000..442a2af39c --- /dev/null +++ b/lerobot/common/robots/ros2/ros2_interface.py @@ -0,0 +1,228 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time + +from lerobot.common.errors import DeviceNotConnectedError + +from .config_ros2 import ActionType, ROS2InterfaceConfig +from .moveit2_servo import MoveIt2Servo + +logger = logging.getLogger(__name__) + +try: + import rclpy + from control_msgs.action import GripperCommand + from rclpy.action import ActionClient + from rclpy.callback_groups import ReentrantCallbackGroup + from rclpy.executors import Executor, SingleThreadedExecutor + from rclpy.node import Node + from rclpy.publisher import Publisher + from sensor_msgs.msg import JointState + from std_msgs.msg import Float64MultiArray + + ROS2_AVAILABLE = True +except ImportError as e: + logger.info(f"ROS2 dependencies not available: {e}") + ROS2_AVAILABLE = False + + +class ROS2Interface: + """Class to interface with a MoveIt2 manipulator.""" + + def __init__(self, config: ROS2InterfaceConfig, action_type: ActionType): + self.config = config + self.action_type = action_type + self.robot_node: Node | None = None + self.pos_cmd_pub: Publisher | None = None + self.gripper_action_client: ActionClient | None = None + self.executor: Executor | None = None + self.moveit2_servo: MoveIt2Servo | None = None + self.executor_thread: threading.Thread | None = None + self.is_connected = False + self._last_joint_state: dict[str, dict[str, float]] | None = None + + def connect(self) -> None: + if not ROS2_AVAILABLE: + raise ImportError("ROS2 is not available") + + if not rclpy.ok(): + rclpy.init() + + self.robot_node = Node("moveit2_interface_node", namespace=self.config.namespace) + if self.action_type == ActionType.JOINT_POSITION: + self.pos_cmd_pub = self.robot_node.create_publisher( + Float64MultiArray, "position_controller/commands", 10 + ) + elif self.action_type == ActionType.CARTESIAN_VELOCITY: + self.moveit2_servo = MoveIt2Servo( + node=self.robot_node, + frame_id=self.config.base_link, + callback_group=ReentrantCallbackGroup(), + ) + self.gripper_action_client = ActionClient( + self.robot_node, + GripperCommand, + "/gripper_controller/gripper_cmd", + callback_group=ReentrantCallbackGroup(), + ) + self._goal_msg = GripperCommand.Goal() + self.joint_state_sub = self.robot_node.create_subscription( + JointState, + "joint_states", + self._joint_state_callback, + 10, + ) + + # Create and start the executor in a separate thread + self.executor = SingleThreadedExecutor() + self.executor.add_node(self.robot_node) + self.executor_thread = threading.Thread(target=self.executor.spin, daemon=True) + self.executor_thread.start() + time.sleep(3) # Give some time to connect to services and receive messages + + self.is_connected = True + + def send_joint_position_command(self, joint_positions: list[float], unnormalize: bool = True) -> None: + """ + Send a command to the robot's joints. + Args: + joint_positions (list[float]): The target positions for the joints. + normalize (bool): Whether to unnormalize the joint positions based on the robot's configuration. + """ + if not self.robot_node: + raise DeviceNotConnectedError("ROS2Interface is not connected. You need to call `connect()`.") + + if unnormalize: + if self.config.min_joint_positions is None or self.config.max_joint_positions is None: + raise ValueError( + "Joint position normalization requires min and max joint positions to be set." + ) + joint_positions = [ + min(max(pos, min_pos), max_pos) + for pos, min_pos, max_pos in zip( + joint_positions, + self.config.min_joint_positions, + self.config.max_joint_positions, + strict=True, + ) + ] + + if len(joint_positions) != len(self.config.arm_joint_names): + raise ValueError( + f"Expected {len(self.config.arm_joint_names)} joint positions, but got {len(joint_positions)}." + ) + msg = Float64MultiArray() + msg.data = joint_positions + if self.pos_cmd_pub is None: + raise DeviceNotConnectedError("Position command publisher is not initialized.") + self.pos_cmd_pub.publish(msg) + + def servo(self, linear, angular, normalize: bool = True) -> None: + if not self.moveit2_servo: + raise DeviceNotConnectedError("ROS2Interface is not connected. You need to call `connect()`.") + + if normalize: + linear = [v * self.config.max_linear_velocity for v in linear] + angular = [v * self.config.max_angular_velocity for v in angular] + self.moveit2_servo.servo(linear=linear, angular=angular) + + def send_gripper_command(self, position: float, unnormalize: bool = True) -> bool: + """ + Send a command to the gripper to move to a specific position. + Args: + position (float): The target position for the gripper (0=open, 1=closed). + Returns: + bool: True if the command was sent successfully, False otherwise. + """ + if not self.gripper_action_client: + raise RuntimeError("ROS2Interface is not connected. You need to call `connect()`.") + + if not self.gripper_action_client.wait_for_server(timeout_sec=1.0): + logger.error("Gripper action server not available") + return False + + if unnormalize: + # Map normalized position (0=open, 1=closed) to actual gripper joint position + open_pos = self.config.gripper_open_position + closed_pos = self.config.gripper_close_position + gripper_goal = open_pos + position * (closed_pos - open_pos) + else: + gripper_goal = position + + self._goal_msg.command.position = gripper_goal + if not (resp := self.gripper_action_client.send_goal(self._goal_msg)): + logger.error("Failed to send gripper command") + return False + result = resp.result # type: ignore # ROS2 types available at runtime + if result.reached_goal: + return True + logger.error( + f"Gripper did not reach goal. stalled: {result.stalled}, " + f"effort: {result.effort}, position: {result.position}" + ) + return False + + @property + def joint_state(self) -> dict[str, dict[str, float]] | None: + """Get the last received joint state.""" + return self._last_joint_state + + def _joint_state_callback(self, msg: "JointState") -> None: + self._last_joint_state = self._last_joint_state or {} + positions = {} + velocities = {} + name_to_index = {name: i for i, name in enumerate(msg.name)} + for joint_name in self.config.arm_joint_names: + idx = name_to_index.get(joint_name) + if idx is None: + raise ValueError(f"Joint '{joint_name}' not found in joint state.") + positions[joint_name] = msg.position[idx] + velocities[joint_name] = msg.velocity[idx] + + if self.config.gripper_joint_name: + idx = name_to_index.get(self.config.gripper_joint_name) + if idx is None: + raise ValueError( + f"Gripper joint '{self.config.gripper_joint_name}' not found in joint state." + ) + positions[self.config.gripper_joint_name] = msg.position[idx] + velocities[self.config.gripper_joint_name] = msg.velocity[idx] + + self._last_joint_state["position"] = positions + self._last_joint_state["velocity"] = velocities + + def disconnect(self): + if self.joint_state_sub: + self.joint_state_sub.destroy() + self.joint_state_sub = None + if self.gripper_action_client: + self.gripper_action_client.destroy() + self.gripper_action_client = None + if self.robot_node: + self.robot_node.destroy_node() + self.robot_node = None + if self.moveit2_servo: + self.moveit2_servo = None + + if self.executor: + self.executor.shutdown() + self.executor = None + if self.executor_thread: + self.executor_thread.join() + self.executor_thread = None + + self.is_connected = False diff --git a/lerobot/common/robots/utils.py b/lerobot/common/robots/utils.py index ccc1c58e86..4eedda73bd 100644 --- a/lerobot/common/robots/utils.py +++ b/lerobot/common/robots/utils.py @@ -49,6 +49,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .viperx import ViperX return ViperX(config) + elif config.type == "annin_ar4_mk1": + from .ros2 import ROS2Robot + + return ROS2Robot(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/lerobot/common/teleoperators/keyboard/teleop_keyboard.py b/lerobot/common/teleoperators/keyboard/teleop_keyboard.py index bd3ab903ef..d773722849 100644 --- a/lerobot/common/teleoperators/keyboard/teleop_keyboard.py +++ b/lerobot/common/teleoperators/keyboard/teleop_keyboard.py @@ -103,11 +103,15 @@ def calibrate(self) -> None: def _on_press(self, key): if hasattr(key, "char"): self.event_queue.put((key.char, True)) + elif keyboard is not None and key == keyboard.Key.space: + self.event_queue.put(("space", True)) def _on_release(self, key): if hasattr(key, "char"): self.event_queue.put((key.char, False)) - if key == keyboard.Key.esc: + elif keyboard is not None and key == keyboard.Key.space: + self.event_queue.put(("space", False)) + if keyboard is not None and key == keyboard.Key.esc: logging.info("ESC pressed, disconnecting.") self.disconnect() diff --git a/lerobot/teleoperate.py b/lerobot/teleoperate.py index 6080dfb403..5d2407019c 100644 --- a/lerobot/teleoperate.py +++ b/lerobot/teleoperate.py @@ -46,6 +46,7 @@ RobotConfig, koch_follower, make_robot_from_config, + ros2, so100_follower, so101_follower, ) @@ -58,7 +59,7 @@ from lerobot.common.utils.utils import init_logging, move_cursor_up from lerobot.common.utils.visualization_utils import _init_rerun -from .common.teleoperators import gamepad, koch_leader, so100_leader, so101_leader # noqa: F401 +from .common.teleoperators import gamepad, keyboard, koch_leader, so100_leader, so101_leader # noqa: F401 @dataclass @@ -100,7 +101,8 @@ def teleop_loop( print("\n" + "-" * (display_len + 10)) print(f"{'NAME':<{display_len}} | {'NORM':>7}") for motor, value in action.items(): - print(f"{motor:<{display_len}} | {value:>7.2f}") + display_value = "None" if value is None else f"{value:>7.2f}" + print(f"{motor:<{display_len}} | {display_value}") print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)") if duration is not None and time.perf_counter() - start >= duration: diff --git a/tests/robots/test_ros2.py b/tests/robots/test_ros2.py new file mode 100644 index 0000000000..43b0bc3acb --- /dev/null +++ b/tests/robots/test_ros2.py @@ -0,0 +1,273 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from lerobot.common.robots.ros2 import ROS2Config, ROS2Robot + + +def _make_moveit2_interface_mock() -> MagicMock: + """Return a ROS2Interface mock with just the attributes used by the robot.""" + interface = MagicMock(name="MoveIt2InterfaceMock") + interface.is_connected = False + + cfg = ROS2Config() + + # Mock joint state + all_joint_names = cfg.ros2_interface.arm_joint_names + [cfg.ros2_interface.gripper_joint_name] + interface.joint_state = { + "position": dict.fromkeys(all_joint_names, 0.1), + "velocity": dict.fromkeys(all_joint_names, 0.2), + } + + # Mock config + config_mock = MagicMock() + config_mock.arm_joint_names = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"] + config_mock.gripper_joint_name = "gripper_jaw1_joint" + interface.config = config_mock + + def _connect(): + interface.is_connected = True + + def _disconnect(): + interface.is_connected = False + + interface.connect.side_effect = _connect + interface.disconnect.side_effect = _disconnect + + # Mock servo and gripper commands + interface.servo.return_value = None + interface.send_gripper_command.return_value = True + + return interface + + +@pytest.fixture +def moveit2_robot(): + interface_mock = _make_moveit2_interface_mock() + + with patch("lerobot.common.robots.moveit2.moveit2.ROS2Interface", return_value=interface_mock): + cfg = ROS2Config() + robot = ROS2Robot(cfg) + yield robot + if robot.is_connected: + robot.disconnect() + + +def test_connect_disconnect(moveit2_robot): + """Test basic connection and disconnection.""" + assert not moveit2_robot.is_connected + + moveit2_robot.connect() + assert moveit2_robot.is_connected + + moveit2_robot.disconnect() + assert not moveit2_robot.is_connected + + +def test_get_observation(moveit2_robot): + """Test getting observations from the robot.""" + moveit2_robot.connect() + obs = moveit2_robot.get_observation() + + # Check that all expected joint positions are in the observation + expected_joints = moveit2_robot.config.ros2_interface.arm_joint_names + [ + moveit2_robot.config.ros2_interface.gripper_joint_name + ] + expected_keys = {f"{joint}.pos" for joint in expected_joints} + + # Only check motor keys since cameras might not be configured + motor_keys = {key for key in obs if key.endswith(".pos")} + assert motor_keys == expected_keys + + # Check that the values match the mocked joint state + for joint in expected_joints: + assert obs[f"{joint}.pos"] == 0.1 + + +def test_send_action(moveit2_robot): + """Test sending action commands to the robot.""" + moveit2_robot.connect() + + action = { + "lienar_x.vel": 0.1, + "linear_y.vel": 0.2, + "linear_z.vel": 0.3, + "angular_x.vel": 0.4, + "angular_y.vel": 0.5, + "angular_z.vel": 0.6, + "gripper.pos": 0.8, + } + + returned_action = moveit2_robot.send_action(action) + assert returned_action == action + + # Verify that the interface methods were called correctly + moveit2_robot.ros2_interface.servo.assert_called_once_with( + linear=(0.1, 0.2, 0.3), angular=(0.4, 0.5, 0.6) + ) + moveit2_robot.ros2_interface.send_gripper_command.assert_called_once_with(0.8) + + +def test_send_action_with_max_relative_target(moveit2_robot): + """Test sending action with safety limits applied.""" + # Configure with a small max relative target for testing + moveit2_robot.config.max_relative_target = 0.1 + moveit2_robot.connect() + + action = { + "lienar_x.vel": 1.0, # Large value that should be clipped + "linear_y.vel": 0.0, + "linear_z.vel": 0.0, + "angular_x.vel": 0.0, + "angular_y.vel": 0.0, + "angular_z.vel": 0.0, + "gripper.pos": 0.5, + } + + returned_action = moveit2_robot.send_action(action) + + # The action should be clipped due to max_relative_target + assert returned_action["lienar_x.vel"] <= 0.1 + assert returned_action["gripper.pos"] <= 0.1 + + +def test_keyboard_action_conversion(moveit2_robot): + """Test conversion from keyboard input to action commands.""" + moveit2_robot.config.action_from_keyboard = True + moveit2_robot.connect() + + pressed_keys = {"w": True, "d": True, "space": True} + + returned_action = moveit2_robot.send_action(pressed_keys) + + # Check that keyboard inputs were converted correctly + assert returned_action["lienar_x.vel"] == 1.0 # 'd' key + assert returned_action["linear_y.vel"] == 1.0 # 'w' key + assert returned_action["linear_z.vel"] == 0.0 + assert returned_action["gripper.pos"] == 1.0 # 'space' key + + +def test_from_keyboard_to_action(moveit2_robot): + """Test the keyboard to action conversion method directly.""" + # Test all movement keys + pressed_keys = { + "w": True, # +linear_y.vel + "s": True, # -linear_y.vel (should cancel with w) + "a": True, # -lienar_x.vel + "d": True, # +lienar_x.vel (should cancel with a) + "q": True, # -linear_z.vel + "e": True, # +linear_z.vel (should cancel with q) + "i": True, # +angular_x.vel + "k": True, # -angular_x.vel (should cancel with i) + "j": True, # -angular_y.vel + "l": True, # +angular_y.vel (should cancel with j) + "u": True, # +angular_z.vel + "o": True, # -angular_z.vel (should cancel with u) + "space": True, # gripper + } + + action = moveit2_robot.from_keyboard_to_action(pressed_keys) + + # All opposing keys should cancel out to 0 + assert action["lienar_x.vel"] == 0.0 + assert action["linear_y.vel"] == 0.0 + assert action["linear_z.vel"] == 0.0 + assert action["angular_x.vel"] == 0.0 + assert action["angular_y.vel"] == 0.0 + assert action["angular_z.vel"] == 0.0 + assert action["gripper.pos"] == 1.0 + + # Test individual keys + single_key_tests = [ + ({"w": True}, {"linear_y.vel": 1.0}), + ({"s": True}, {"linear_y.vel": -1.0}), + ({"a": True}, {"lienar_x.vel": -1.0}), + ({"d": True}, {"lienar_x.vel": 1.0}), + ({"q": True}, {"linear_z.vel": -1.0}), + ({"e": True}, {"linear_z.vel": 1.0}), + ({"i": True}, {"angular_x.vel": -1.0}), + ({"k": True}, {"angular_x.vel": 1.0}), + ({"j": True}, {"angular_y.vel": -1.0}), + ({"l": True}, {"angular_y.vel": 1.0}), + ({"u": True}, {"angular_z.vel": 1.0}), + ({"o": True}, {"angular_z.vel": -1.0}), + ] + + for pressed_keys, expected_non_zero in single_key_tests: + action = moveit2_robot.from_keyboard_to_action(pressed_keys) + for key, expected_value in expected_non_zero.items(): + assert action[key] == expected_value + + +def test_observation_features(moveit2_robot): + """Test that observation features are correctly defined.""" + features = moveit2_robot.observation_features + + # Check that all joint position features are defined + expected_joints = moveit2_robot.config.ros2_interface.arm_joint_names + [ + moveit2_robot.config.ros2_interface.gripper_joint_name + ] + for joint in expected_joints: + assert f"{joint}.pos" in features + assert features[f"{joint}.pos"] is float + + +def test_action_features(moveit2_robot): + """Test that action features are correctly defined.""" + features = moveit2_robot.action_features + + expected_actions = [ + "lienar_x.vel", + "linear_y.vel", + "linear_z.vel", + "angular_x.vel", + "angular_y.vel", + "angular_z.vel", + "gripper.pos", + ] + + for action in expected_actions: + assert action in features + assert features[action] is float + + +def test_calibration_methods(moveit2_robot): + """Test calibration-related methods.""" + # MoveIt2 robot should always be considered calibrated + assert moveit2_robot.is_calibrated + + # Calibrate method should do nothing (no-op) + moveit2_robot.calibrate() # Should not raise any exception + + +def test_configure_method(moveit2_robot): + """Test configure method.""" + # Configure method should do nothing (no-op) + moveit2_robot.configure() # Should not raise any exception + + +def test_error_handling_when_not_connected(moveit2_robot): + """Test that appropriate errors are raised when robot is not connected.""" + from lerobot.common.errors import DeviceNotConnectedError + + # Should raise error when trying to get observation without connection + with pytest.raises(DeviceNotConnectedError): + moveit2_robot.get_observation() + + # Should raise error when trying to send action without connection + with pytest.raises(DeviceNotConnectedError): + moveit2_robot.send_action({"lienar_x.vel": 0.1, "gripper.pos": 0.5}) + + # Should raise error when trying to disconnect without connection + with pytest.raises(DeviceNotConnectedError): + moveit2_robot.disconnect() + + +def test_double_connect_error(moveit2_robot): + """Test that connecting twice raises an error.""" + from lerobot.common.errors import DeviceAlreadyConnectedError + + moveit2_robot.connect() + + with pytest.raises(DeviceAlreadyConnectedError): + moveit2_robot.connect()