Skip to content

Commit ceb44ff

Browse files
Merge pull request #9267 from JohnSnowLabs/feature/batch-opt-gpu-transformers
Optimizing batch processing for transformer-based Word Embeddings on GPU
2 parents 6c4f826 + e8a25da commit ceb44ff

File tree

8 files changed

+206
-110
lines changed

8 files changed

+206
-110
lines changed

src/main/scala/com/johnsnowlabs/ml/tensorflow/TensorflowXlmRoberta.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,6 @@ class TensorflowXlmRoberta(
8484
private val SentencePadTokenId = 1
8585
private val SentencePieceDelimiterId = spp.getSppModel.pieceToId("")
8686

87-
private val encoder =
88-
new SentencepieceEncoder(spp, caseSensitive, SentencePieceDelimiterId, pieceIdOffset = 1)
89-
9087
def prepareBatchInputs(
9188
sentences: Seq[(WordpieceTokenizedSentence, Int)],
9289
maxSequenceLength: Int): Seq[Array[Int]] = {
@@ -323,11 +320,13 @@ class TensorflowXlmRoberta(
323320
def tokenizeWithAlignment(
324321
sentences: Seq[TokenizedSentence],
325322
maxSeqLength: Int): Seq[WordpieceTokenizedSentence] = {
323+
val encoder =
324+
new SentencepieceEncoder(spp, caseSensitive, SentencePieceDelimiterId, pieceIdOffset = 1)
326325

327326
val sentenceTokenPieces = sentences.map { s =>
328-
val shrinkedSentence = s.indexedTokens.take(maxSeqLength - 2)
327+
val trimmedSentence = s.indexedTokens.take(maxSeqLength - 2)
329328
val wordpieceTokens =
330-
shrinkedSentence.flatMap(token => encoder.encode(token)).take(maxSeqLength)
329+
trimmedSentence.flatMap(token => encoder.encode(token)).take(maxSeqLength)
331330
WordpieceTokenizedSentence(wordpieceTokens)
332331
}
333332
sentenceTokenPieces
@@ -336,6 +335,8 @@ class TensorflowXlmRoberta(
336335
def tokenizeSentence(
337336
sentences: Seq[Sentence],
338337
maxSeqLength: Int): Seq[WordpieceTokenizedSentence] = {
338+
val encoder =
339+
new SentencepieceEncoder(spp, caseSensitive, SentencePieceDelimiterId, pieceIdOffset = 1)
339340

340341
val sentenceTokenPieces = sentences.map { s =>
341342
val wordpieceTokens = encoder.encodeSentence(s, maxLength = maxSeqLength).take(maxSeqLength)

src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -289,20 +289,34 @@ class AlbertEmbeddings(override val uid: String)
289289
* relationship
290290
*/
291291
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
292-
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
293-
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
294-
.toArray
292+
// Unpack annotations and zip each sentence to the index or the row it belongs to
293+
val sentencesWithRow = batchedAnnotations.zipWithIndex
294+
.flatMap { case (annotations, i) =>
295+
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
296+
}
295297

296298
/*Return empty if the real tokens are empty*/
297-
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
298-
299-
val embeddings = getModelIfNotSet
300-
.predict(tokenizedSentences, $(batchSize), $(maxSentenceLength), $(caseSensitive))
301-
WordpieceEmbeddingsSentence.pack(embeddings)
299+
val sentenceWordEmbeddings = getModelIfNotSet.predict(
300+
sentencesWithRow.map(_._1),
301+
$(batchSize),
302+
$(maxSentenceLength),
303+
$(caseSensitive))
304+
305+
// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
306+
batchedAnnotations.indices.map(rowIndex => {
307+
val rowEmbeddings = sentenceWordEmbeddings
308+
// zip each annotation with its corresponding row index
309+
.zip(sentencesWithRow)
310+
// select the sentences belonging to the current row
311+
.filter(_._2._2 == rowIndex)
312+
// leave the annotation only
313+
.map(_._1)
314+
315+
if (rowEmbeddings.nonEmpty)
316+
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
317+
else
318+
Seq.empty[Annotation]
302319
})
303-
else {
304-
Seq(Seq.empty[Annotation])
305-
}
306320
}
307321

308322
override def onWrite(path: String, spark: SparkSession): Unit = {

src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -270,20 +270,34 @@ class DeBertaEmbeddings(override val uid: String)
270270
* relationship
271271
*/
272272
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
273-
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
274-
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
275-
.toArray
273+
// Unpack annotations and zip each sentence to the index or the row it belongs to
274+
val sentencesWithRow = batchedAnnotations.zipWithIndex
275+
.flatMap { case (annotations, i) =>
276+
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
277+
}
276278

277279
/*Return empty if the real tokens are empty*/
278-
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
279-
280-
val embeddings = getModelIfNotSet
281-
.predict(tokenizedSentences, $(batchSize), $(maxSentenceLength), $(caseSensitive))
282-
WordpieceEmbeddingsSentence.pack(embeddings)
280+
val sentenceWordEmbeddings = getModelIfNotSet.predict(
281+
sentencesWithRow.map(_._1),
282+
$(batchSize),
283+
$(maxSentenceLength),
284+
$(caseSensitive))
285+
286+
// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
287+
batchedAnnotations.indices.map(rowIndex => {
288+
val rowEmbeddings = sentenceWordEmbeddings
289+
// zip each annotation with its corresponding row index
290+
.zip(sentencesWithRow)
291+
// select the sentences belonging to the current row
292+
.filter(_._2._2 == rowIndex)
293+
// leave the annotation only
294+
.map(_._1)
295+
296+
if (rowEmbeddings.nonEmpty)
297+
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
298+
else
299+
Seq.empty[Annotation]
283300
})
284-
else {
285-
Seq(Seq.empty[Annotation])
286-
}
287301
}
288302

289303
override def onWrite(path: String, spark: SparkSession): Unit = {

src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ class DistilBertEmbeddings(override val uid: String)
295295
val content = if ($(caseSensitive)) token.token else token.token.toLowerCase()
296296
val sentenceBegin = token.begin
297297
val sentenceEnd = token.end
298-
val sentenceInedx = tokenIndex.sentenceIndex
298+
val sentenceIndex = tokenIndex.sentenceIndex
299299
val result = basicTokenizer.tokenize(
300-
Sentence(content, sentenceBegin, sentenceEnd, sentenceInedx))
300+
Sentence(content, sentenceBegin, sentenceEnd, sentenceIndex))
301301
if (result.nonEmpty) result.head else IndexedToken("")
302302
}
303303
val wordpieceTokens =
@@ -316,24 +316,38 @@ class DistilBertEmbeddings(override val uid: String)
316316
* relationship
317317
*/
318318
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
319-
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
320-
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
321-
.toArray
322-
/*Return empty if the real tokens are empty*/
323-
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
324-
val tokenized = tokenizeWithAlignment(tokenizedSentences)
325-
326-
val withEmbeddings = getModelIfNotSet.predict(
327-
tokenized,
328-
tokenizedSentences,
329-
$(batchSize),
330-
$(maxSentenceLength),
331-
$(caseSensitive))
332-
WordpieceEmbeddingsSentence.pack(withEmbeddings)
319+
// Unpack annotations and zip each sentence to the index or the row it belongs to
320+
val sentencesWithRow = batchedAnnotations.zipWithIndex
321+
.flatMap { case (annotations, i) =>
322+
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
323+
}
324+
325+
// Tokenize sentences
326+
val tokenizedSentences = tokenizeWithAlignment(sentencesWithRow.map(_._1))
327+
328+
// Process all sentences
329+
val sentenceWordEmbeddings = getModelIfNotSet.predict(
330+
tokenizedSentences,
331+
sentencesWithRow.map(_._1),
332+
$(batchSize),
333+
$(maxSentenceLength),
334+
$(caseSensitive))
335+
336+
// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
337+
batchedAnnotations.indices.map(rowIndex => {
338+
val rowEmbeddings = sentenceWordEmbeddings
339+
// zip each annotation with its corresponding row index
340+
.zip(sentencesWithRow)
341+
// select the sentences belonging to the current row
342+
.filter(_._2._2 == rowIndex)
343+
// leave the annotation only
344+
.map(_._1)
345+
346+
if (rowEmbeddings.nonEmpty)
347+
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
348+
else
349+
Seq.empty[Annotation]
333350
})
334-
else {
335-
Seq(Seq.empty[Annotation])
336-
}
337351
}
338352

339353
override protected def afterAnnotate(dataset: DataFrame): DataFrame = {

src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ class LongformerEmbeddings(override val uid: String)
283283
this
284284
}
285285

286-
setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 1024, caseSensitive -> true)
286+
setDefault(dimension -> 768, batchSize -> 4, maxSentenceLength -> 1024, caseSensitive -> true)
287287

288288
def tokenizeWithAlignment(tokens: Seq[TokenizedSentence]): Seq[WordpieceTokenizedSentence] = {
289289

@@ -301,9 +301,9 @@ class LongformerEmbeddings(override val uid: String)
301301
val content = if ($(caseSensitive)) token.token else token.token.toLowerCase()
302302
val sentenceBegin = token.begin
303303
val sentenceEnd = token.end
304-
val sentenceInedx = tokenIndex.sentenceIndex
304+
val sentenceIndex = tokenIndex.sentenceIndex
305305
val result =
306-
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceInedx))
306+
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceIndex))
307307
if (result.nonEmpty) result.head else IndexedToken("")
308308
}
309309
val wordpieceTokens =
@@ -322,25 +322,38 @@ class LongformerEmbeddings(override val uid: String)
322322
* relationship
323323
*/
324324
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
325-
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
326-
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
327-
.toArray
328-
329-
/*Return empty if the real tokens are empty*/
330-
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
331-
val tokenized = tokenizeWithAlignment(tokenizedSentences)
332-
333-
val withEmbeddings = getModelIfNotSet.predict(
334-
tokenized,
335-
tokenizedSentences,
336-
$(batchSize),
337-
$(maxSentenceLength),
338-
$(caseSensitive))
339-
WordpieceEmbeddingsSentence.pack(withEmbeddings)
325+
// Unpack annotations and zip each sentence to the index or the row it belongs to
326+
val sentencesWithRow = batchedAnnotations.zipWithIndex
327+
.flatMap { case (annotations, i) =>
328+
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
329+
}
330+
331+
// Tokenize sentences
332+
val tokenizedSentences = tokenizeWithAlignment(sentencesWithRow.map(_._1))
333+
334+
// Process all sentences
335+
val sentenceWordEmbeddings = getModelIfNotSet.predict(
336+
tokenizedSentences,
337+
sentencesWithRow.map(_._1),
338+
$(batchSize),
339+
$(maxSentenceLength),
340+
$(caseSensitive))
341+
342+
// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
343+
batchedAnnotations.indices.map(rowIndex => {
344+
val rowEmbeddings = sentenceWordEmbeddings
345+
// zip each annotation with its corresponding row index
346+
.zip(sentencesWithRow)
347+
// select the sentences belonging to the current row
348+
.filter(_._2._2 == rowIndex)
349+
// leave the annotation only
350+
.map(_._1)
351+
352+
if (rowEmbeddings.nonEmpty)
353+
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
354+
else
355+
Seq.empty[Annotation]
340356
})
341-
else {
342-
Seq(Seq.empty[Annotation])
343-
}
344357
}
345358

346359
override protected def afterAnnotate(dataset: DataFrame): DataFrame = {

src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,9 @@ class RoBertaEmbeddings(override val uid: String)
313313
val content = if ($(caseSensitive)) token.token else token.token.toLowerCase()
314314
val sentenceBegin = token.begin
315315
val sentenceEnd = token.end
316-
val sentenceInedx = tokenIndex.sentenceIndex
316+
val sentenceIndex = tokenIndex.sentenceIndex
317317
val result =
318-
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceInedx))
318+
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceIndex))
319319
if (result.nonEmpty) result.head else IndexedToken("")
320320
}
321321
val wordpieceTokens =
@@ -334,25 +334,38 @@ class RoBertaEmbeddings(override val uid: String)
334334
* relationship
335335
*/
336336
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
337-
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
338-
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
339-
.toArray
340-
341-
/*Return empty if the real tokens are empty*/
342-
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
343-
val tokenized = tokenizeWithAlignment(tokenizedSentences)
344-
345-
val withEmbeddings = getModelIfNotSet.predict(
346-
tokenized,
347-
tokenizedSentences,
348-
$(batchSize),
349-
$(maxSentenceLength),
350-
$(caseSensitive))
351-
WordpieceEmbeddingsSentence.pack(withEmbeddings)
337+
// Unpack annotations and zip each sentence to the index or the row it belongs to
338+
val sentencesWithRow = batchedAnnotations.zipWithIndex
339+
.flatMap { case (annotations, i) =>
340+
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
341+
}
342+
343+
// Tokenize sentences
344+
val tokenizedSentences = tokenizeWithAlignment(sentencesWithRow.map(_._1))
345+
346+
// Process all sentences
347+
val sentenceWordEmbeddings = getModelIfNotSet.predict(
348+
tokenizedSentences,
349+
sentencesWithRow.map(_._1),
350+
$(batchSize),
351+
$(maxSentenceLength),
352+
$(caseSensitive))
353+
354+
// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
355+
batchedAnnotations.indices.map(rowIndex => {
356+
val rowEmbeddings = sentenceWordEmbeddings
357+
// zip each annotation with its corresponding row index
358+
.zip(sentencesWithRow)
359+
// select the sentences belonging to the current row
360+
.filter(_._2._2 == rowIndex)
361+
// leave the annotation only
362+
.map(_._1)
363+
364+
if (rowEmbeddings.nonEmpty)
365+
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
366+
else
367+
Seq.empty[Annotation]
352368
})
353-
else {
354-
Seq(Seq.empty[Annotation])
355-
}
356369
}
357370

358371
override protected def afterAnnotate(dataset: DataFrame): DataFrame = {

src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,20 +277,31 @@ class XlmRoBertaEmbeddings(override val uid: String)
277277
* relationship
278278
*/
279279
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
280-
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
281-
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
282-
.toArray
283-
284-
/*Return empty if the real tokens are empty*/
285-
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
286-
287-
val embeddings =
288-
getModelIfNotSet.predict(tokenizedSentences, $(batchSize), $(maxSentenceLength))
289-
WordpieceEmbeddingsSentence.pack(embeddings)
280+
// Unpack annotations and zip each sentence to the index or the row it belongs to
281+
val sentencesWithRow = batchedAnnotations.zipWithIndex
282+
.flatMap { case (annotations, i) =>
283+
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
284+
}
285+
286+
val sentenceWordEmbeddings =
287+
getModelIfNotSet.predict(sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength))
288+
289+
// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
290+
batchedAnnotations.indices.map(rowIndex => {
291+
val rowEmbeddings = sentenceWordEmbeddings
292+
// zip each annotation with its corresponding row index
293+
.zip(sentencesWithRow)
294+
// select the sentences belonging to the current row
295+
.filter(_._2._2 == rowIndex)
296+
// leave the annotation only
297+
.map(_._1)
298+
299+
if (rowEmbeddings.nonEmpty)
300+
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
301+
else
302+
Seq.empty[Annotation]
290303
})
291-
else {
292-
Seq(Seq.empty[Annotation])
293-
}
304+
294305
}
295306

296307
override protected def afterAnnotate(dataset: DataFrame): DataFrame = {

0 commit comments

Comments
 (0)