5
5
from torchvision .io import (
6
6
_read_video_timestamps_from_file ,
7
7
_read_video_from_file ,
8
+ _probe_video_from_file
8
9
)
9
10
from torchvision .io import read_video_timestamps , read_video
10
11
@@ -71,11 +72,11 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1
71
72
frame_rate = None , _precomputed_metadata = None , num_workers = 0 ,
72
73
_video_width = 0 , _video_height = 0 , _video_min_dimension = 0 ,
73
74
_audio_samples = 0 ):
74
- from torchvision import get_video_backend
75
75
76
76
self .video_paths = video_paths
77
77
self .num_workers = num_workers
78
- self ._backend = get_video_backend ()
78
+
79
+ # these options are not valid for pyav backend
79
80
self ._video_width = _video_width
80
81
self ._video_height = _video_height
81
82
self ._video_min_dimension = _video_min_dimension
@@ -89,87 +90,60 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1
89
90
90
91
def _compute_frame_pts (self ):
91
92
self .video_pts = []
92
- if self ._backend == "pyav" :
93
- self .video_fps = []
94
- else :
95
- self .info = []
93
+ self .video_fps = []
96
94
97
95
# strategy: use a DataLoader to parallelize read_video_timestamps
98
96
# so need to create a dummy dataset first
99
97
class DS (object ):
100
- def __init__ (self , x , _backend ):
98
+ def __init__ (self , x ):
101
99
self .x = x
102
- self ._backend = _backend
103
100
104
101
def __len__ (self ):
105
102
return len (self .x )
106
103
107
104
def __getitem__ (self , idx ):
108
- if self ._backend == "pyav" :
109
- return read_video_timestamps (self .x [idx ])
110
- else :
111
- return _read_video_timestamps_from_file (self .x [idx ])
105
+ return read_video_timestamps (self .x [idx ])
112
106
113
107
import torch .utils .data
114
108
dl = torch .utils .data .DataLoader (
115
- DS (self .video_paths , self . _backend ),
109
+ DS (self .video_paths ),
116
110
batch_size = 16 ,
117
111
num_workers = self .num_workers ,
118
112
collate_fn = lambda x : x )
119
113
120
114
with tqdm (total = len (dl )) as pbar :
121
115
for batch in dl :
122
116
pbar .update (1 )
123
- if self ._backend == "pyav" :
124
- clips , fps = list (zip (* batch ))
125
- clips = [torch .as_tensor (c ) for c in clips ]
126
- self .video_pts .extend (clips )
127
- self .video_fps .extend (fps )
128
- else :
129
- video_pts , _audio_pts , info = list (zip (* batch ))
130
- video_pts = [torch .as_tensor (c ) for c in video_pts ]
131
- self .video_pts .extend (video_pts )
132
- self .info .extend (info )
117
+ clips , fps = list (zip (* batch ))
118
+ clips = [torch .as_tensor (c ) for c in clips ]
119
+ self .video_pts .extend (clips )
120
+ self .video_fps .extend (fps )
133
121
134
122
def _init_from_metadata (self , metadata ):
135
123
self .video_paths = metadata ["video_paths" ]
136
124
assert len (self .video_paths ) == len (metadata ["video_pts" ])
137
125
self .video_pts = metadata ["video_pts" ]
138
-
139
- if self ._backend == "pyav" :
140
- assert len (self .video_paths ) == len (metadata ["video_fps" ])
141
- self .video_fps = metadata ["video_fps" ]
142
- else :
143
- assert len (self .video_paths ) == len (metadata ["info" ])
144
- self .info = metadata ["info" ]
126
+ assert len (self .video_paths ) == len (metadata ["video_fps" ])
127
+ self .video_fps = metadata ["video_fps" ]
145
128
146
129
@property
147
130
def metadata (self ):
148
131
_metadata = {
149
132
"video_paths" : self .video_paths ,
150
133
"video_pts" : self .video_pts ,
134
+ "video_fps" : self .video_fps
151
135
}
152
- if self ._backend == "pyav" :
153
- _metadata .update ({"video_fps" : self .video_fps })
154
- else :
155
- _metadata .update ({"info" : self .info })
156
136
return _metadata
157
137
158
138
def subset (self , indices ):
159
139
video_paths = [self .video_paths [i ] for i in indices ]
160
140
video_pts = [self .video_pts [i ] for i in indices ]
161
- if self ._backend == "pyav" :
162
- video_fps = [self .video_fps [i ] for i in indices ]
163
- else :
164
- info = [self .info [i ] for i in indices ]
141
+ video_fps = [self .video_fps [i ] for i in indices ]
165
142
metadata = {
166
143
"video_paths" : video_paths ,
167
144
"video_pts" : video_pts ,
145
+ "video_fps" : video_fps
168
146
}
169
- if self ._backend == "pyav" :
170
- metadata .update ({"video_fps" : video_fps })
171
- else :
172
- metadata .update ({"info" : info })
173
147
return type (self )(video_paths , self .num_frames , self .step , self .frame_rate ,
174
148
_precomputed_metadata = metadata , num_workers = self .num_workers ,
175
149
_video_width = self ._video_width ,
@@ -212,22 +186,10 @@ def compute_clips(self, num_frames, step, frame_rate=None):
212
186
self .frame_rate = frame_rate
213
187
self .clips = []
214
188
self .resampling_idxs = []
215
- if self ._backend == "pyav" :
216
- for video_pts , fps in zip (self .video_pts , self .video_fps ):
217
- clips , idxs = self .compute_clips_for_video (video_pts , num_frames , step , fps , frame_rate )
218
- self .clips .append (clips )
219
- self .resampling_idxs .append (idxs )
220
- else :
221
- for video_pts , info in zip (self .video_pts , self .info ):
222
- if "video_fps" in info :
223
- clips , idxs = self .compute_clips_for_video (
224
- video_pts , num_frames , step , info ["video_fps" ], frame_rate )
225
- self .clips .append (clips )
226
- self .resampling_idxs .append (idxs )
227
- else :
228
- # properly handle the cases where video decoding fails
229
- self .clips .append (torch .zeros (0 , num_frames , dtype = torch .int64 ))
230
- self .resampling_idxs .append (torch .zeros (0 , dtype = torch .int64 ))
189
+ for video_pts , fps in zip (self .video_pts , self .video_fps ):
190
+ clips , idxs = self .compute_clips_for_video (video_pts , num_frames , step , fps , frame_rate )
191
+ self .clips .append (clips )
192
+ self .resampling_idxs .append (idxs )
231
193
clip_lengths = torch .as_tensor ([len (v ) for v in self .clips ])
232
194
self .cumulative_sizes = clip_lengths .cumsum (0 ).tolist ()
233
195
@@ -287,12 +249,28 @@ def get_clip(self, idx):
287
249
video_path = self .video_paths [video_idx ]
288
250
clip_pts = self .clips [video_idx ][clip_idx ]
289
251
290
- if self ._backend == "pyav" :
252
+ from torchvision import get_video_backend
253
+ backend = get_video_backend ()
254
+
255
+ if backend == "pyav" :
256
+ # check for invalid options
257
+ if self ._video_width != 0 :
258
+ raise ValueError ("pyav backend doesn't support _video_width != 0" )
259
+ if self ._video_height != 0 :
260
+ raise ValueError ("pyav backend doesn't support _video_height != 0" )
261
+ if self ._video_min_dimension != 0 :
262
+ raise ValueError ("pyav backend doesn't support _video_min_dimension != 0" )
263
+ if self ._audio_samples != 0 :
264
+ raise ValueError ("pyav backend doesn't support _audio_samples != 0" )
265
+
266
+ if backend == "pyav" :
291
267
start_pts = clip_pts [0 ].item ()
292
268
end_pts = clip_pts [- 1 ].item ()
293
269
video , audio , info = read_video (video_path , start_pts , end_pts )
294
270
else :
295
- info = self .info [video_idx ]
271
+ info = _probe_video_from_file (video_path )
272
+ video_fps = info ["video_fps" ]
273
+ audio_fps = None
296
274
297
275
video_start_pts = clip_pts [0 ].item ()
298
276
video_end_pts = clip_pts [- 1 ].item ()
@@ -313,6 +291,7 @@ def get_clip(self, idx):
313
291
info ["audio_timebase" ],
314
292
math .ceil ,
315
293
)
294
+ audio_fps = info ["audio_sample_rate" ]
316
295
video , audio , info = _read_video_from_file (
317
296
video_path ,
318
297
video_width = self ._video_width ,
@@ -324,6 +303,11 @@ def get_clip(self, idx):
324
303
audio_pts_range = (audio_start_pts , audio_end_pts ),
325
304
audio_timebase = audio_timebase ,
326
305
)
306
+
307
+ info = {"video_fps" : video_fps }
308
+ if audio_fps is not None :
309
+ info ["audio_fps" ] = audio_fps
310
+
327
311
if self .frame_rate is not None :
328
312
resampling_idx = self .resampling_idxs [video_idx ][clip_idx ]
329
313
if isinstance (resampling_idx , torch .Tensor ):
0 commit comments