Skip to content

Commit 1a8fb41

Browse files
author
Liquan Pei
committed
use weighted sum in combOp
1 parent 7efbb6f commit 1a8fb41

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class Word2Vec(
8787
private var vocabHash = mutable.HashMap.empty[String, Int]
8888
private var alpha = startingAlpha
8989

90-
private def learnVocab(words:RDD[String]) {
90+
private def learnVocab(words:RDD[String]){
9191
vocab = words.map(w => (w, 1))
9292
.reduceByKey(_ + _)
9393
.map(x => VocabWord(
@@ -110,6 +110,10 @@ class Word2Vec(
110110
logInfo("trainWordsCount = " + trainWordsCount)
111111
}
112112

113+
private def learnVocabPerPartition(words:RDD[String]) {
114+
115+
}
116+
113117
private def createExpTable(): Array[Double] = {
114118
val expTable = new Array[Double](EXP_TABLE_SIZE)
115119
var i = 0
@@ -303,8 +307,12 @@ class Word2Vec(
303307
combOp = (c1, c2) => (c1, c2) match {
304308
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
305309
val n = syn0_1.length
306-
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
307-
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
310+
val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
311+
val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
312+
blas.dscal(n, weight1, syn0_1, 1)
313+
blas.dscal(n, weight1, syn1_1, 1)
314+
blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1)
315+
blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1)
308316
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
309317
})
310318
syn0Global = aggSyn0

0 commit comments

Comments
 (0)