Skip to content

Commit 3bfdb42

Browse files
bjuncekdatumboxfmassa
authored
Implementing multithreaded video decoding (#3389)
* multithreading allowed in stream codec context * numThreads is passed as a decoder parameter. At this stage code should be unchanged * enabling multithreading in videoReader API * moving defaults to header files * replace long with int64_t because torchscript * docstring for Num threads * Enable codec-related heuristics as defaults * Update torchvision/csrc/io/decoder/stream.cpp Co-authored-by: Vasilis Vryniotis <[email protected]> * Fixing build errors * minor docs * Linting * updating defaults for the C++ function calls to be single threaded * adding special case for single threaded stuff Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Francisco Massa <[email protected]>
1 parent 38bb270 commit 3bfdb42

File tree

7 files changed

+82
-19
lines changed

7 files changed

+82
-19
lines changed

torchvision/csrc/io/decoder/decoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ bool Decoder::openStreams(std::vector<DecoderMetadata>* metadata) {
432432
it->format,
433433
params_.loggingUuid);
434434
CHECK(stream);
435-
if (stream->openCodec(metadata) < 0) {
435+
if (stream->openCodec(metadata, params_.numThreads) < 0) {
436436
LOG(ERROR) << "uuid=" << params_.loggingUuid
437437
<< " open codec failed, stream_idx=" << i;
438438
return false;

torchvision/csrc/io/decoder/defs.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ struct DecoderParameters {
194194
bool preventStaleness{true};
195195
// seek tolerated accuracy (us)
196196
double seekAccuracy{1000000.0};
197+
// Allow multithreaded decoding for numThreads > 1;
198+
// 0 numThreads=0 sets up sensible defaults
199+
int numThreads{1};
197200
// what media types should be processed, default none
198201
std::set<MediaFormat> formats;
199202

torchvision/csrc/io/decoder/stream.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "stream.h"
2+
#include <ATen/Parallel.h>
23
#include <c10/util/Logging.h>
4+
#include <stdio.h>
5+
#include <string.h>
36
#include "util.h"
47

58
namespace ffmpeg {
@@ -28,7 +31,7 @@ AVCodec* Stream::findCodec(AVCodecParameters* params) {
2831
return avcodec_find_decoder(params->codec_id);
2932
}
3033

31-
int Stream::openCodec(std::vector<DecoderMetadata>* metadata) {
34+
int Stream::openCodec(std::vector<DecoderMetadata>* metadata, int num_threads) {
3235
AVStream* steam = inputCtx_->streams[format_.stream];
3336

3437
AVCodec* codec = findCodec(steam->codecpar);
@@ -53,6 +56,49 @@ int Stream::openCodec(std::vector<DecoderMetadata>* metadata) {
5356
return ret;
5457
}
5558

59+
// multithreading heuristics
60+
int max_threads = at::get_num_threads();
61+
// first a safety check
62+
if (num_threads > max_threads) {
63+
num_threads = max_threads;
64+
}
65+
66+
if (num_threads > 0) {
67+
if (num_threads > 1) {
68+
codecCtx_->active_thread_type = 1;
69+
}
70+
// if user defined, respect that
71+
codecCtx_->thread_count = num_threads;
72+
73+
} else {
74+
// otherwise set sensible defaults
75+
// with the special case for the different MPEG4 codecs
76+
// that don't have threading context functions
77+
// TODO: potentially automate this using native c++ function lookups
78+
if (strcmp(codecCtx_->codec->name, "mpeg4") == 0 &&
79+
codecCtx_->codec_type == 0) {
80+
if (codecCtx_->codec_tag == 1684633208) {
81+
codecCtx_->thread_count = (8 <= max_threads) ? 8 : max_threads;
82+
codecCtx_->thread_type = 1;
83+
} else {
84+
codecCtx_->thread_count = (2 <= max_threads) ? 2 : max_threads;
85+
codecCtx_->thread_type = 2;
86+
}
87+
} else {
88+
// otherwise default to multithreading
89+
codecCtx_->thread_count = (8 <= max_threads) ? 8 : max_threads;
90+
codecCtx_->active_thread_type = 1;
91+
}
92+
}
93+
94+
// print codec type and number of threads
95+
LOG(INFO) << "Codec " << codecCtx_->codec->long_name
96+
<< " Codec id: " << codecCtx_->codec_id
97+
<< " Codec tag: " << codecCtx_->codec_tag
98+
<< " Codec type: " << codecCtx_->codec_type
99+
<< " Codec extradata: " << codecCtx_->extradata
100+
<< " Number of threads: " << at::get_num_threads();
101+
56102
// after avcodec_open2, value of codecCtx_->time_base is NOT meaningful
57103
if ((ret = avcodec_open2(codecCtx_, codec, nullptr)) < 0) {
58104
LOG(ERROR) << "LoggingUuid #" << loggingUuid_

torchvision/csrc/io/decoder/stream.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class Stream {
2020
virtual ~Stream();
2121

2222
// returns 0 - on success or negative error
23-
int openCodec(std::vector<DecoderMetadata>* metadata);
23+
// num_threads sets up the codec context for multithreading if needed
24+
int openCodec(std::vector<DecoderMetadata>* metadata, int num_threads = 1);
2425
// returns 1 - if packet got consumed, 0 - if it's not, and < 0 on error
2526
int decodePacket(
2627
const AVPacket* packet,

torchvision/csrc/io/video/video.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,17 @@ void Video::_getDecoderParams(
9797
double videoStartS,
9898
int64_t getPtsOnly,
9999
std::string stream,
100-
long stream_id = -1,
101-
bool all_streams = false,
102-
double seekFrameMarginUs = 10) {
100+
long stream_id,
101+
bool all_streams,
102+
int64_t num_threads,
103+
double seekFrameMarginUs) {
103104
int64_t videoStartUs = int64_t(videoStartS * 1e6);
104105

105106
params.timeoutMs = decoderTimeoutMs;
106107
params.startOffset = videoStartUs;
107108
params.seekAccuracy = seekFrameMarginUs;
108109
params.headerOnly = false;
110+
params.numThreads = num_threads;
109111

110112
params.preventStaleness = false; // not sure what this is about
111113

@@ -152,7 +154,9 @@ void Video::_getDecoderParams(
152154

153155
} // _get decoder params
154156

155-
Video::Video(std::string videoPath, std::string stream) {
157+
Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
158+
// set number of threads global
159+
numThreads_ = numThreads;
156160
// parse stream information
157161
current_stream = _parseStream(stream);
158162
// note that in the initial call we want to get all streams
@@ -161,7 +165,8 @@ Video::Video(std::string videoPath, std::string stream) {
161165
0, // headerOnly
162166
std::get<0>(current_stream), // stream info - remove that
163167
long(-1), // stream_id parsed from info above change to -2
164-
true // read all streams
168+
true, // read all streams
169+
numThreads_ // global number of Threads for decoding
165170
);
166171

167172
std::string logMessage, logType;
@@ -225,7 +230,7 @@ Video::Video(std::string videoPath, std::string stream) {
225230
}
226231
} // video
227232

228-
bool Video::setCurrentStream(std::string stream = "video") {
233+
bool Video::setCurrentStream(std::string stream) {
229234
if ((!stream.empty()) && (_parseStream(stream) != current_stream)) {
230235
current_stream = _parseStream(stream);
231236
}
@@ -241,7 +246,8 @@ bool Video::setCurrentStream(std::string stream = "video") {
241246
std::get<0>(current_stream), // stream
242247
long(std::get<1>(
243248
current_stream)), // stream_id parsed from info above change to -2
244-
false // read all streams
249+
false, // read all streams
250+
numThreads_ // global number of threads
245251
);
246252

247253
// calback and metadata defined in Video.h
@@ -265,7 +271,8 @@ void Video::Seek(double ts) {
265271
std::get<0>(current_stream), // stream
266272
long(std::get<1>(
267273
current_stream)), // stream_id parsed from info above change to -2
268-
false // read all streams
274+
false, // read all streams
275+
numThreads_ // global num threads
269276
);
270277

271278
// calback and metadata defined in Video.h
@@ -331,7 +338,7 @@ std::tuple<torch::Tensor, double> Video::Next() {
331338

332339
static auto registerVideo =
333340
torch::class_<Video>("torchvision", "Video")
334-
.def(torch::init<std::string, std::string>())
341+
.def(torch::init<std::string, std::string, int64_t>())
335342
.def("get_current_stream", &Video::getCurrentStream)
336343
.def("set_current_stream", &Video::setCurrentStream)
337344
.def("get_metadata", &Video::getStreamMetadata)

torchvision/csrc/io/video/video.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ struct Video : torch::CustomClassHolder {
1616
// global video metadata
1717
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
1818
streamsMetadata;
19+
int64_t numThreads_{0};
1920

2021
public:
21-
Video(std::string videoPath, std::string stream);
22+
Video(std::string videoPath, std::string stream, int64_t numThreads);
2223
std::tuple<std::string, int64_t> getCurrentStream() const;
2324
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
2425
getStreamMetadata() const;
2526
void Seek(double ts);
26-
bool setCurrentStream(std::string stream);
27+
bool setCurrentStream(std::string stream = "video");
2728
std::tuple<torch::Tensor, double> Next();
2829

2930
private:
@@ -37,9 +38,10 @@ struct Video : torch::CustomClassHolder {
3738
double videoStartS,
3839
int64_t getPtsOnly,
3940
std::string stream,
40-
long stream_id,
41-
bool all_streams,
42-
double seekFrameMarginUs); // this needs to be improved
41+
long stream_id = -1,
42+
bool all_streams = false,
43+
int64_t num_threads = 0,
44+
double seekFrameMarginUs = 10); // this needs to be improved
4345

4446
std::map<std::string, std::vector<double>> streamTimeBase; // not used
4547

torchvision/io/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,21 @@ class VideoReader:
9696
stream (string, optional): descriptor of the required stream, followed by the stream id,
9797
in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
9898
Currently available options include ``['video', 'audio']``
99+
100+
num_threads (int, optional): number of threads used by the codec to decode video.
101+
Default value (0) enables multithreading with codec-dependent heuristic. The performance
102+
will depend on the version of FFMPEG codecs supported.
99103
"""
100104

101-
def __init__(self, path, stream="video"):
105+
def __init__(self, path, stream="video", num_threads=0):
102106
if not _has_video_opt():
103107
raise RuntimeError(
104108
"Not compiled with video_reader support, "
105109
+ "to enable video_reader support, please install "
106110
+ "ffmpeg (version 4.2 is currently supported) and"
107111
+ "build torchvision from source."
108112
)
109-
self._c = torch.classes.torchvision.Video(path, stream)
113+
self._c = torch.classes.torchvision.Video(path, stream, num_threads)
110114

111115
def __next__(self):
112116
"""Decodes and returns the next frame of the current stream.

0 commit comments

Comments
 (0)