@@ -30,6 +30,7 @@ import org.apache.spark.SparkContext._
30
30
import org .apache .spark .mllib .linalg .{Vector , Vectors }
31
31
import org .apache .spark .HashPartitioner
32
32
import org .apache .spark .storage .StorageLevel
33
+ import org .apache .spark .mllib .rdd .RDDFunctions ._
33
34
/**
34
35
* Entry in vocabulary
35
36
*/
@@ -111,9 +112,9 @@ class Word2Vec(
111
112
}
112
113
113
114
private def learnVocabPerPartition (words: RDD [String ]) {
114
-
115
+
115
116
}
116
-
117
+
117
118
private def createExpTable (): Array [Double ] = {
118
119
val expTable = new Array [Double ](EXP_TABLE_SIZE )
119
120
var i = 0
@@ -254,7 +255,7 @@ class Word2Vec(
254
255
val (aggSyn0, aggSyn1, _, _) =
255
256
// TODO: broadcast temp instead of serializing it directly
256
257
// or initialize the model in each executor
257
- newSentences.aggregate ((syn0Global.clone(), syn1Global.clone(), 0 , 0 ))(
258
+ newSentences.treeAggregate ((syn0Global.clone(), syn1Global.clone(), 0 , 0 ))(
258
259
seqOp = (c, v) => (c, v) match {
259
260
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
260
261
var lwc = lastWordCount
0 commit comments