@@ -28,6 +28,7 @@ import org.apache.spark.annotation.Experimental
28
28
import org .apache .spark .mllib .linalg .{Vector , Vectors }
29
29
import org .apache .spark .mllib .rdd .RDDFunctions ._
30
30
import org .apache .spark .rdd ._
31
+ import org .apache .spark .util .Utils
31
32
import org .apache .spark .util .random .XORShiftRandom
32
33
33
34
/**
@@ -58,29 +59,63 @@ private case class VocabWord(
58
59
* Efficient Estimation of Word Representations in Vector Space
59
60
* and
60
61
* Distributed Representations of Words and Phrases and their Compositionality.
61
- * @param size vector dimension
62
- * @param startingAlpha initial learning rate
63
- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
64
- * @param numIterations number of iterations to run, should be smaller than or equal to parallelism
65
62
*/
66
63
@ Experimental
67
- class Word2Vec (
68
- val size : Int ,
69
- val startingAlpha : Double ,
70
- val parallelism : Int ,
71
- val numIterations : Int ) extends Serializable with Logging {
64
+ class Word2Vec extends Serializable with Logging {
65
+
66
+ private var vectorSize = 100
67
+ private var startingAlpha = 0.025
68
+ private var numPartitions = 1
69
+ private var numIterations = 1
70
+ private var seed = Utils .random.nextLong()
71
+
72
+ /**
73
+ * Sets vector size (default: 100).
74
+ */
75
+ def setVectorSize (vectorSize : Int ): this .type = {
76
+ this .vectorSize = vectorSize
77
+ this
78
+ }
79
+
80
+ /**
81
+ * Sets initial learning rate (default: 0.025).
82
+ */
83
+ def setLearningRate (learningRate : Double ): this .type = {
84
+ this .startingAlpha = learningRate
85
+ this
86
+ }
72
87
73
88
/**
74
- * Word2Vec with a single thread .
89
+ * Sets number of partitions (default: 1). Use a small number for accuracy .
75
90
*/
76
- def this (size : Int , startingAlpha : Int ) = this (size, startingAlpha, 1 , 1 )
91
+ def setNumPartitions (numPartitions : Int ): this .type = {
92
+ require(numPartitions > 0 , s " numPartitions must be greater than 0 but got $numPartitions" )
93
+ this .numPartitions = numPartitions
94
+ this
95
+ }
96
+
97
+ /**
98
+ * Sets number of iterations (default: 1), which should be smaller than or equal to number of
99
+ * partitions.
100
+ */
101
+ def setNumIterations (numIterations : Int ): this .type = {
102
+ this .numIterations = numIterations
103
+ this
104
+ }
105
+
106
+ /**
107
+ * Sets random seed (default: a random long integer).
108
+ */
109
+ def setSeed (seed : Long ): this .type = {
110
+ this .seed = seed
111
+ this
112
+ }
77
113
78
114
private val EXP_TABLE_SIZE = 1000
79
115
private val MAX_EXP = 6
80
116
private val MAX_CODE_LENGTH = 40
81
117
private val MAX_SENTENCE_LENGTH = 1000
82
- private val layer1Size = size
83
- private val modelPartitionNum = 100
118
+ private val layer1Size = vectorSize
84
119
85
120
/** context words from [-window, window] */
86
121
private val window = 5
@@ -245,8 +280,7 @@ class Word2Vec(
245
280
}
246
281
}
247
282
248
- val newSentences = sentences.repartition(parallelism).cache()
249
- val seed = 5875483L
283
+ val newSentences = sentences.repartition(numPartitions).cache()
250
284
val initRandom = new XORShiftRandom (seed)
251
285
var syn0Global =
252
286
Array .fill[Float ](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f ) / layer1Size)
@@ -263,7 +297,7 @@ class Word2Vec(
263
297
lwc = wordCount
264
298
// TODO: discount by iteration?
265
299
alpha =
266
- startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1 ))
300
+ startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1 ))
267
301
if (alpha < startingAlpha * 0.0001 ) alpha = startingAlpha * 0.0001
268
302
logInfo(" wordCount = " + wordCount + " , alpha = " + alpha)
269
303
}
@@ -404,23 +438,3 @@ class Word2VecModel private[mllib] (
404
438
.toArray
405
439
}
406
440
}
407
-
408
- object Word2Vec {
409
- /**
410
- * Train Word2Vec model
411
- * @param input RDD of words
412
- * @param size vector dimension
413
- * @param startingAlpha initial learning rate
414
- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
415
- * @param numIterations number of iterations, should be smaller than or equal to parallelism
416
- * @return Word2Vec model
417
- */
418
- def train [S <: Iterable [String ]](
419
- input : RDD [S ],
420
- size : Int ,
421
- startingAlpha : Double ,
422
- parallelism : Int = 1 ,
423
- numIterations: Int = 1 ): Word2VecModel = {
424
- new Word2Vec (size,startingAlpha, parallelism, numIterations).fit[S ](input)
425
- }
426
- }
0 commit comments