Skip to content

Commit 384c771

Browse files
committed
remove minCount and window from constructor
change model to use float instead of double
1 parent e93e726 commit 384c771

File tree

1 file changed

+63
-67
lines changed

1 file changed

+63
-67
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 63 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
3131
import org.apache.spark.HashPartitioner
3232
import org.apache.spark.storage.StorageLevel
3333
import org.apache.spark.mllib.rdd.RDDFunctions._
34+
3435
/**
3536
* Entry in vocabulary
3637
*/
@@ -61,18 +62,15 @@ private case class VocabWord(
6162
* Distributed Representations of Words and Phrases and their Compositionality.
6263
* @param size vector dimension
6364
* @param startingAlpha initial learning rate
64-
* @param window context words from [-window, window]
65-
* @param minCount minimum frequncy to consider a vocabulary word
66-
* @param parallelisum number of partitions to run Word2Vec
65+
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
66+
* @param numIterations number of iterations to run, should be smaller than or equal to parallelism
6767
*/
6868
@Experimental
6969
class Word2Vec(
7070
val size: Int,
7171
val startingAlpha: Double,
72-
val window: Int,
73-
val minCount: Int,
74-
val parallelism:Int = 1,
75-
val numIterations:Int = 1)
72+
val parallelism: Int = 1,
73+
val numIterations: Int = 1)
7674
extends Serializable with Logging {
7775

7876
private val EXP_TABLE_SIZE = 1000
@@ -81,7 +79,13 @@ class Word2Vec(
8179
private val MAX_SENTENCE_LENGTH = 1000
8280
private val layer1Size = size
8381
private val modelPartitionNum = 100
84-
82+
83+
/** context words from [-window, window] */
84+
private val window = 5
85+
86+
/** minimum frequency to consider a vocabulary word */
87+
private val minCount = 5
88+
8589
private var trainWordsCount = 0
8690
private var vocabSize = 0
8791
private var vocab: Array[VocabWord] = null
@@ -99,7 +103,7 @@ class Word2Vec(
99103
0))
100104
.filter(_.cn >= minCount)
101105
.collect()
102-
.sortWith((a, b)=> a.cn > b.cn)
106+
.sortWith((a, b) => a.cn > b.cn)
103107

104108
vocabSize = vocab.length
105109
var a = 0
@@ -111,16 +115,12 @@ class Word2Vec(
111115
logInfo("trainWordsCount = " + trainWordsCount)
112116
}
113117

114-
private def learnVocabPerPartition(words:RDD[String]) {
115-
116-
}
117-
118-
private def createExpTable(): Array[Double] = {
119-
val expTable = new Array[Double](EXP_TABLE_SIZE)
118+
private def createExpTable(): Array[Float] = {
119+
val expTable = new Array[Float](EXP_TABLE_SIZE)
120120
var i = 0
121121
while (i < EXP_TABLE_SIZE) {
122122
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
123-
expTable(i) = tmp / (tmp + 1)
123+
expTable(i) = (tmp / (tmp + 1.0)).toFloat
124124
i += 1
125125
}
126126
expTable
@@ -209,7 +209,7 @@ class Word2Vec(
209209
* @return a Word2VecModel
210210
*/
211211

212-
def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = {
212+
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
213213

214214
val words = dataset.flatMap(x => x)
215215

@@ -223,39 +223,37 @@ class Word2Vec(
223223
val bcVocab = sc.broadcast(vocab)
224224
val bcVocabHash = sc.broadcast(vocabHash)
225225

226-
val sentences: RDD[Array[Int]] = words.mapPartitions {
227-
iter => { new Iterator[Array[Int]] {
228-
def hasNext = iter.hasNext
229-
230-
def next = {
231-
var sentence = new ArrayBuffer[Int]
232-
var sentenceLength = 0
233-
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
234-
val word = bcVocabHash.value.get(iter.next)
235-
word match {
236-
case Some(w) => {
237-
sentence += w
238-
sentenceLength += 1
239-
}
240-
case None =>
241-
}
226+
val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
227+
new Iterator[Array[Int]] {
228+
def hasNext: Boolean = iter.hasNext
229+
230+
def next(): Array[Int] = {
231+
var sentence = new ArrayBuffer[Int]
232+
var sentenceLength = 0
233+
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
234+
val word = bcVocabHash.value.get(iter.next())
235+
word match {
236+
case Some(w) =>
237+
sentence += w
238+
sentenceLength += 1
239+
case None =>
242240
}
243-
sentence.toArray
244241
}
242+
sentence.toArray
245243
}
246244
}
247245
}
248246

249247
val newSentences = sentences.repartition(parallelism).cache()
250-
var syn0Global
251-
= Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
252-
var syn1Global = new Array[Double](vocabSize * layer1Size)
248+
var syn0Global =
249+
Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
250+
var syn1Global = new Array[Float](vocabSize * layer1Size)
253251

254252
for(iter <- 1 to numIterations) {
255253
val (aggSyn0, aggSyn1, _, _) =
256-
// TODO: broadcast temp instead of serializing it directly
254+
// TODO: broadcast temp instead of serializing it directly
257255
// or initialize the model in each executor
258-
newSentences.treeAggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))(
256+
newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
259257
seqOp = (c, v) => (c, v) match {
260258
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
261259
var lwc = lastWordCount
@@ -280,23 +278,23 @@ class Word2Vec(
280278
if (c >= 0 && c < sentence.size) {
281279
val lastWord = sentence(c)
282280
val l1 = lastWord * layer1Size
283-
val neu1e = new Array[Double](layer1Size)
281+
val neu1e = new Array[Float](layer1Size)
284282
// Hierarchical softmax
285283
var d = 0
286284
while (d < bcVocab.value(word).codeLen) {
287285
val l2 = bcVocab.value(word).point(d) * layer1Size
288286
// Propagate hidden -> output
289-
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
287+
var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1)
290288
if (f > -MAX_EXP && f < MAX_EXP) {
291289
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
292290
f = expTable.value(ind)
293-
val g = (1 - bcVocab.value(word).code(d) - f) * alpha
294-
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
295-
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
291+
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
292+
blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
293+
blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
296294
}
297295
d += 1
298296
}
299-
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
297+
blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1)
300298
}
301299
}
302300
a += 1
@@ -308,24 +306,24 @@ class Word2Vec(
308306
combOp = (c1, c2) => (c1, c2) match {
309307
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
310308
val n = syn0_1.length
311-
val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
312-
val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
313-
blas.dscal(n, weight1, syn0_1, 1)
314-
blas.dscal(n, weight1, syn1_1, 1)
315-
blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1)
316-
blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1)
309+
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
310+
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
311+
blas.sscal(n, weight1, syn0_1, 1)
312+
blas.sscal(n, weight1, syn1_1, 1)
313+
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
314+
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
317315
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
318316
})
319317
syn0Global = aggSyn0
320318
syn1Global = aggSyn1
321319
}
322320
newSentences.unpersist()
323321

324-
val wordMap = new Array[(String, Array[Double])](vocabSize)
322+
val wordMap = new Array[(String, Array[Float])](vocabSize)
325323
var i = 0
326324
while (i < vocabSize) {
327325
val word = bcVocab.value(i).word
328-
val vector = new Array[Double](layer1Size)
326+
val vector = new Array[Float](layer1Size)
329327
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
330328
wordMap(i) = (word, vector)
331329
i += 1
@@ -341,15 +339,15 @@ class Word2Vec(
341339
/**
342340
* Word2Vec model
343341
*/
344-
class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable {
342+
class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Serializable {
345343

346-
private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = {
344+
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
347345
require(v1.length == v2.length, "Vectors should have the same length")
348346
val n = v1.length
349-
val norm1 = blas.dnrm2(n, v1, 1)
350-
val norm2 = blas.dnrm2(n, v2, 1)
347+
val norm1 = blas.snrm2(n, v1, 1)
348+
val norm2 = blas.snrm2(n, v2, 1)
351349
if (norm1 == 0 || norm2 == 0) return 0.0
352-
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2
350+
blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
353351
}
354352

355353
/**
@@ -362,7 +360,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
362360
if (result.isEmpty) {
363361
throw new IllegalStateException(s"${word} not in vocabulary")
364362
}
365-
else Vectors.dense(result(0))
363+
else Vectors.dense(result(0).map(_.toDouble))
366364
}
367365

368366
/**
@@ -394,7 +392,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
394392
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
395393
require(num > 0, "Number of similar words should > 0")
396394
val topK = model.map { case(w, vec) =>
397-
(cosineSimilarity(vector.toArray, vec), w) }
395+
(cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
398396
.sortByKey(ascending = false)
399397
.take(num + 1)
400398
.map(_.swap)
@@ -410,18 +408,16 @@ object Word2Vec{
410408
* @param input RDD of words
411409
* @param size vector dimension
412410
* @param startingAlpha initial learning rate
413-
* @param window context words from [-window, window]
414-
* @param minCount minimum frequncy to consider a vocabulary word
415-
* @return Word2Vec model
416-
*/
411+
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
412+
* @param numIterations number of iterations, should be smaller than or equal to parallelism
413+
* @return Word2Vec model
414+
*/
417415
def train[S <: Iterable[String]](
418416
input: RDD[S],
419417
size: Int,
420418
startingAlpha: Double,
421-
window: Int,
422-
minCount: Int,
423419
parallelism: Int = 1,
424420
numIterations:Int = 1): Word2VecModel = {
425-
new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input)
421+
new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
426422
}
427423
}

0 commit comments

Comments
 (0)