Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ class TensorflowXlmRoberta(
private val SentencePadTokenId = 1
private val SentencePieceDelimiterId = spp.getSppModel.pieceToId("▁")

private val encoder =
new SentencepieceEncoder(spp, caseSensitive, SentencePieceDelimiterId, pieceIdOffset = 1)

def prepareBatchInputs(
sentences: Seq[(WordpieceTokenizedSentence, Int)],
maxSequenceLength: Int): Seq[Array[Int]] = {
Expand Down Expand Up @@ -323,11 +320,13 @@ class TensorflowXlmRoberta(
def tokenizeWithAlignment(
sentences: Seq[TokenizedSentence],
maxSeqLength: Int): Seq[WordpieceTokenizedSentence] = {
val encoder =
new SentencepieceEncoder(spp, caseSensitive, SentencePieceDelimiterId, pieceIdOffset = 1)

val sentenceTokenPieces = sentences.map { s =>
val shrinkedSentence = s.indexedTokens.take(maxSeqLength - 2)
val trimmedSentence = s.indexedTokens.take(maxSeqLength - 2)
val wordpieceTokens =
shrinkedSentence.flatMap(token => encoder.encode(token)).take(maxSeqLength)
trimmedSentence.flatMap(token => encoder.encode(token)).take(maxSeqLength)
WordpieceTokenizedSentence(wordpieceTokens)
}
sentenceTokenPieces
Expand All @@ -336,6 +335,8 @@ class TensorflowXlmRoberta(
def tokenizeSentence(
sentences: Seq[Sentence],
maxSeqLength: Int): Seq[WordpieceTokenizedSentence] = {
val encoder =
new SentencepieceEncoder(spp, caseSensitive, SentencePieceDelimiterId, pieceIdOffset = 1)

val sentenceTokenPieces = sentences.map { s =>
val wordpieceTokens = encoder.encodeSentence(s, maxLength = maxSeqLength).take(maxSeqLength)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,20 +289,34 @@ class AlbertEmbeddings(override val uid: String)
* relationship
*/
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
.toArray
// Unpack annotations and zip each sentence to the index or the row it belongs to
val sentencesWithRow = batchedAnnotations.zipWithIndex
.flatMap { case (annotations, i) =>
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
}

/*Return empty if the real tokens are empty*/
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {

val embeddings = getModelIfNotSet
.predict(tokenizedSentences, $(batchSize), $(maxSentenceLength), $(caseSensitive))
WordpieceEmbeddingsSentence.pack(embeddings)
val sentenceWordEmbeddings = getModelIfNotSet.predict(
sentencesWithRow.map(_._1),
$(batchSize),
$(maxSentenceLength),
$(caseSensitive))

// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
batchedAnnotations.indices.map(rowIndex => {
val rowEmbeddings = sentenceWordEmbeddings
// zip each annotation with its corresponding row index
.zip(sentencesWithRow)
// select the sentences belonging to the current row
.filter(_._2._2 == rowIndex)
// leave the annotation only
.map(_._1)

if (rowEmbeddings.nonEmpty)
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
else
Seq.empty[Annotation]
})
else {
Seq(Seq.empty[Annotation])
}
}

override def onWrite(path: String, spark: SparkSession): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,20 +270,34 @@ class DeBertaEmbeddings(override val uid: String)
* relationship
*/
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
.toArray
// Unpack annotations and zip each sentence to the index or the row it belongs to
val sentencesWithRow = batchedAnnotations.zipWithIndex
.flatMap { case (annotations, i) =>
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
}

/*Return empty if the real tokens are empty*/
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {

val embeddings = getModelIfNotSet
.predict(tokenizedSentences, $(batchSize), $(maxSentenceLength), $(caseSensitive))
WordpieceEmbeddingsSentence.pack(embeddings)
val sentenceWordEmbeddings = getModelIfNotSet.predict(
sentencesWithRow.map(_._1),
$(batchSize),
$(maxSentenceLength),
$(caseSensitive))

// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
batchedAnnotations.indices.map(rowIndex => {
val rowEmbeddings = sentenceWordEmbeddings
// zip each annotation with its corresponding row index
.zip(sentencesWithRow)
// select the sentences belonging to the current row
.filter(_._2._2 == rowIndex)
// leave the annotation only
.map(_._1)

if (rowEmbeddings.nonEmpty)
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
else
Seq.empty[Annotation]
})
else {
Seq(Seq.empty[Annotation])
}
}

override def onWrite(path: String, spark: SparkSession): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,9 @@ class DistilBertEmbeddings(override val uid: String)
val content = if ($(caseSensitive)) token.token else token.token.toLowerCase()
val sentenceBegin = token.begin
val sentenceEnd = token.end
val sentenceInedx = tokenIndex.sentenceIndex
val sentenceIndex = tokenIndex.sentenceIndex
val result = basicTokenizer.tokenize(
Sentence(content, sentenceBegin, sentenceEnd, sentenceInedx))
Sentence(content, sentenceBegin, sentenceEnd, sentenceIndex))
if (result.nonEmpty) result.head else IndexedToken("")
}
val wordpieceTokens =
Expand All @@ -316,24 +316,38 @@ class DistilBertEmbeddings(override val uid: String)
* relationship
*/
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
.toArray
/*Return empty if the real tokens are empty*/
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
val tokenized = tokenizeWithAlignment(tokenizedSentences)

val withEmbeddings = getModelIfNotSet.predict(
tokenized,
tokenizedSentences,
$(batchSize),
$(maxSentenceLength),
$(caseSensitive))
WordpieceEmbeddingsSentence.pack(withEmbeddings)
// Unpack annotations and zip each sentence to the index or the row it belongs to
val sentencesWithRow = batchedAnnotations.zipWithIndex
.flatMap { case (annotations, i) =>
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
}

// Tokenize sentences
val tokenizedSentences = tokenizeWithAlignment(sentencesWithRow.map(_._1))

// Process all sentences
val sentenceWordEmbeddings = getModelIfNotSet.predict(
tokenizedSentences,
sentencesWithRow.map(_._1),
$(batchSize),
$(maxSentenceLength),
$(caseSensitive))

// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
batchedAnnotations.indices.map(rowIndex => {
val rowEmbeddings = sentenceWordEmbeddings
// zip each annotation with its corresponding row index
.zip(sentencesWithRow)
// select the sentences belonging to the current row
.filter(_._2._2 == rowIndex)
// leave the annotation only
.map(_._1)

if (rowEmbeddings.nonEmpty)
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
else
Seq.empty[Annotation]
})
else {
Seq(Seq.empty[Annotation])
}
}

override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class LongformerEmbeddings(override val uid: String)
this
}

setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 1024, caseSensitive -> true)
setDefault(dimension -> 768, batchSize -> 4, maxSentenceLength -> 1024, caseSensitive -> true)

def tokenizeWithAlignment(tokens: Seq[TokenizedSentence]): Seq[WordpieceTokenizedSentence] = {

Expand All @@ -301,9 +301,9 @@ class LongformerEmbeddings(override val uid: String)
val content = if ($(caseSensitive)) token.token else token.token.toLowerCase()
val sentenceBegin = token.begin
val sentenceEnd = token.end
val sentenceInedx = tokenIndex.sentenceIndex
val sentenceIndex = tokenIndex.sentenceIndex
val result =
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceInedx))
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceIndex))
if (result.nonEmpty) result.head else IndexedToken("")
}
val wordpieceTokens =
Expand All @@ -322,25 +322,38 @@ class LongformerEmbeddings(override val uid: String)
* relationship
*/
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
.toArray

/*Return empty if the real tokens are empty*/
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
val tokenized = tokenizeWithAlignment(tokenizedSentences)

val withEmbeddings = getModelIfNotSet.predict(
tokenized,
tokenizedSentences,
$(batchSize),
$(maxSentenceLength),
$(caseSensitive))
WordpieceEmbeddingsSentence.pack(withEmbeddings)
// Unpack annotations and zip each sentence to the index or the row it belongs to
val sentencesWithRow = batchedAnnotations.zipWithIndex
.flatMap { case (annotations, i) =>
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
}

// Tokenize sentences
val tokenizedSentences = tokenizeWithAlignment(sentencesWithRow.map(_._1))

// Process all sentences
val sentenceWordEmbeddings = getModelIfNotSet.predict(
tokenizedSentences,
sentencesWithRow.map(_._1),
$(batchSize),
$(maxSentenceLength),
$(caseSensitive))

// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
batchedAnnotations.indices.map(rowIndex => {
val rowEmbeddings = sentenceWordEmbeddings
// zip each annotation with its corresponding row index
.zip(sentencesWithRow)
// select the sentences belonging to the current row
.filter(_._2._2 == rowIndex)
// leave the annotation only
.map(_._1)

if (rowEmbeddings.nonEmpty)
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
else
Seq.empty[Annotation]
})
else {
Seq(Seq.empty[Annotation])
}
}

override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,9 @@ class RoBertaEmbeddings(override val uid: String)
val content = if ($(caseSensitive)) token.token else token.token.toLowerCase()
val sentenceBegin = token.begin
val sentenceEnd = token.end
val sentenceInedx = tokenIndex.sentenceIndex
val sentenceIndex = tokenIndex.sentenceIndex
val result =
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceInedx))
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceIndex))
if (result.nonEmpty) result.head else IndexedToken("")
}
val wordpieceTokens =
Expand All @@ -334,25 +334,38 @@ class RoBertaEmbeddings(override val uid: String)
* relationship
*/
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
.toArray

/*Return empty if the real tokens are empty*/
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {
val tokenized = tokenizeWithAlignment(tokenizedSentences)

val withEmbeddings = getModelIfNotSet.predict(
tokenized,
tokenizedSentences,
$(batchSize),
$(maxSentenceLength),
$(caseSensitive))
WordpieceEmbeddingsSentence.pack(withEmbeddings)
// Unpack annotations and zip each sentence to the index or the row it belongs to
val sentencesWithRow = batchedAnnotations.zipWithIndex
.flatMap { case (annotations, i) =>
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
}

// Tokenize sentences
val tokenizedSentences = tokenizeWithAlignment(sentencesWithRow.map(_._1))

// Process all sentences
val sentenceWordEmbeddings = getModelIfNotSet.predict(
tokenizedSentences,
sentencesWithRow.map(_._1),
$(batchSize),
$(maxSentenceLength),
$(caseSensitive))

// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
batchedAnnotations.indices.map(rowIndex => {
val rowEmbeddings = sentenceWordEmbeddings
// zip each annotation with its corresponding row index
.zip(sentencesWithRow)
// select the sentences belonging to the current row
.filter(_._2._2 == rowIndex)
// leave the annotation only
.map(_._1)

if (rowEmbeddings.nonEmpty)
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
else
Seq.empty[Annotation]
})
else {
Seq(Seq.empty[Annotation])
}
}

override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,20 +277,31 @@ class XlmRoBertaEmbeddings(override val uid: String)
* relationship
*/
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
val batchedTokenizedSentences: Array[Array[TokenizedSentence]] = batchedAnnotations
.map(annotations => TokenizedWithSentence.unpack(annotations).toArray)
.toArray

/*Return empty if the real tokens are empty*/
if (batchedTokenizedSentences.nonEmpty) batchedTokenizedSentences.map(tokenizedSentences => {

val embeddings =
getModelIfNotSet.predict(tokenizedSentences, $(batchSize), $(maxSentenceLength))
WordpieceEmbeddingsSentence.pack(embeddings)
// Unpack annotations and zip each sentence to the index or the row it belongs to
val sentencesWithRow = batchedAnnotations.zipWithIndex
.flatMap { case (annotations, i) =>
TokenizedWithSentence.unpack(annotations).toArray.map(x => (x, i))
}

val sentenceWordEmbeddings =
getModelIfNotSet.predict(sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength))

// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
batchedAnnotations.indices.map(rowIndex => {
val rowEmbeddings = sentenceWordEmbeddings
// zip each annotation with its corresponding row index
.zip(sentencesWithRow)
// select the sentences belonging to the current row
.filter(_._2._2 == rowIndex)
// leave the annotation only
.map(_._1)

if (rowEmbeddings.nonEmpty)
WordpieceEmbeddingsSentence.pack(rowEmbeddings)
else
Seq.empty[Annotation]
})
else {
Seq(Seq.empty[Annotation])
}

}

override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
Expand Down
Loading