@@ -87,7 +87,7 @@ class Word2Vec(
87
87
private var vocabHash = mutable.HashMap .empty[String , Int ]
88
88
private var alpha = startingAlpha
89
89
90
- private def learnVocab (words: RDD [String ]) {
90
+ private def learnVocab (words: RDD [String ]){
91
91
vocab = words.map(w => (w, 1 ))
92
92
.reduceByKey(_ + _)
93
93
.map(x => VocabWord (
@@ -110,6 +110,10 @@ class Word2Vec(
110
110
logInfo(" trainWordsCount = " + trainWordsCount)
111
111
}
112
112
113
+ private def learnVocabPerPartition (words: RDD [String ]) {
114
+
115
+ }
116
+
113
117
private def createExpTable (): Array [Double ] = {
114
118
val expTable = new Array [Double ](EXP_TABLE_SIZE )
115
119
var i = 0
@@ -303,8 +307,12 @@ class Word2Vec(
303
307
combOp = (c1, c2) => (c1, c2) match {
304
308
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
305
309
val n = syn0_1.length
306
- blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
307
- blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
310
+ val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
311
+ val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
312
+ blas.dscal(n, weight1, syn0_1, 1 )
313
+ blas.dscal(n, weight1, syn1_1, 1 )
314
+ blas.daxpy(n, weight2, syn0_2, 1 , syn0_1, 1 )
315
+ blas.daxpy(n, weight2, syn1_2, 1 , syn1_1, 1 )
308
316
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
309
317
})
310
318
syn0Global = aggSyn0
0 commit comments