Skip to content

Commit 1bdda8c

Browse files
bryant1410pmeier
andauthored
Add a warning if any clip can't be obtained from a video in VideoClips. (#2513)
* Add a warning if a clip can't be get from a video in VideoClips * Update torchvision/datasets/video_utils.py Co-authored-by: Philip Meier <[email protected]> * Add a test Co-authored-by: Philip Meier <[email protected]>
1 parent 4521f6d commit 1bdda8c

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

test/test_datasets_video_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ def test_compute_clips_for_video(self):
119119
self.assertTrue(clips.equal(idxs))
120120
self.assertTrue(idxs.flatten().equal(resampled_idxs))
121121

122+
# case 3: frames aren't enough for a clip
123+
num_frames = 32
124+
orig_fps = 30
125+
new_fps = 13
126+
with self.assertWarns(UserWarning):
127+
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames,
128+
orig_fps, new_fps)
129+
self.assertEqual(len(clips), 0)
130+
self.assertEqual(len(idxs), 0)
131+
122132

123133
if __name__ == '__main__':
124134
unittest.main()

torchvision/datasets/video_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import bisect
22
import math
3+
import warnings
34
from fractions import Fraction
45
from typing import List
56

@@ -204,6 +205,9 @@ def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
204205
)
205206
video_pts = video_pts[idxs]
206207
clips = unfold(video_pts, num_frames, step)
208+
if not clips.numel():
209+
warnings.warn("There aren't enough frames in the current video to get a clip for the given clip length and "
210+
"frames between clips. The video (and potentially others) will be skipped.")
207211
if isinstance(idxs, slice):
208212
idxs = [idxs] * len(clips)
209213
else:

0 commit comments

Comments
 (0)