Skip to content

Commit 3c8fa50

Browse files
Ishiiharamengxr
authored andcommitted
[SPARK-3097][MLlib] Word2Vec performance improvement
mengxr Please review the code. Adding weights in reduceByKey soon. Only output model entry for words appeared in the partition before merging and use reduceByKey to combine model. In general, this implementation is 30s or so faster than implementation using big array. Author: Liquan Pei <[email protected]> Closes apache#1932 from Ishiihara/Word2Vec-improve2 and squashes the following commits: d5377a9 [Liquan Pei] use syn0Global and syn1Global to represent model cad2011 [Liquan Pei] bug fix for synModify array out of bound 083aa66 [Liquan Pei] update synGlobal in place and reduce synOut size 9075e1c [Liquan Pei] combine syn0Global and syn1Global to synGlobal aa2ab36 [Liquan Pei] use reduceByKey to combine models
1 parent df652ea commit 3c8fa50

File tree

1 file changed

+35
-15
lines changed

1 file changed

+35
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
3434
import org.apache.spark.rdd._
3535
import org.apache.spark.util.Utils
3636
import org.apache.spark.util.random.XORShiftRandom
37+
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
3738

3839
/**
3940
* Entry in vocabulary
@@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging {
287288
var syn0Global =
288289
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
289290
var syn1Global = new Array[Float](vocabSize * vectorSize)
290-
291291
var alpha = startingAlpha
292292
for (k <- 1 to numIterations) {
293293
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
294294
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
295+
val syn0Modify = new Array[Int](vocabSize)
296+
val syn1Modify = new Array[Int](vocabSize)
295297
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
296298
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
297299
var lwc = lastWordCount
@@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging {
321323
// Hierarchical softmax
322324
var d = 0
323325
while (d < bcVocab.value(word).codeLen) {
324-
val l2 = bcVocab.value(word).point(d) * vectorSize
326+
val inner = bcVocab.value(word).point(d)
327+
val l2 = inner * vectorSize
325328
// Propagate hidden -> output
326329
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
327330
if (f > -MAX_EXP && f < MAX_EXP) {
@@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging {
330333
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
331334
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
332335
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
336+
syn1Modify(inner) += 1
333337
}
334338
d += 1
335339
}
336340
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
341+
syn0Modify(lastWord) += 1
337342
}
338343
}
339344
a += 1
@@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging {
342347
}
343348
(syn0, syn1, lwc, wc)
344349
}
345-
Iterator(model)
350+
val syn0Local = model._1
351+
val syn1Local = model._2
352+
val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
353+
var index = 0
354+
while(index < vocabSize) {
355+
if (syn0Modify(index) != 0) {
356+
synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
357+
}
358+
if (syn1Modify(index) != 0) {
359+
synOut.update(index + vocabSize,
360+
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
361+
}
362+
index += 1
363+
}
364+
Iterator(synOut)
346365
}
347-
val (aggSyn0, aggSyn1, _, _) =
348-
partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
349-
val n = syn0_1.length
350-
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
351-
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
352-
blas.sscal(n, weight1, syn0_1, 1)
353-
blas.sscal(n, weight1, syn1_1, 1)
354-
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
355-
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
356-
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
366+
val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
367+
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
368+
v1
369+
}.collect()
370+
var i = 0
371+
while (i < synAgg.length) {
372+
val index = synAgg(i)._1
373+
if (index < vocabSize) {
374+
Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
375+
} else {
376+
Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
357377
}
358-
syn0Global = aggSyn0
359-
syn1Global = aggSyn1
378+
i += 1
379+
}
360380
}
361381
newSentences.unpersist()
362382

0 commit comments

Comments
 (0)