Skip to content

Commit 0aafb1b

Browse files
author
Liquan Pei
committed
Add comments, minor fixes
1 parent 8d6befe commit 0aafb1b

File tree

1 file changed

+46
-23
lines changed

1 file changed

+46
-23
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ import org.apache.spark.SparkContext._
3131
import org.apache.spark.mllib.linalg.Vector
3232
import org.apache.spark.HashPartitioner
3333

34+
/**
35+
* Entry in vocabulary
36+
*/
3437
private case class VocabWord(
3538
var word: String,
3639
var cn: Int,
@@ -39,6 +42,9 @@ private case class VocabWord(
3942
var codeLen:Int
4043
)
4144

45+
/**
46+
* Vector representation of word
47+
*/
4248
class Word2Vec(
4349
val size: Int,
4450
val startingAlpha: Double,
@@ -51,7 +57,8 @@ class Word2Vec(
5157
private val MAX_CODE_LENGTH = 40
5258
private val MAX_SENTENCE_LENGTH = 1000
5359
private val layer1Size = size
54-
60+
private val modelPartitionNum = 100
61+
5562
private var trainWordsCount = 0
5663
private var vocabSize = 0
5764
private var vocab: Array[VocabWord] = null
@@ -169,6 +176,7 @@ class Word2Vec(
169176
* Computes the vector representation of each word in
170177
* vocabulary
171178
* @param dataset an RDD of strings
179+
* @return a Word2VecModel
172180
*/
173181

174182
def fit(dataset:RDD[String]): Word2VecModel = {
@@ -274,11 +282,14 @@ class Word2Vec(
274282
wordMap(i) = (word, vector)
275283
i += 1
276284
}
277-
val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100))
285+
val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum))
278286
new Word2VecModel(modelRDD)
279287
}
280288
}
281289

290+
/**
291+
* Word2Vec model
292+
*/
282293
class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable {
283294

284295
val model = _model
@@ -292,22 +303,46 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
292303
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2
293304
}
294305

306+
/**
307+
* Transforms a word to its vector representation
308+
* @param word a word
309+
* @return vector representation of word
310+
*/
311+
295312
def transform(word: String): Array[Double] = {
296313
val result = model.lookup(word)
297314
if (result.isEmpty) Array[Double]()
298315
else result(0)
299316
}
300317

318+
/**
319+
* Transforms an RDD to its vector representation
320+
* @param dataset a an RDD of words
321+
* @return RDD of vector representation
322+
*/
323+
301324
def transform(dataset: RDD[String]): RDD[Array[Double]] = {
302325
dataset.map(word => transform(word))
303326
}
304327

328+
/**
329+
* Find synonyms of a word
330+
* @param word a word
331+
* @param num number of synonyms to find
332+
* @return array of (word, similarity)
333+
*/
305334
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
306335
val vector = transform(word)
307336
if (vector.isEmpty) Array[(String, Double)]()
308337
else findSynonyms(vector,num)
309338
}
310339

340+
/**
341+
* Find synonyms of the vector representation of a word
342+
* @param vector vector representation of a word
343+
* @param num number of synonyms to find
344+
* @return array of (word, similarity)
345+
*/
311346
def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = {
312347
require(num > 0, "Number of similar words should > 0")
313348
val topK = model.map(
@@ -321,6 +356,15 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
321356
}
322357

323358
object Word2Vec extends Serializable with Logging {
359+
/**
360+
* Train Word2Vec model
361+
* @param input RDD of words
362+
* @param size vectoer dimension
363+
* @param startingAlpha initial learning rate
364+
* @param window context words from [-window, window]
365+
* @param minCount minimum frequncy to consider a vocabulary word
366+
* @return Word2Vec model
367+
*/
324368
def train(
325369
input: RDD[String],
326370
size: Int,
@@ -329,25 +373,4 @@ object Word2Vec extends Serializable with Logging {
329373
minCount: Int): Word2VecModel = {
330374
new Word2Vec(size,startingAlpha, window, minCount).fit(input)
331375
}
332-
333-
def main(args: Array[String]) {
334-
if (args.length < 6) {
335-
println("Usage: word2vec input size startingAlpha window minCount num")
336-
sys.exit(1)
337-
}
338-
val conf = new SparkConf()
339-
.setAppName("word2vec")
340-
341-
val sc = new SparkContext(conf)
342-
val input = sc.textFile(args(0))
343-
val size = args(1).toInt
344-
val startingAlpha = args(2).toDouble
345-
val window = args(3).toInt
346-
val minCount = args(4).toInt
347-
val num = args(5).toInt
348-
val model = train(input, size, startingAlpha, window, minCount)
349-
val vec = model.findSynonyms("china", num)
350-
for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString)
351-
sc.stop()
352-
}
353376
}

0 commit comments

Comments
 (0)