Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -291,24 +291,40 @@ class BertEmbeddings(override val uid: String)
* @return any number of annotations processed for every input annotation. Not necessary one to one relationship
*/
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations.map(annotations =>
TokenizedWithSentence.unpack(annotations).toArray
).toArray
/*Return empty if the real tokens are empty*/
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
val tokenized = tokenizeWithAlignment(tokenizedSentences)

val withEmbeddings = getModelIfNotSet.calculateEmbeddings(
tokenized,
tokenizedSentences,
$(batchSize),
$(maxSentenceLength),
$(caseSensitive)
)
WordpieceEmbeddingsSentence.pack(withEmbeddings)
}) else {
Seq(Seq.empty[Annotation])
}

//Unpack annotations and zip each sentence to the index or the row it belongs to
val sentencesWithRow = batchedAnnotations
.zipWithIndex
.flatMap { case (annotations, i) => TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i)) }

//Tokenize sentences
val tokenizedSentences = tokenizeWithAlignment(sentencesWithRow.map(_._1))

//Process all sentences
val sentenceWordEmbeddings = getModelIfNotSet.calculateEmbeddings(
tokenizedSentences,
sentencesWithRow.map(_._1),
$(batchSize),
$(maxSentenceLength),
$(caseSensitive)
)

//Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
batchedAnnotations.indices.map(rowIndex => {
val rowEmbeddings = sentenceWordEmbeddings
//zip each annotation with its corresponding row index
.zip(sentencesWithRow)
//select the sentences belonging to the current row
.filter(_._2._2 == rowIndex)
//leave the annotation only
.map(_._1)

if (rowEmbeddings.nonEmpty)
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
else
Seq.empty[Annotation]
})

}

override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,119 +17,119 @@
package com.johnsnowlabs.nlp.embeddings

import java.io.File

import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
import com.johnsnowlabs.nlp.serialization.MapFeature
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper}
import com.johnsnowlabs.storage.HasStorageRef

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param.{IntArrayParam, IntParam, BooleanParam}
import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.collection.mutable.ArrayBuffer

/**
* Sentence-level embeddings using BERT. BERT (Bidirectional Encoder Representations from Transformers) provides dense
* vector representations for natural language by using a deep, pre-trained neural network with the Transformer architecture.
*
* Pretrained models can be loaded with `pretrained` of the companion object:
* {{{
* val embeddings = BertSentenceEmbeddings.pretrained()
* .setInputCols("sentence")
* .setOutputCol("sentence_bert_embeddings")
* }}}
* The default model is `"sent_small_bert_L2_768"`, if no name is provided.
*
* For available pretrained models please see the [[https://nlp.johnsnowlabs.com/models?task=Embeddings Models Hub]].
*
* For extended examples of usage, see the [[https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/transformers/HuggingFace%20in%20Spark%20NLP%20-%20BERT%20Sentence.ipynb Spark NLP Workshop]]
* and the [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddingsTestSpec.scala BertSentenceEmbeddingsTestSpec]].
*
* '''Sources''' :
*
* [[https://arxiv.org/abs/1810.04805 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding]]
*
* [[https://github.com/google-research/bert]]
*
* ''' Paper abstract '''
*
* ''We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations
* from Transformers. Unlike recent language representation models, BERT is designed to pre-train deep bidirectional
* representations from unlabeled text by jointly conditioning on both left and right context in all layers. As a
* result, the pre-trained BERT model can be fine-tuned with just one additional output layer to create
* state-of-the-art models for a wide range of tasks, such as question answering and language inference, without
* substantial task-specific architecture modifications. BERT is conceptually simple and empirically powerful. It
* obtains new state-of-the-art results on eleven natural language processing tasks, including pushing the GLUE score
* to 80.5% (7.7% point absolute improvement), MultiNLI accuracy to 86.7% (4.6% absolute improvement), SQuAD v1.1
* question answering Test F1 to 93.2 (1.5 point absolute improvement) and SQuAD v2.0 Test F1 to 83.1 (5.1 point
* absolute improvement).''
*
* ==Example==
* {{{
* import spark.implicits._
* import com.johnsnowlabs.nlp.base.DocumentAssembler
* import com.johnsnowlabs.nlp.annotator.SentenceDetector
* import com.johnsnowlabs.nlp.embeddings.BertSentenceEmbeddings
* import com.johnsnowlabs.nlp.EmbeddingsFinisher
* import org.apache.spark.ml.Pipeline
*
* val documentAssembler = new DocumentAssembler()
* .setInputCol("text")
* .setOutputCol("document")
*
* val sentence = new SentenceDetector()
* .setInputCols("document")
* .setOutputCol("sentence")
*
* val embeddings = BertSentenceEmbeddings.pretrained("sent_small_bert_L2_128")
* .setInputCols("sentence")
* .setOutputCol("sentence_bert_embeddings")
*
* val embeddingsFinisher = new EmbeddingsFinisher()
* .setInputCols("sentence_bert_embeddings")
* .setOutputCols("finished_embeddings")
* .setOutputAsVector(true)
*
* val pipeline = new Pipeline().setStages(Array(
* documentAssembler,
* sentence,
* embeddings,
* embeddingsFinisher
* ))
*
* val data = Seq("John loves apples. Mary loves oranges. John loves Mary.").toDF("text")
* val result = pipeline.fit(data).transform(data)
*
* result.selectExpr("explode(finished_embeddings) as result").show(5, 80)
* +--------------------------------------------------------------------------------+
* | result|
* +--------------------------------------------------------------------------------+
* |[-0.8951074481010437,0.13753940165042877,0.3108254075050354,-1.65693199634552...|
* |[-0.6180210709571838,-0.12179657071828842,-0.191165953874588,-1.4497021436691...|
* |[-0.822715163230896,0.7568016648292542,-0.1165061742067337,-1.59048593044281,...|
* +--------------------------------------------------------------------------------+
* }}}
*
* @see [[BertEmbeddings]] for token-level embeddings
* @see [[https://nlp.johnsnowlabs.com/docs/en/annotators Annotators Main Page]] for a list of transformer based embeddings
* @param uid required uid for storing annotator to disk
* @groupname anno Annotator types
* @groupdesc anno Required input and expected output annotator types
* @groupname Ungrouped Members
* @groupname param Parameters
* @groupname setParam Parameter setters
* @groupname getParam Parameter getters
* @groupname Ungrouped Members
* @groupprio param 1
* @groupprio anno 2
* @groupprio Ungrouped 3
* @groupprio setParam 4
* @groupprio getParam 5
* @groupdesc param A list of (hyper-)parameter keys this annotator can take. Users can set and get the parameter values through setters and getters, respectively.
* */
* Sentence-level embeddings using BERT. BERT (Bidirectional Encoder Representations from Transformers) provides dense
* vector representations for natural language by using a deep, pre-trained neural network with the Transformer architecture.
*
* Pretrained models can be loaded with `pretrained` of the companion object:
* {{{
* val embeddings = BertSentenceEmbeddings.pretrained()
* .setInputCols("sentence")
* .setOutputCol("sentence_bert_embeddings")
* }}}
* The default model is `"sent_small_bert_L2_768"`, if no name is provided.
*
* For available pretrained models please see the [[https://nlp.johnsnowlabs.com/models?task=Embeddings Models Hub]].
*
* For extended examples of usage, see the [[https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/transformers/HuggingFace%20in%20Spark%20NLP%20-%20BERT%20Sentence.ipynb Spark NLP Workshop]]
* and the [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddingsTestSpec.scala BertSentenceEmbeddingsTestSpec]].
*
* '''Sources''' :
*
* [[https://arxiv.org/abs/1810.04805 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding]]
*
* [[https://github.com/google-research/bert]]
*
* ''' Paper abstract '''
*
* ''We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations
* from Transformers. Unlike recent language representation models, BERT is designed to pre-train deep bidirectional
* representations from unlabeled text by jointly conditioning on both left and right context in all layers. As a
* result, the pre-trained BERT model can be fine-tuned with just one additional output layer to create
* state-of-the-art models for a wide range of tasks, such as question answering and language inference, without
* substantial task-specific architecture modifications. BERT is conceptually simple and empirically powerful. It
* obtains new state-of-the-art results on eleven natural language processing tasks, including pushing the GLUE score
* to 80.5% (7.7% point absolute improvement), MultiNLI accuracy to 86.7% (4.6% absolute improvement), SQuAD v1.1
* question answering Test F1 to 93.2 (1.5 point absolute improvement) and SQuAD v2.0 Test F1 to 83.1 (5.1 point
* absolute improvement).''
*
* ==Example==
* {{{
* import spark.implicits._
* import com.johnsnowlabs.nlp.base.DocumentAssembler
* import com.johnsnowlabs.nlp.annotator.SentenceDetector
* import com.johnsnowlabs.nlp.embeddings.BertSentenceEmbeddings
* import com.johnsnowlabs.nlp.EmbeddingsFinisher
* import org.apache.spark.ml.Pipeline
*
* val documentAssembler = new DocumentAssembler()
* .setInputCol("text")
* .setOutputCol("document")
*
* val sentence = new SentenceDetector()
* .setInputCols("document")
* .setOutputCol("sentence")
*
* val embeddings = BertSentenceEmbeddings.pretrained("sent_small_bert_L2_128")
* .setInputCols("sentence")
* .setOutputCol("sentence_bert_embeddings")
*
* val embeddingsFinisher = new EmbeddingsFinisher()
* .setInputCols("sentence_bert_embeddings")
* .setOutputCols("finished_embeddings")
* .setOutputAsVector(true)
*
* val pipeline = new Pipeline().setStages(Array(
* documentAssembler,
* sentence,
* embeddings,
* embeddingsFinisher
* ))
*
* val data = Seq("John loves apples. Mary loves oranges. John loves Mary.").toDF("text")
* val result = pipeline.fit(data).transform(data)
*
* result.selectExpr("explode(finished_embeddings) as result").show(5, 80)
* +--------------------------------------------------------------------------------+
* | result|
* +--------------------------------------------------------------------------------+
* |[-0.8951074481010437,0.13753940165042877,0.3108254075050354,-1.65693199634552...|
* |[-0.6180210709571838,-0.12179657071828842,-0.191165953874588,-1.4497021436691...|
* |[-0.822715163230896,0.7568016648292542,-0.1165061742067337,-1.59048593044281,...|
* +--------------------------------------------------------------------------------+
* }}}
*
* @see [[BertEmbeddings]] for token-level embeddings
* @see [[https://nlp.johnsnowlabs.com/docs/en/annotators Annotators Main Page]] for a list of transformer based embeddings
* @param uid required uid for storing annotator to disk
* @groupname anno Annotator types
* @groupdesc anno Required input and expected output annotator types
* @groupname Ungrouped Members
* @groupname param Parameters
* @groupname setParam Parameter setters
* @groupname getParam Parameter getters
* @groupname Ungrouped Members
* @groupprio param 1
* @groupprio anno 2
* @groupprio Ungrouped 3
* @groupprio setParam 4
* @groupprio getParam 5
* @groupdesc param A list of (hyper-)parameter keys this annotator can take. Users can set and get the parameter values through setters and getters, respectively.
* */
class BertSentenceEmbeddings(override val uid: String)
extends AnnotatorModel[BertSentenceEmbeddings]
with HasBatchedAnnotate[BertSentenceEmbeddings]
Expand Down Expand Up @@ -316,23 +316,40 @@ class BertSentenceEmbeddings(override val uid: String)
* @return any number of annotations processed for every input annotation. Not necessary one to one relationship
*/
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
/*Return empty if the real sentences are empty*/
batchedAnnotations.map(annotations => {
val sentences = SentenceSplit.unpack(annotations).toArray

if (sentences.nonEmpty) {
val tokenized = tokenize(sentences)
getModelIfNotSet.calculateSentenceEmbeddings(
tokenized,
sentences,
$(batchSize),
$(maxSentenceLength),
getIsLong
)
} else {

//Unpack annotations and zip each sentence to the index or the row it belongs to
val sentencesWithRow = batchedAnnotations
.zipWithIndex
.flatMap { case (annotations, i) => SentenceSplit.unpack(annotations).map(x => (x, i)) }

//Tokenize sentences
val tokenizedSentences = tokenize(sentencesWithRow.map(_._1))

//Process all sentences
val allAnnotations = getModelIfNotSet.calculateSentenceEmbeddings(
tokenizedSentences,
sentencesWithRow.map(_._1),
$(batchSize),
$(maxSentenceLength),
getIsLong
)

//Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
batchedAnnotations.indices.map(rowIndex => {
val rowAnnotations = allAnnotations
//zip each annotation with its corresponding row index
.zip(sentencesWithRow)
//select the sentences belonging to the current row
.filter(_._2._2 == rowIndex)
//leave the annotation only
.map(_._1)

if (rowAnnotations.nonEmpty)
rowAnnotations
else
Seq.empty[Annotation]
}
})

}

override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
Expand Down