Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/video/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
### Decoding parameters
**Decoder**
We tested two video decoding backends from torchvision:
- `pyav` (default)
- `pyav`
- `video_reader` (requires to build torchvision from source)

**Requested timestamps**
Expand Down
7 changes: 4 additions & 3 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
VideoFrame,
decode_video_frames,
encode_video_frames,
get_safe_default_codec,
get_video_info,
)
from lerobot.common.robot_devices.robots.utils import Robot
Expand Down Expand Up @@ -462,7 +463,7 @@ def __init__(
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
"""
super().__init__()
Expand All @@ -473,7 +474,7 @@ def __init__(
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "torchcodec"
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None

# Unused attributes
Expand Down Expand Up @@ -1027,7 +1028,7 @@ def create(
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj


Expand Down
24 changes: 21 additions & 3 deletions lerobot/common/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import json
import logging
import subprocess
Expand All @@ -27,14 +28,23 @@
import torchvision
from datasets.features.features import register_feature
from PIL import Image
from torchcodec.decoders import VideoDecoder


def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"


def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "torchcodec",
backend: str | None = None,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Expand All @@ -43,13 +53,15 @@ def decode_video_frames(
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..

Returns:
torch.Tensor: Decoded frames.

Currently supports torchcodec on cpu and pyav.
"""
if backend is None:
backend = get_safe_default_codec()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
Expand Down Expand Up @@ -173,6 +185,12 @@ def decode_video_frames_torchcodec(
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""

if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")

# initialize video decoder
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
loaded_frames = []
Expand Down
3 changes: 2 additions & 1 deletion lerobot/configs/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
policies, # noqa: F401
)
from lerobot.common.datasets.transforms import ImageTransformsConfig
from lerobot.common.datasets.video_utils import get_safe_default_codec


@dataclass
Expand All @@ -35,7 +36,7 @@ class DatasetConfig:
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = "pyav"
video_backend: str = field(default_factory=get_safe_default_codec)


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ dependencies = [
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"torch>=2.2.1",
"torchcodec>=0.2.1",
"torchcodec>=0.2.1 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0",
Expand Down
Loading