Skip to content

Commit 26a948d

Browse files
committed
Merge pull request #1 from mengxr/Ishiihara-master
some updates
2 parents e93e726 + c14da41 commit 26a948d

File tree

2 files changed

+94
-96
lines changed

2 files changed

+94
-96
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 79 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,19 @@
1717

1818
package org.apache.spark.mllib.feature
1919

20-
import scala.util.Random
21-
import scala.collection.mutable.ArrayBuffer
2220
import scala.collection.mutable
21+
import scala.collection.mutable.ArrayBuffer
22+
import scala.util.Random
2323

2424
import com.github.fommil.netlib.BLAS.{getInstance => blas}
25-
26-
import org.apache.spark.annotation.Experimental
27-
import org.apache.spark.Logging
28-
import org.apache.spark.rdd._
25+
import org.apache.spark.{HashPartitioner, Logging}
2926
import org.apache.spark.SparkContext._
27+
import org.apache.spark.annotation.Experimental
3028
import org.apache.spark.mllib.linalg.{Vector, Vectors}
31-
import org.apache.spark.HashPartitioner
32-
import org.apache.spark.storage.StorageLevel
3329
import org.apache.spark.mllib.rdd.RDDFunctions._
30+
import org.apache.spark.rdd._
31+
import org.apache.spark.storage.StorageLevel
32+
3433
/**
3534
* Entry in vocabulary
3635
*/
@@ -52,7 +51,7 @@ private case class VocabWord(
5251
*
5352
* We used skip-gram model in our implementation and hierarchical softmax
5453
* method to train the model. The variable names in the implementation
55-
* mathes the original C implementation.
54+
* matches the original C implementation.
5655
*
5756
* For original C implementation, see https://code.google.com/p/word2vec/
5857
* For research papers, see
@@ -61,34 +60,41 @@ private case class VocabWord(
6160
* Distributed Representations of Words and Phrases and their Compositionality.
6261
* @param size vector dimension
6362
* @param startingAlpha initial learning rate
64-
* @param window context words from [-window, window]
65-
* @param minCount minimum frequncy to consider a vocabulary word
66-
* @param parallelisum number of partitions to run Word2Vec
63+
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
64+
* @param numIterations number of iterations to run, should be smaller than or equal to parallelism
6765
*/
6866
@Experimental
6967
class Word2Vec(
7068
val size: Int,
7169
val startingAlpha: Double,
72-
val window: Int,
73-
val minCount: Int,
74-
val parallelism:Int = 1,
75-
val numIterations:Int = 1)
76-
extends Serializable with Logging {
77-
70+
val parallelism: Int,
71+
val numIterations: Int) extends Serializable with Logging {
72+
73+
/**
74+
* Word2Vec with a single thread.
75+
*/
76+
def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)
77+
7878
private val EXP_TABLE_SIZE = 1000
7979
private val MAX_EXP = 6
8080
private val MAX_CODE_LENGTH = 40
8181
private val MAX_SENTENCE_LENGTH = 1000
8282
private val layer1Size = size
8383
private val modelPartitionNum = 100
84-
84+
85+
/** context words from [-window, window] */
86+
private val window = 5
87+
88+
/** minimum frequency to consider a vocabulary word */
89+
private val minCount = 5
90+
8591
private var trainWordsCount = 0
8692
private var vocabSize = 0
8793
private var vocab: Array[VocabWord] = null
8894
private var vocabHash = mutable.HashMap.empty[String, Int]
8995
private var alpha = startingAlpha
9096

91-
private def learnVocab(words:RDD[String]){
97+
private def learnVocab(words:RDD[String]): Unit = {
9298
vocab = words.map(w => (w, 1))
9399
.reduceByKey(_ + _)
94100
.map(x => VocabWord(
@@ -99,7 +105,7 @@ class Word2Vec(
99105
0))
100106
.filter(_.cn >= minCount)
101107
.collect()
102-
.sortWith((a, b)=> a.cn > b.cn)
108+
.sortWith((a, b) => a.cn > b.cn)
103109

104110
vocabSize = vocab.length
105111
var a = 0
@@ -111,22 +117,18 @@ class Word2Vec(
111117
logInfo("trainWordsCount = " + trainWordsCount)
112118
}
113119

114-
private def learnVocabPerPartition(words:RDD[String]) {
115-
116-
}
117-
118-
private def createExpTable(): Array[Double] = {
119-
val expTable = new Array[Double](EXP_TABLE_SIZE)
120+
private def createExpTable(): Array[Float] = {
121+
val expTable = new Array[Float](EXP_TABLE_SIZE)
120122
var i = 0
121123
while (i < EXP_TABLE_SIZE) {
122124
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
123-
expTable(i) = tmp / (tmp + 1)
125+
expTable(i) = (tmp / (tmp + 1.0)).toFloat
124126
i += 1
125127
}
126128
expTable
127129
}
128130

129-
private def createBinaryTree() {
131+
private def createBinaryTree(): Unit = {
130132
val count = new Array[Long](vocabSize * 2 + 1)
131133
val binary = new Array[Int](vocabSize * 2 + 1)
132134
val parentNode = new Array[Int](vocabSize * 2 + 1)
@@ -208,8 +210,7 @@ class Word2Vec(
208210
* @param dataset an RDD of words
209211
* @return a Word2VecModel
210212
*/
211-
212-
def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = {
213+
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
213214

214215
val words = dataset.flatMap(x => x)
215216

@@ -223,39 +224,37 @@ class Word2Vec(
223224
val bcVocab = sc.broadcast(vocab)
224225
val bcVocabHash = sc.broadcast(vocabHash)
225226

226-
val sentences: RDD[Array[Int]] = words.mapPartitions {
227-
iter => { new Iterator[Array[Int]] {
228-
def hasNext = iter.hasNext
229-
230-
def next = {
231-
var sentence = new ArrayBuffer[Int]
232-
var sentenceLength = 0
233-
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
234-
val word = bcVocabHash.value.get(iter.next)
235-
word match {
236-
case Some(w) => {
237-
sentence += w
238-
sentenceLength += 1
239-
}
240-
case None =>
241-
}
227+
val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
228+
new Iterator[Array[Int]] {
229+
def hasNext: Boolean = iter.hasNext
230+
231+
def next(): Array[Int] = {
232+
var sentence = new ArrayBuffer[Int]
233+
var sentenceLength = 0
234+
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
235+
val word = bcVocabHash.value.get(iter.next())
236+
word match {
237+
case Some(w) =>
238+
sentence += w
239+
sentenceLength += 1
240+
case None =>
242241
}
243-
sentence.toArray
244242
}
243+
sentence.toArray
245244
}
246245
}
247246
}
248247

249248
val newSentences = sentences.repartition(parallelism).cache()
250-
var syn0Global
251-
= Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
252-
var syn1Global = new Array[Double](vocabSize * layer1Size)
249+
var syn0Global =
250+
Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
251+
var syn1Global = new Array[Float](vocabSize * layer1Size)
253252

254253
for(iter <- 1 to numIterations) {
255254
val (aggSyn0, aggSyn1, _, _) =
256-
// TODO: broadcast temp instead of serializing it directly
255+
// TODO: broadcast temp instead of serializing it directly
257256
// or initialize the model in each executor
258-
newSentences.treeAggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))(
257+
newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
259258
seqOp = (c, v) => (c, v) match {
260259
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
261260
var lwc = lastWordCount
@@ -280,23 +279,23 @@ class Word2Vec(
280279
if (c >= 0 && c < sentence.size) {
281280
val lastWord = sentence(c)
282281
val l1 = lastWord * layer1Size
283-
val neu1e = new Array[Double](layer1Size)
282+
val neu1e = new Array[Float](layer1Size)
284283
// Hierarchical softmax
285284
var d = 0
286285
while (d < bcVocab.value(word).codeLen) {
287286
val l2 = bcVocab.value(word).point(d) * layer1Size
288287
// Propagate hidden -> output
289-
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
288+
var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1)
290289
if (f > -MAX_EXP && f < MAX_EXP) {
291290
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
292291
f = expTable.value(ind)
293-
val g = (1 - bcVocab.value(word).code(d) - f) * alpha
294-
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
295-
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
292+
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
293+
blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
294+
blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
296295
}
297296
d += 1
298297
}
299-
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
298+
blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1)
300299
}
301300
}
302301
a += 1
@@ -308,24 +307,24 @@ class Word2Vec(
308307
combOp = (c1, c2) => (c1, c2) match {
309308
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
310309
val n = syn0_1.length
311-
val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
312-
val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
313-
blas.dscal(n, weight1, syn0_1, 1)
314-
blas.dscal(n, weight1, syn1_1, 1)
315-
blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1)
316-
blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1)
310+
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
311+
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
312+
blas.sscal(n, weight1, syn0_1, 1)
313+
blas.sscal(n, weight1, syn1_1, 1)
314+
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
315+
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
317316
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
318317
})
319318
syn0Global = aggSyn0
320319
syn1Global = aggSyn1
321320
}
322321
newSentences.unpersist()
323322

324-
val wordMap = new Array[(String, Array[Double])](vocabSize)
323+
val wordMap = new Array[(String, Array[Float])](vocabSize)
325324
var i = 0
326325
while (i < vocabSize) {
327326
val word = bcVocab.value(i).word
328-
val vector = new Array[Double](layer1Size)
327+
val vector = new Array[Float](layer1Size)
329328
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
330329
wordMap(i) = (word, vector)
331330
i += 1
@@ -341,15 +340,15 @@ class Word2Vec(
341340
/**
342341
* Word2Vec model
343342
*/
344-
class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable {
343+
class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {
345344

346-
private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = {
345+
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
347346
require(v1.length == v2.length, "Vectors should have the same length")
348347
val n = v1.length
349-
val norm1 = blas.dnrm2(n, v1, 1)
350-
val norm2 = blas.dnrm2(n, v2, 1)
348+
val norm1 = blas.snrm2(n, v1, 1)
349+
val norm2 = blas.snrm2(n, v2, 1)
351350
if (norm1 == 0 || norm2 == 0) return 0.0
352-
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2
351+
blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
353352
}
354353

355354
/**
@@ -360,9 +359,9 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
360359
def transform(word: String): Vector = {
361360
val result = model.lookup(word)
362361
if (result.isEmpty) {
363-
throw new IllegalStateException(s"${word} not in vocabulary")
362+
throw new IllegalStateException(s"$word not in vocabulary")
364363
}
365-
else Vectors.dense(result(0))
364+
else Vectors.dense(result(0).map(_.toDouble))
366365
}
367366

368367
/**
@@ -394,7 +393,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
394393
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
395394
require(num > 0, "Number of similar words should > 0")
396395
val topK = model.map { case(w, vec) =>
397-
(cosineSimilarity(vector.toArray, vec), w) }
396+
(cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
398397
.sortByKey(ascending = false)
399398
.take(num + 1)
400399
.map(_.swap)
@@ -410,18 +409,16 @@ object Word2Vec{
410409
* @param input RDD of words
411410
* @param size vector dimension
412411
* @param startingAlpha initial learning rate
413-
* @param window context words from [-window, window]
414-
* @param minCount minimum frequncy to consider a vocabulary word
415-
* @return Word2Vec model
416-
*/
412+
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
413+
* @param numIterations number of iterations, should be smaller than or equal to parallelism
414+
* @return Word2Vec model
415+
*/
417416
def train[S <: Iterable[String]](
418417
input: RDD[S],
419418
size: Int,
420419
startingAlpha: Double,
421-
window: Int,
422-
minCount: Int,
423420
parallelism: Int = 1,
424421
numIterations:Int = 1): Word2VecModel = {
425-
new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input)
422+
new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
426423
}
427424
}

mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ package org.apache.spark.mllib.feature
1919

2020
import org.scalatest.FunSuite
2121

22-
import org.apache.spark.SparkContext._
2322
import org.apache.spark.mllib.util.LocalSparkContext
2423

2524
class Word2VecSuite extends FunSuite with LocalSparkContext {
25+
26+
// TODO: add more tests
27+
2628
test("Word2Vec") {
2729
val sentence = "a b " * 100 + "a c " * 10
2830
val localDoc = Seq(sentence, sentence)
@@ -33,28 +35,27 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
3335
val window = 2
3436
val minCount = 2
3537
val num = 2
36-
val word = "a"
3738

3839
val model = Word2Vec.train(doc, size, startingAlpha, window, minCount)
39-
val synons = model.findSynonyms("a", 2)
40-
assert(synons.length == num)
41-
assert(synons(0)._1 == "b")
42-
assert(synons(1)._1 == "c")
40+
val syms = model.findSynonyms("a", 2)
41+
assert(syms.length == num)
42+
assert(syms(0)._1 == "b")
43+
assert(syms(1)._1 == "c")
4344
}
4445

4546

4647
test("Word2VecModel") {
4748
val num = 2
4849
val localModel = Seq(
49-
("china" , Array(0.50, 0.50, 0.50, 0.50)),
50-
("japan" , Array(0.40, 0.50, 0.50, 0.50)),
51-
("taiwan", Array(0.60, 0.50, 0.50, 0.50)),
52-
("korea" , Array(0.45, 0.60, 0.60, 0.60))
50+
("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)),
51+
("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)),
52+
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
53+
("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f))
5354
)
5455
val model = new Word2VecModel(sc.parallelize(localModel, 2))
55-
val synons = model.findSynonyms("china", num)
56-
assert(synons.length == num)
57-
assert(synons(0)._1 == "taiwan")
58-
assert(synons(1)._1 == "japan")
56+
val syms = model.findSynonyms("china", num)
57+
assert(syms.length == num)
58+
assert(syms(0)._1 == "taiwan")
59+
assert(syms(1)._1 == "japan")
5960
}
6061
}

0 commit comments

Comments
 (0)