17
17
18
18
package org .apache .spark .mllib .feature
19
19
20
- import scala .util .Random
21
- import scala .collection .mutable .ArrayBuffer
22
20
import scala .collection .mutable
21
+ import scala .collection .mutable .ArrayBuffer
22
+ import scala .util .Random
23
23
24
24
import com .github .fommil .netlib .BLAS .{getInstance => blas }
25
-
26
- import org .apache .spark .annotation .Experimental
27
- import org .apache .spark .Logging
28
- import org .apache .spark .rdd ._
25
+ import org .apache .spark .{HashPartitioner , Logging }
29
26
import org .apache .spark .SparkContext ._
27
+ import org .apache .spark .annotation .Experimental
30
28
import org .apache .spark .mllib .linalg .{Vector , Vectors }
31
- import org .apache .spark .HashPartitioner
32
- import org .apache .spark .storage .StorageLevel
33
29
import org .apache .spark .mllib .rdd .RDDFunctions ._
30
+ import org .apache .spark .rdd ._
31
+ import org .apache .spark .storage .StorageLevel
32
+
34
33
/**
35
34
* Entry in vocabulary
36
35
*/
@@ -52,7 +51,7 @@ private case class VocabWord(
52
51
*
53
52
* We used skip-gram model in our implementation and hierarchical softmax
54
53
* method to train the model. The variable names in the implementation
55
- * mathes the original C implementation.
54
+ * matches the original C implementation.
56
55
*
57
56
* For original C implementation, see https://code.google.com/p/word2vec/
58
57
* For research papers, see
@@ -61,34 +60,41 @@ private case class VocabWord(
61
60
* Distributed Representations of Words and Phrases and their Compositionality.
62
61
* @param size vector dimension
63
62
* @param startingAlpha initial learning rate
64
- * @param window context words from [-window, window]
65
- * @param minCount minimum frequncy to consider a vocabulary word
66
- * @param parallelisum number of partitions to run Word2Vec
63
+ * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
64
+ * @param numIterations number of iterations to run, should be smaller than or equal to parallelism
67
65
*/
68
66
@ Experimental
69
67
class Word2Vec (
70
68
val size : Int ,
71
69
val startingAlpha : Double ,
72
- val window : Int ,
73
- val minCount : Int ,
74
- val parallelism : Int = 1 ,
75
- val numIterations : Int = 1 )
76
- extends Serializable with Logging {
77
-
70
+ val parallelism : Int ,
71
+ val numIterations : Int ) extends Serializable with Logging {
72
+
73
+ /**
74
+ * Word2Vec with a single thread.
75
+ */
76
+ def this (size : Int , startingAlpha : Int ) = this (size, startingAlpha, 1 , 1 )
77
+
78
78
private val EXP_TABLE_SIZE = 1000
79
79
private val MAX_EXP = 6
80
80
private val MAX_CODE_LENGTH = 40
81
81
private val MAX_SENTENCE_LENGTH = 1000
82
82
private val layer1Size = size
83
83
private val modelPartitionNum = 100
84
-
84
+
85
+ /** context words from [-window, window] */
86
+ private val window = 5
87
+
88
+ /** minimum frequency to consider a vocabulary word */
89
+ private val minCount = 5
90
+
85
91
private var trainWordsCount = 0
86
92
private var vocabSize = 0
87
93
private var vocab : Array [VocabWord ] = null
88
94
private var vocabHash = mutable.HashMap .empty[String , Int ]
89
95
private var alpha = startingAlpha
90
96
91
- private def learnVocab (words: RDD [String ]){
97
+ private def learnVocab (words: RDD [String ]): Unit = {
92
98
vocab = words.map(w => (w, 1 ))
93
99
.reduceByKey(_ + _)
94
100
.map(x => VocabWord (
@@ -99,7 +105,7 @@ class Word2Vec(
99
105
0 ))
100
106
.filter(_.cn >= minCount)
101
107
.collect()
102
- .sortWith((a, b)=> a.cn > b.cn)
108
+ .sortWith((a, b) => a.cn > b.cn)
103
109
104
110
vocabSize = vocab.length
105
111
var a = 0
@@ -111,22 +117,18 @@ class Word2Vec(
111
117
logInfo(" trainWordsCount = " + trainWordsCount)
112
118
}
113
119
114
- private def learnVocabPerPartition (words: RDD [String ]) {
115
-
116
- }
117
-
118
- private def createExpTable (): Array [Double ] = {
119
- val expTable = new Array [Double ](EXP_TABLE_SIZE )
120
+ private def createExpTable (): Array [Float ] = {
121
+ val expTable = new Array [Float ](EXP_TABLE_SIZE )
120
122
var i = 0
121
123
while (i < EXP_TABLE_SIZE ) {
122
124
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0 ) * MAX_EXP )
123
- expTable(i) = tmp / (tmp + 1 )
125
+ expTable(i) = ( tmp / (tmp + 1.0 )).toFloat
124
126
i += 1
125
127
}
126
128
expTable
127
129
}
128
130
129
- private def createBinaryTree () {
131
+ private def createBinaryTree (): Unit = {
130
132
val count = new Array [Long ](vocabSize * 2 + 1 )
131
133
val binary = new Array [Int ](vocabSize * 2 + 1 )
132
134
val parentNode = new Array [Int ](vocabSize * 2 + 1 )
@@ -208,8 +210,7 @@ class Word2Vec(
208
210
* @param dataset an RDD of words
209
211
* @return a Word2VecModel
210
212
*/
211
-
212
- def fit [S <: Iterable [String ]](dataset: RDD [S ]): Word2VecModel = {
213
+ def fit [S <: Iterable [String ]](dataset : RDD [S ]): Word2VecModel = {
213
214
214
215
val words = dataset.flatMap(x => x)
215
216
@@ -223,39 +224,37 @@ class Word2Vec(
223
224
val bcVocab = sc.broadcast(vocab)
224
225
val bcVocabHash = sc.broadcast(vocabHash)
225
226
226
- val sentences : RDD [Array [Int ]] = words.mapPartitions {
227
- iter => { new Iterator [Array [Int ]] {
228
- def hasNext = iter.hasNext
229
-
230
- def next = {
231
- var sentence = new ArrayBuffer [Int ]
232
- var sentenceLength = 0
233
- while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH ) {
234
- val word = bcVocabHash.value.get(iter.next)
235
- word match {
236
- case Some (w) => {
237
- sentence += w
238
- sentenceLength += 1
239
- }
240
- case None =>
241
- }
227
+ val sentences : RDD [Array [Int ]] = words.mapPartitions { iter =>
228
+ new Iterator [Array [Int ]] {
229
+ def hasNext : Boolean = iter.hasNext
230
+
231
+ def next (): Array [Int ] = {
232
+ var sentence = new ArrayBuffer [Int ]
233
+ var sentenceLength = 0
234
+ while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH ) {
235
+ val word = bcVocabHash.value.get(iter.next())
236
+ word match {
237
+ case Some (w) =>
238
+ sentence += w
239
+ sentenceLength += 1
240
+ case None =>
242
241
}
243
- sentence.toArray
244
242
}
243
+ sentence.toArray
245
244
}
246
245
}
247
246
}
248
247
249
248
val newSentences = sentences.repartition(parallelism).cache()
250
- var syn0Global
251
- = Array .fill[Double ](vocabSize * layer1Size)((Random .nextDouble - 0.5 ) / layer1Size)
252
- var syn1Global = new Array [Double ](vocabSize * layer1Size)
249
+ var syn0Global =
250
+ Array .fill[Float ](vocabSize * layer1Size)((Random .nextFloat() - 0.5f ) / layer1Size)
251
+ var syn1Global = new Array [Float ](vocabSize * layer1Size)
253
252
254
253
for (iter <- 1 to numIterations) {
255
254
val (aggSyn0, aggSyn1, _, _) =
256
- // TODO: broadcast temp instead of serializing it directly
255
+ // TODO: broadcast temp instead of serializing it directly
257
256
// or initialize the model in each executor
258
- newSentences.treeAggregate((syn0Global.clone() , syn1Global.clone() , 0 , 0 ))(
257
+ newSentences.treeAggregate((syn0Global, syn1Global, 0 , 0 ))(
259
258
seqOp = (c, v) => (c, v) match {
260
259
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
261
260
var lwc = lastWordCount
@@ -280,23 +279,23 @@ class Word2Vec(
280
279
if (c >= 0 && c < sentence.size) {
281
280
val lastWord = sentence(c)
282
281
val l1 = lastWord * layer1Size
283
- val neu1e = new Array [Double ](layer1Size)
282
+ val neu1e = new Array [Float ](layer1Size)
284
283
// Hierarchical softmax
285
284
var d = 0
286
285
while (d < bcVocab.value(word).codeLen) {
287
286
val l2 = bcVocab.value(word).point(d) * layer1Size
288
287
// Propagate hidden -> output
289
- var f = blas.ddot (layer1Size, syn0, l1, 1 , syn1, l2, 1 )
288
+ var f = blas.sdot (layer1Size, syn0, l1, 1 , syn1, l2, 1 )
290
289
if (f > - MAX_EXP && f < MAX_EXP ) {
291
290
val ind = ((f + MAX_EXP ) * (EXP_TABLE_SIZE / MAX_EXP / 2.0 )).toInt
292
291
f = expTable.value(ind)
293
- val g = (1 - bcVocab.value(word).code(d) - f) * alpha
294
- blas.daxpy (layer1Size, g, syn1, l2, 1 , neu1e, 0 , 1 )
295
- blas.daxpy (layer1Size, g, syn0, l1, 1 , syn1, l2, 1 )
292
+ val g = (( 1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
293
+ blas.saxpy (layer1Size, g, syn1, l2, 1 , neu1e, 0 , 1 )
294
+ blas.saxpy (layer1Size, g, syn0, l1, 1 , syn1, l2, 1 )
296
295
}
297
296
d += 1
298
297
}
299
- blas.daxpy (layer1Size, 1.0 , neu1e, 0 , 1 , syn0, l1, 1 )
298
+ blas.saxpy (layer1Size, 1.0f , neu1e, 0 , 1 , syn0, l1, 1 )
300
299
}
301
300
}
302
301
a += 1
@@ -308,24 +307,24 @@ class Word2Vec(
308
307
combOp = (c1, c2) => (c1, c2) match {
309
308
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
310
309
val n = syn0_1.length
311
- val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
312
- val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
313
- blas.dscal (n, weight1, syn0_1, 1 )
314
- blas.dscal (n, weight1, syn1_1, 1 )
315
- blas.daxpy (n, weight2, syn0_2, 1 , syn0_1, 1 )
316
- blas.daxpy (n, weight2, syn1_2, 1 , syn1_1, 1 )
310
+ val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
311
+ val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
312
+ blas.sscal (n, weight1, syn0_1, 1 )
313
+ blas.sscal (n, weight1, syn1_1, 1 )
314
+ blas.saxpy (n, weight2, syn0_2, 1 , syn0_1, 1 )
315
+ blas.saxpy (n, weight2, syn1_2, 1 , syn1_1, 1 )
317
316
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
318
317
})
319
318
syn0Global = aggSyn0
320
319
syn1Global = aggSyn1
321
320
}
322
321
newSentences.unpersist()
323
322
324
- val wordMap = new Array [(String , Array [Double ])](vocabSize)
323
+ val wordMap = new Array [(String , Array [Float ])](vocabSize)
325
324
var i = 0
326
325
while (i < vocabSize) {
327
326
val word = bcVocab.value(i).word
328
- val vector = new Array [Double ](layer1Size)
327
+ val vector = new Array [Float ](layer1Size)
329
328
Array .copy(syn0Global, i * layer1Size, vector, 0 , layer1Size)
330
329
wordMap(i) = (word, vector)
331
330
i += 1
@@ -341,15 +340,15 @@ class Word2Vec(
341
340
/**
342
341
* Word2Vec model
343
342
*/
344
- class Word2VecModel (private val model : RDD [(String , Array [Double ])]) extends Serializable {
343
+ class Word2VecModel (private val model : RDD [(String , Array [Float ])]) extends Serializable {
345
344
346
- private def cosineSimilarity (v1 : Array [Double ], v2 : Array [Double ]): Double = {
345
+ private def cosineSimilarity (v1 : Array [Float ], v2 : Array [Float ]): Double = {
347
346
require(v1.length == v2.length, " Vectors should have the same length" )
348
347
val n = v1.length
349
- val norm1 = blas.dnrm2 (n, v1, 1 )
350
- val norm2 = blas.dnrm2 (n, v2, 1 )
348
+ val norm1 = blas.snrm2 (n, v1, 1 )
349
+ val norm2 = blas.snrm2 (n, v2, 1 )
351
350
if (norm1 == 0 || norm2 == 0 ) return 0.0
352
- blas.ddot (n, v1, 1 , v2,1 ) / norm1 / norm2
351
+ blas.sdot (n, v1, 1 , v2,1 ) / norm1 / norm2
353
352
}
354
353
355
354
/**
@@ -360,9 +359,9 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
360
359
def transform (word : String ): Vector = {
361
360
val result = model.lookup(word)
362
361
if (result.isEmpty) {
363
- throw new IllegalStateException (s " ${ word} not in vocabulary " )
362
+ throw new IllegalStateException (s " $word not in vocabulary " )
364
363
}
365
- else Vectors .dense(result(0 ))
364
+ else Vectors .dense(result(0 ).map(_.toDouble) )
366
365
}
367
366
368
367
/**
@@ -394,7 +393,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
394
393
def findSynonyms (vector : Vector , num : Int ): Array [(String , Double )] = {
395
394
require(num > 0 , " Number of similar words should > 0" )
396
395
val topK = model.map { case (w, vec) =>
397
- (cosineSimilarity(vector.toArray, vec), w) }
396
+ (cosineSimilarity(vector.toArray.map(_.toFloat) , vec), w) }
398
397
.sortByKey(ascending = false )
399
398
.take(num + 1 )
400
399
.map(_.swap)
@@ -410,18 +409,16 @@ object Word2Vec{
410
409
* @param input RDD of words
411
410
* @param size vector dimension
412
411
* @param startingAlpha initial learning rate
413
- * @param window context words from [-window, window]
414
- * @param minCount minimum frequncy to consider a vocabulary word
415
- * @return Word2Vec model
416
- */
412
+ * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
413
+ * @param numIterations number of iterations, should be smaller than or equal to parallelism
414
+ * @return Word2Vec model
415
+ */
417
416
def train [S <: Iterable [String ]](
418
417
input : RDD [S ],
419
418
size : Int ,
420
419
startingAlpha : Double ,
421
- window : Int ,
422
- minCount : Int ,
423
420
parallelism : Int = 1 ,
424
421
numIterations: Int = 1 ): Word2VecModel = {
425
- new Word2Vec (size,startingAlpha, window, minCount, parallelism, numIterations).fit[S ](input)
422
+ new Word2Vec (size,startingAlpha, parallelism, numIterations).fit[S ](input)
426
423
}
427
424
}
0 commit comments