Skip to content

Commit c0e5b49

Browse files
imstevenpmworkAdilZouitine
authored andcommitted
fix(codec): hot-fix for default codec in linux arm platforms (#868)
1 parent 95f02be commit c0e5b49

File tree

5 files changed

+29
-9
lines changed

5 files changed

+29
-9
lines changed

benchmarks/video/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
5151
### Decoding parameters
5252
**Decoder**
5353
We tested two video decoding backends from torchvision:
54-
- `pyav` (default)
54+
- `pyav`
5555
- `video_reader` (requires to build torchvision from source)
5656

5757
**Requested timestamps**

lerobot/common/datasets/lerobot_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
VideoFrame,
7070
decode_video_frames,
7171
encode_video_frames,
72+
get_safe_default_codec,
7273
get_video_info,
7374
)
7475
from lerobot.common.robot_devices.robots.utils import Robot
@@ -462,7 +463,7 @@ def __init__(
462463
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
463464
video files are already present on local disk, they won't be downloaded again. Defaults to
464465
True.
465-
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
466+
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
466467
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
467468
"""
468469
super().__init__()
@@ -473,7 +474,7 @@ def __init__(
473474
self.episodes = episodes
474475
self.tolerance_s = tolerance_s
475476
self.revision = revision if revision else CODEBASE_VERSION
476-
self.video_backend = video_backend if video_backend else "torchcodec"
477+
self.video_backend = video_backend if video_backend else get_safe_default_codec()
477478
self.delta_indices = None
478479

479480
# Unused attributes
@@ -1027,7 +1028,7 @@ def create(
10271028
obj.delta_timestamps = None
10281029
obj.delta_indices = None
10291030
obj.episode_data_index = None
1030-
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
1031+
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
10311032
return obj
10321033

10331034

lerobot/common/datasets/video_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import importlib
1617
import json
1718
import logging
1819
import subprocess
@@ -27,14 +28,23 @@
2728
import torchvision
2829
from datasets.features.features import register_feature
2930
from PIL import Image
30-
from torchcodec.decoders import VideoDecoder
31+
32+
33+
def get_safe_default_codec():
34+
if importlib.util.find_spec("torchcodec"):
35+
return "torchcodec"
36+
else:
37+
logging.warning(
38+
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
39+
)
40+
return "pyav"
3141

3242

3343
def decode_video_frames(
3444
video_path: Path | str,
3545
timestamps: list[float],
3646
tolerance_s: float,
37-
backend: str = "torchcodec",
47+
backend: str | None = None,
3848
) -> torch.Tensor:
3949
"""
4050
Decodes video frames using the specified backend.
@@ -43,13 +53,15 @@ def decode_video_frames(
4353
video_path (Path): Path to the video file.
4454
timestamps (list[float]): List of timestamps to extract frames.
4555
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
46-
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
56+
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
4757
4858
Returns:
4959
torch.Tensor: Decoded frames.
5060
5161
Currently supports torchcodec on cpu and pyav.
5262
"""
63+
if backend is None:
64+
backend = get_safe_default_codec()
5365
if backend == "torchcodec":
5466
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
5567
elif backend in ["pyav", "video_reader"]:
@@ -173,6 +185,12 @@ def decode_video_frames_torchcodec(
173185
and all subsequent frames until reaching the requested frame. The number of key frames in a video
174186
can be adjusted during encoding to take into account decoding time and video size in bytes.
175187
"""
188+
189+
if importlib.util.find_spec("torchcodec"):
190+
from torchcodec.decoders import VideoDecoder
191+
else:
192+
raise ImportError("torchcodec is required but not available.")
193+
176194
# initialize video decoder
177195
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
178196
loaded_frames = []

lerobot/configs/default.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
policies, # noqa: F401
2121
)
2222
from lerobot.common.datasets.transforms import ImageTransformsConfig
23+
from lerobot.common.datasets.video_utils import get_safe_default_codec
2324

2425

2526
@dataclass
@@ -35,7 +36,7 @@ class DatasetConfig:
3536
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
3637
revision: str | None = None
3738
use_imagenet_stats: bool = True
38-
video_backend: str = "pyav"
39+
video_backend: str = field(default_factory=get_safe_default_codec)
3940

4041

4142
@dataclass

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ dependencies = [
6969
"rerun-sdk>=0.21.0",
7070
"termcolor>=2.4.0",
7171
"torch>=2.2.1",
72-
"torchcodec>=0.2.1",
72+
"torchcodec>=0.2.1 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')",
7373
"torchvision>=0.21.0",
7474
"wandb>=0.16.3",
7575
"zarr>=2.17.0",

0 commit comments

Comments
 (0)