Skip to content

Commit 0a9f49f

Browse files
author
Qianqian Zhu
committed
update marian decoder with option of loading binary shortlist
1 parent af67906 commit 0a9f49f

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

src/translator/translator.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,17 @@ class Translate : public ModelTask {
5959
trgVocab_->load(vocabs.back());
6060
auto srcVocab = corpus_->getVocabs()[0];
6161

62-
if(options_->hasAndNotEmpty("shortlist"))
63-
shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
64-
options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
62+
if(options_->hasAndNotEmpty("shortlist")) {
63+
auto slOptions = options_->get<std::vector<std::string>>("shortlist");
64+
ABORT_IF(slOptions.empty(), "No path to shortlist file given");
65+
std::string filename = slOptions[0];
66+
if(data::isBinaryShortlist(filename))
67+
shortlistGenerator_ = New<data::BinaryShortlistGenerator>(
68+
options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
69+
else
70+
shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
71+
options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
72+
}
6573

6674
auto devices = Config::getDevices(options_);
6775
numDevices_ = devices.size();
@@ -87,7 +95,6 @@ class Translate : public ModelTask {
8795
auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
8896
graph->setDefaultElementType(typeFromString(prec[0]));
8997
graph->setDevice(device);
90-
graph->getBackend()->configureDevice(options_);
9198
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
9299
graphs_[id] = graph;
93100

@@ -211,9 +218,17 @@ class TranslateService : public ModelServiceTask {
211218
trgVocab_->load(vocabPaths.back());
212219

213220
// load lexical shortlist
214-
if(options_->hasAndNotEmpty("shortlist"))
215-
shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
216-
options_, srcVocabs_.front(), trgVocab_, 0, 1, vocabPaths.front() == vocabPaths.back());
221+
if(options_->hasAndNotEmpty("shortlist")) {
222+
auto slOptions = options_->get<std::vector<std::string>>("shortlist");
223+
ABORT_IF(slOptions.empty(), "No path to shortlist file given");
224+
std::string filename = slOptions[0];
225+
if(data::isBinaryShortlist(filename))
226+
shortlistGenerator_ = New<data::BinaryShortlistGenerator>(
227+
options_, srcVocabs_.front(), trgVocab_, 0, 1, vocabPaths.front() == vocabPaths.back());
228+
else
229+
shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
230+
options_, srcVocabs_.front(), trgVocab_, 0, 1, vocabPaths.front() == vocabPaths.back());
231+
}
217232

218233
// get device IDs
219234
auto devices = Config::getDevices(options_);
@@ -226,7 +241,6 @@ class TranslateService : public ModelServiceTask {
226241
auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
227242
graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
228243
graph->setDevice(device);
229-
graph->getBackend()->configureDevice(options_);
230244
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
231245
graphs_.push_back(graph);
232246

0 commit comments

Comments
 (0)