-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Open
Labels
questionRequests for clarification or additional informationRequests for clarification or additional informationsimulationMatters involving system simulation or modelingMatters involving system simulation or modeling
Description
First the good news:
This is an interactive gym where you can experiment with pre-trained policies to control the robot in real time.
Here is how to use it:
Double-clickon a body to select it.Ctrl + leftdrag applies a torque to the selected object, resulting in rotation.Ctrl + rightdrag applies a force to the selected object in the (x,z) plane, resulting in translation.Ctrl + Shift + rightdrag applies a force to the selected object in the (x,y) plane.
However, there are a few limitations:
- When you move the cubes, the robot doesn't seem to register the new positions and instead attempts to pick them up from their original locations.
- Only the environment
lerobot/act_aloha_sim_insertion_humanappears to work occasionally. The others either don't function at all or cause the program to crash due to missing attributes that haven't been implemented in the gym.
I'd really appreciate feedback/guidance from the repo maintainers on how to improve this snippet to support more environments and tasks.
file interactive_gym.py:
import gymnasium as gym
import mujoco
import mujoco.viewer
import torch
import importlib
from lerobot.policies.utils import get_device_from_parameters
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.envs.utils import preprocess_observation
from lerobot.utils.utils import get_safe_torch_device
# $ python interactive_gym.py --policy.path=lerobot/act_aloha_sim_insertion_human --env.type=aloha
# $ python interactive_gym.py --policy.path=lerobot/act_aloha_sim_transfer_cube_human --env.type=aloha
@parser.wrap()
def make_env_and_policy(cfg: EvalPipelineConfig):
package_name = f"gym_{cfg.env.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.type}]'`")
raise e
gym_handle = f"{package_name}/{cfg.env.task}"
env = gym.make(gym_handle, disable_env_checker=True, **cfg.env.gym_kwargs)
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
policy.eval()
policy.reset()
return env, policy
def main(env, policy):
device = get_device_from_parameters(policy)
viewer = mujoco.viewer.launch_passive(env.unwrapped.model, env.unwrapped.data)
observation, info = env.reset(seed=42)
viewer.sync()
for i in range(40000):
observation = preprocess_observation(observation)
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
if hasattr(env, "task_description"):
observation["task"] = env.unwrapped.task_description
elif hasattr(env, "task"):
observation["task"] = env.unwrapped.task
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
observation["task"] = ""
with torch.inference_mode():
action = policy.select_action(observation)
# Convert to CPU / numpy.
action = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
#observation, reward, terminated, truncated, info = env.step(action)
observation, reward, terminated, truncated, info = env.step(action[0])
viewer.sync()
if terminated or truncated:
observation, info = env.reset()
viewer.sync()
if i % 100 == 0:
print(i)
viewer.close()
env.close()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
env, policy = make_env_and_policy()
main(env, policy)Metadata
Metadata
Assignees
Labels
questionRequests for clarification or additional informationRequests for clarification or additional informationsimulationMatters involving system simulation or modelingMatters involving system simulation or modeling