@@ -59,9 +59,17 @@ class Translate : public ModelTask {
5959 trgVocab_->load (vocabs.back ());
6060 auto srcVocab = corpus_->getVocabs ()[0 ];
6161
62- if (options_->hasAndNotEmpty (" shortlist" ))
63- shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
64- options_, srcVocab, trgVocab_, 0 , 1 , vocabs.front () == vocabs.back ());
62+ if (options_->hasAndNotEmpty (" shortlist" )) {
63+ auto slOptions = options_->get <std::vector<std::string>>(" shortlist" );
64+ ABORT_IF (slOptions.empty (), " No path to shortlist file given" );
65+ std::string filename = slOptions[0 ];
66+ if (data::isBinaryShortlist (filename))
67+ shortlistGenerator_ = New<data::BinaryShortlistGenerator>(
68+ options_, srcVocab, trgVocab_, 0 , 1 , vocabs.front () == vocabs.back ());
69+ else
70+ shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
71+ options_, srcVocab, trgVocab_, 0 , 1 , vocabs.front () == vocabs.back ());
72+ }
6573
6674 auto devices = Config::getDevices (options_);
6775 numDevices_ = devices.size ();
@@ -87,7 +95,6 @@ class Translate : public ModelTask {
8795 auto prec = options_->get <std::vector<std::string>>(" precision" , {" float32" });
8896 graph->setDefaultElementType (typeFromString (prec[0 ]));
8997 graph->setDevice (device);
90- graph->getBackend ()->configureDevice (options_);
9198 graph->reserveWorkspaceMB (options_->get <size_t >(" workspace" ));
9299 graphs_[id] = graph;
93100
@@ -211,9 +218,17 @@ class TranslateService : public ModelServiceTask {
211218 trgVocab_->load (vocabPaths.back ());
212219
213220 // load lexical shortlist
214- if (options_->hasAndNotEmpty (" shortlist" ))
215- shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
216- options_, srcVocabs_.front (), trgVocab_, 0 , 1 , vocabPaths.front () == vocabPaths.back ());
221+ if (options_->hasAndNotEmpty (" shortlist" )) {
222+ auto slOptions = options_->get <std::vector<std::string>>(" shortlist" );
223+ ABORT_IF (slOptions.empty (), " No path to shortlist file given" );
224+ std::string filename = slOptions[0 ];
225+ if (data::isBinaryShortlist (filename))
226+ shortlistGenerator_ = New<data::BinaryShortlistGenerator>(
227+ options_, srcVocabs_.front (), trgVocab_, 0 , 1 , vocabPaths.front () == vocabPaths.back ());
228+ else
229+ shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
230+ options_, srcVocabs_.front (), trgVocab_, 0 , 1 , vocabPaths.front () == vocabPaths.back ());
231+ }
217232
218233 // get device IDs
219234 auto devices = Config::getDevices (options_);
@@ -226,7 +241,6 @@ class TranslateService : public ModelServiceTask {
226241 auto precison = options_->get <std::vector<std::string>>(" precision" , {" float32" });
227242 graph->setDefaultElementType (typeFromString (precison[0 ])); // only use first type, used for parameter type in graph
228243 graph->setDevice (device);
229- graph->getBackend ()->configureDevice (options_);
230244 graph->reserveWorkspaceMB (options_->get <size_t >(" workspace" ));
231245 graphs_.push_back (graph);
232246
0 commit comments