Skip to content

Data Loading Performance Bottleneck: set_format(type='torch') Incompatible with select/stack, Unable to Accelerate Large-Scale Training #1282

@elapstjtl

Description

@elapstjtl

System Info

`lerobot` version: 0.1.0
- Platform: Linux-5.15.0-139-generic-x86_64-with-glibc2.31
- Python version: 3.10.13
- Huggingface_hub version: 0.33.0
- Dataset version: 3.6.0
- Numpy version: 2.2.6
- PyTorch version (GPU?): 2.7.1+cu126 (True)
- Cuda version: 12060
- Using GPU in script?: <4070tisuper>

Information

  • One of the scripts in the examples/ folder of LeRobot
  • My own task or dataset (give details below)

Reproduction

Steps to Reproduce

  1. Use LeRobotDataset to load a large dataset (e.g., tens of thousands of frames).
  2. Try to accelerate data loading by calling dataset.hf_dataset.set_format(type='torch') so that DataLoader returns tensors directly, reducing Python object deepcopy overhead.
  3. After adding set_format(type='torch'), the following error occurs during training:
TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor
  File ".../lerobot_dataset.py", line 697, in <dictcomp>
    key: torch.stack(self.hf_dataset.select(q_idx)[key])
dataset, input_features, output_features = load_dataset(task_name, task_config)

def load_dataset(task_name: str, task_config: Dict[str, Any]) -> Tuple[LeRobotDataset, Dict[str, Any], Dict[str, Any]]:

	dataset_metadata = LeRobotDatasetMetadata(f"lerobot/{task_name}")
	features = dataset_to_policy_features(dataset_metadata.features)
	output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
	input_features = {key: ft for key, ft in features.items() if key not in output_features}

	delta_timestamps = task_config["delta_timestamps"]
	dataset = LeRobotDataset(
		repo_id=f"lerobot/{task_name}",  # 数据集仓库ID,格式为"组织名/数据集名"
		root="",                         # 数据集本地存储根目录,空字符串表示使用默认路径
		episodes=None,                   # 指定要加载的episodes,None表示加载所有episodes
		image_transforms=None,           # 图像预处理转换,None表示不进行预处理
		delta_timestamps=delta_timestamps,  # 时间戳间隔,用于控制数据采样频率
		tolerance_s=1e-4,               # 时间戳匹配的容差值(秒),用于处理时间戳不完全匹配的情况
		revision=None,                   # 数据集版本,None表示使用最新版本
		force_cache_sync=False,          # 是否强制同步缓存,False表示只在必要时更新
		download_videos=True,            # 是否下载视频数据,True表示下载
		video_backend=None,              # 视频处理后端,None表示使用默认后端
	)
	
	dataset.hf_dataset.set_format(type='torch')
	
	return dataset, input_features, output_features

Expected behavior

Expected behavior

I expect that after calling dataset.hf_dataset.set_format(type='torch'),

  • DataLoader can directly return PyTorch tensors for each sample,
  • There are no type errors or compatibility issues with internal select/stack operations,
  • Data loading becomes significantly faster, with reduced Python object copying and higher GPU utilization,
  • Large-scale training can proceed efficiently without being bottlenecked by data loading.

Metadata

Metadata

Assignees

No one assigned

    Labels

    datasetIssues regarding data inputs, processing, or datasetsperformanceIssues aimed at improving speed or resource usage

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions