Skip to content

Commit 306fe48

Browse files
committed
Unify video metadata in VideoClips (#1527)
* Unify video metadata in VideoClips * Bugfix * Make tests a bit more robust * Fix merge conflicts for cherry-pick for 0.4.2
1 parent bafc3dc commit 306fe48

File tree

5 files changed

+67
-61
lines changed

5 files changed

+67
-61
lines changed

test/test_datasets_video_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def test_unfold(self):
5858
self.assertTrue(r.equal(expected))
5959

6060
@unittest.skipIf(not io.video._av_available(), "this test requires av")
61+
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
6162
def test_video_clips(self):
6263
with get_list_of_videos(num_videos=3) as video_list:
6364
video_clips = VideoClips(video_list, 5, 5)
@@ -82,6 +83,7 @@ def test_video_clips(self):
8283
self.assertEqual(clip_idx, c_idx)
8384

8485
@unittest.skipIf(not io.video._av_available(), "this test requires av")
86+
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
8587
def test_video_clips_custom_fps(self):
8688
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
8789
num_frames = 4
@@ -91,6 +93,7 @@ def test_video_clips_custom_fps(self):
9193
video, audio, info, video_idx = video_clips.get_clip(i)
9294
self.assertEqual(video.shape[0], num_frames)
9395
self.assertEqual(info["video_fps"], fps)
96+
self.assertEqual(info, {"video_fps": fps})
9497
# TODO add tests checking that the content is right
9598

9699
def test_compute_clips_for_video(self):

test/test_datasets_video_utils_opt.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import unittest
2+
from torchvision import set_video_backend
3+
import test_datasets_video_utils
4+
5+
6+
set_video_backend('video_reader')
7+
8+
9+
if __name__ == '__main__':
10+
suite = unittest.TestLoader().loadTestsFromModule(test_datasets_video_utils)
11+
unittest.TextTestRunner(verbosity=1).run(suite)

test/test_io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def test_read_video_pts_unit_sec(self):
181181

182182
self.assertTrue(data.equal(lv))
183183
self.assertEqual(info["video_fps"], 5)
184+
self.assertEqual(info, {"video_fps": 5})
184185

185186
def test_read_timestamps_pts_unit_sec(self):
186187
with temp_video(10, 300, 300, 5) as (f_name, data):

torchvision/datasets/video_utils.py

Lines changed: 44 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torchvision.io import (
66
_read_video_timestamps_from_file,
77
_read_video_from_file,
8+
_probe_video_from_file
89
)
910
from torchvision.io import read_video_timestamps, read_video
1011

@@ -71,11 +72,11 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1
7172
frame_rate=None, _precomputed_metadata=None, num_workers=0,
7273
_video_width=0, _video_height=0, _video_min_dimension=0,
7374
_audio_samples=0):
74-
from torchvision import get_video_backend
7575

7676
self.video_paths = video_paths
7777
self.num_workers = num_workers
78-
self._backend = get_video_backend()
78+
79+
# these options are not valid for pyav backend
7980
self._video_width = _video_width
8081
self._video_height = _video_height
8182
self._video_min_dimension = _video_min_dimension
@@ -89,87 +90,60 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1
8990

9091
def _compute_frame_pts(self):
9192
self.video_pts = []
92-
if self._backend == "pyav":
93-
self.video_fps = []
94-
else:
95-
self.info = []
93+
self.video_fps = []
9694

9795
# strategy: use a DataLoader to parallelize read_video_timestamps
9896
# so need to create a dummy dataset first
9997
class DS(object):
100-
def __init__(self, x, _backend):
98+
def __init__(self, x):
10199
self.x = x
102-
self._backend = _backend
103100

104101
def __len__(self):
105102
return len(self.x)
106103

107104
def __getitem__(self, idx):
108-
if self._backend == "pyav":
109-
return read_video_timestamps(self.x[idx])
110-
else:
111-
return _read_video_timestamps_from_file(self.x[idx])
105+
return read_video_timestamps(self.x[idx])
112106

113107
import torch.utils.data
114108
dl = torch.utils.data.DataLoader(
115-
DS(self.video_paths, self._backend),
109+
DS(self.video_paths),
116110
batch_size=16,
117111
num_workers=self.num_workers,
118112
collate_fn=lambda x: x)
119113

120114
with tqdm(total=len(dl)) as pbar:
121115
for batch in dl:
122116
pbar.update(1)
123-
if self._backend == "pyav":
124-
clips, fps = list(zip(*batch))
125-
clips = [torch.as_tensor(c) for c in clips]
126-
self.video_pts.extend(clips)
127-
self.video_fps.extend(fps)
128-
else:
129-
video_pts, _audio_pts, info = list(zip(*batch))
130-
video_pts = [torch.as_tensor(c) for c in video_pts]
131-
self.video_pts.extend(video_pts)
132-
self.info.extend(info)
117+
clips, fps = list(zip(*batch))
118+
clips = [torch.as_tensor(c) for c in clips]
119+
self.video_pts.extend(clips)
120+
self.video_fps.extend(fps)
133121

134122
def _init_from_metadata(self, metadata):
135123
self.video_paths = metadata["video_paths"]
136124
assert len(self.video_paths) == len(metadata["video_pts"])
137125
self.video_pts = metadata["video_pts"]
138-
139-
if self._backend == "pyav":
140-
assert len(self.video_paths) == len(metadata["video_fps"])
141-
self.video_fps = metadata["video_fps"]
142-
else:
143-
assert len(self.video_paths) == len(metadata["info"])
144-
self.info = metadata["info"]
126+
assert len(self.video_paths) == len(metadata["video_fps"])
127+
self.video_fps = metadata["video_fps"]
145128

146129
@property
147130
def metadata(self):
148131
_metadata = {
149132
"video_paths": self.video_paths,
150133
"video_pts": self.video_pts,
134+
"video_fps": self.video_fps
151135
}
152-
if self._backend == "pyav":
153-
_metadata.update({"video_fps": self.video_fps})
154-
else:
155-
_metadata.update({"info": self.info})
156136
return _metadata
157137

158138
def subset(self, indices):
159139
video_paths = [self.video_paths[i] for i in indices]
160140
video_pts = [self.video_pts[i] for i in indices]
161-
if self._backend == "pyav":
162-
video_fps = [self.video_fps[i] for i in indices]
163-
else:
164-
info = [self.info[i] for i in indices]
141+
video_fps = [self.video_fps[i] for i in indices]
165142
metadata = {
166143
"video_paths": video_paths,
167144
"video_pts": video_pts,
145+
"video_fps": video_fps
168146
}
169-
if self._backend == "pyav":
170-
metadata.update({"video_fps": video_fps})
171-
else:
172-
metadata.update({"info": info})
173147
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate,
174148
_precomputed_metadata=metadata, num_workers=self.num_workers,
175149
_video_width=self._video_width,
@@ -212,22 +186,10 @@ def compute_clips(self, num_frames, step, frame_rate=None):
212186
self.frame_rate = frame_rate
213187
self.clips = []
214188
self.resampling_idxs = []
215-
if self._backend == "pyav":
216-
for video_pts, fps in zip(self.video_pts, self.video_fps):
217-
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
218-
self.clips.append(clips)
219-
self.resampling_idxs.append(idxs)
220-
else:
221-
for video_pts, info in zip(self.video_pts, self.info):
222-
if "video_fps" in info:
223-
clips, idxs = self.compute_clips_for_video(
224-
video_pts, num_frames, step, info["video_fps"], frame_rate)
225-
self.clips.append(clips)
226-
self.resampling_idxs.append(idxs)
227-
else:
228-
# properly handle the cases where video decoding fails
229-
self.clips.append(torch.zeros(0, num_frames, dtype=torch.int64))
230-
self.resampling_idxs.append(torch.zeros(0, dtype=torch.int64))
189+
for video_pts, fps in zip(self.video_pts, self.video_fps):
190+
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
191+
self.clips.append(clips)
192+
self.resampling_idxs.append(idxs)
231193
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
232194
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
233195

@@ -287,12 +249,28 @@ def get_clip(self, idx):
287249
video_path = self.video_paths[video_idx]
288250
clip_pts = self.clips[video_idx][clip_idx]
289251

290-
if self._backend == "pyav":
252+
from torchvision import get_video_backend
253+
backend = get_video_backend()
254+
255+
if backend == "pyav":
256+
# check for invalid options
257+
if self._video_width != 0:
258+
raise ValueError("pyav backend doesn't support _video_width != 0")
259+
if self._video_height != 0:
260+
raise ValueError("pyav backend doesn't support _video_height != 0")
261+
if self._video_min_dimension != 0:
262+
raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
263+
if self._audio_samples != 0:
264+
raise ValueError("pyav backend doesn't support _audio_samples != 0")
265+
266+
if backend == "pyav":
291267
start_pts = clip_pts[0].item()
292268
end_pts = clip_pts[-1].item()
293269
video, audio, info = read_video(video_path, start_pts, end_pts)
294270
else:
295-
info = self.info[video_idx]
271+
info = _probe_video_from_file(video_path)
272+
video_fps = info["video_fps"]
273+
audio_fps = None
296274

297275
video_start_pts = clip_pts[0].item()
298276
video_end_pts = clip_pts[-1].item()
@@ -313,6 +291,7 @@ def get_clip(self, idx):
313291
info["audio_timebase"],
314292
math.ceil,
315293
)
294+
audio_fps = info["audio_sample_rate"]
316295
video, audio, info = _read_video_from_file(
317296
video_path,
318297
video_width=self._video_width,
@@ -324,6 +303,11 @@ def get_clip(self, idx):
324303
audio_pts_range=(audio_start_pts, audio_end_pts),
325304
audio_timebase=audio_timebase,
326305
)
306+
307+
info = {"video_fps": video_fps}
308+
if audio_fps is not None:
309+
info["audio_fps"] = audio_fps
310+
327311
if self.frame_rate is not None:
328312
resampling_idx = self.resampling_idxs[video_idx][clip_idx]
329313
if isinstance(resampling_idx, torch.Tensor):

torchvision/io/_video_opt.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def get_pts(time_base):
383383
audio_timebase = info['audio_timebase']
384384
audio_pts_range = get_pts(audio_timebase)
385385

386-
return _read_video_from_file(
386+
vframes, aframes, info = _read_video_from_file(
387387
filename,
388388
read_video_stream=True,
389389
video_pts_range=video_pts_range,
@@ -392,6 +392,13 @@ def get_pts(time_base):
392392
audio_pts_range=audio_pts_range,
393393
audio_timebase=audio_timebase,
394394
)
395+
_info = {}
396+
if has_video:
397+
_info['video_fps'] = info['video_fps']
398+
if has_audio:
399+
_info['audio_fps'] = info['audio_sample_rate']
400+
401+
return vframes, aframes, _info
395402

396403

397404
def _read_video_timestamps(filename, pts_unit='pts'):

0 commit comments

Comments
 (0)