diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 081f332b..4ca9b3a6 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Literal, Optional, Tuple, Union -from torch import device, Tensor +from torch import device as torch_device, Tensor from torchcodec import Frame, FrameBatch from torchcodec.decoders import _core as core @@ -72,7 +72,7 @@ def __init__( stream_index: Optional[int] = None, dimension_order: Literal["NCHW", "NHWC"] = "NCHW", num_ffmpeg_threads: int = 1, - device: Optional[Union[str, device]] = "cpu", + device: Optional[Union[str, torch_device]] = "cpu", seek_mode: Literal["exact", "approximate"] = "exact", ): allowed_seek_modes = ("exact", "approximate") @@ -94,6 +94,9 @@ def __init__( if num_ffmpeg_threads is None: raise ValueError(f"{num_ffmpeg_threads = } should be an int.") + if isinstance(device, torch_device): + device = str(device) + core.add_video_stream( self._decoder, stream_index=stream_index, diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index cc47e116..4115553f 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -285,6 +285,11 @@ def test_getitem_slice(self, device, seek_mode): # See https://github.com/pytorch/torchcodec/issues/428 assert_frames_equal(sliced, ref) + def test_device_instance(self): + # Non-regression test for https://github.com/pytorch/torchcodec/issues/602 + decoder = VideoDecoder(NASA_VIDEO.path, device=torch.device("cpu")) + assert isinstance(decoder.metadata, VideoStreamMetadata) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_fails(self, device, seek_mode):