@@ -31,6 +31,9 @@ import org.apache.spark.SparkContext._
31
31
import org .apache .spark .mllib .linalg .Vector
32
32
import org .apache .spark .HashPartitioner
33
33
34
+ /**
35
+ * Entry in vocabulary
36
+ */
34
37
private case class VocabWord (
35
38
var word : String ,
36
39
var cn : Int ,
@@ -39,6 +42,9 @@ private case class VocabWord(
39
42
var codeLen : Int
40
43
)
41
44
45
+ /**
46
+ * Vector representation of word
47
+ */
42
48
class Word2Vec (
43
49
val size : Int ,
44
50
val startingAlpha : Double ,
@@ -51,7 +57,8 @@ class Word2Vec(
51
57
private val MAX_CODE_LENGTH = 40
52
58
private val MAX_SENTENCE_LENGTH = 1000
53
59
private val layer1Size = size
54
-
60
+ private val modelPartitionNum = 100
61
+
55
62
private var trainWordsCount = 0
56
63
private var vocabSize = 0
57
64
private var vocab : Array [VocabWord ] = null
@@ -169,6 +176,7 @@ class Word2Vec(
169
176
* Computes the vector representation of each word in
170
177
* vocabulary
171
178
* @param dataset an RDD of strings
179
+ * @return a Word2VecModel
172
180
*/
173
181
174
182
def fit (dataset: RDD [String ]): Word2VecModel = {
@@ -274,11 +282,14 @@ class Word2Vec(
274
282
wordMap(i) = (word, vector)
275
283
i += 1
276
284
}
277
- val modelRDD = sc.parallelize(wordMap,100 ).partitionBy(new HashPartitioner (100 ))
285
+ val modelRDD = sc.parallelize(wordMap, modelPartitionNum ).partitionBy(new HashPartitioner (modelPartitionNum ))
278
286
new Word2VecModel (modelRDD)
279
287
}
280
288
}
281
289
290
+ /**
291
+ * Word2Vec model
292
+ */
282
293
class Word2VecModel (val _model : RDD [(String , Array [Double ])]) extends Serializable {
283
294
284
295
val model = _model
@@ -292,22 +303,46 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
292
303
blas.ddot(n, v1, 1 , v2,1 ) / norm1 / norm2
293
304
}
294
305
306
+ /**
307
+ * Transforms a word to its vector representation
308
+ * @param word a word
309
+ * @return vector representation of word
310
+ */
311
+
295
312
def transform (word : String ): Array [Double ] = {
296
313
val result = model.lookup(word)
297
314
if (result.isEmpty) Array [Double ]()
298
315
else result(0 )
299
316
}
300
317
318
+ /**
319
+ * Transforms an RDD to its vector representation
320
+ * @param dataset a an RDD of words
321
+ * @return RDD of vector representation
322
+ */
323
+
301
324
def transform (dataset : RDD [String ]): RDD [Array [Double ]] = {
302
325
dataset.map(word => transform(word))
303
326
}
304
327
328
+ /**
329
+ * Find synonyms of a word
330
+ * @param word a word
331
+ * @param num number of synonyms to find
332
+ * @return array of (word, similarity)
333
+ */
305
334
def findSynonyms (word : String , num : Int ): Array [(String , Double )] = {
306
335
val vector = transform(word)
307
336
if (vector.isEmpty) Array [(String , Double )]()
308
337
else findSynonyms(vector,num)
309
338
}
310
339
340
+ /**
341
+ * Find synonyms of the vector representation of a word
342
+ * @param vector vector representation of a word
343
+ * @param num number of synonyms to find
344
+ * @return array of (word, similarity)
345
+ */
311
346
def findSynonyms (vector : Array [Double ], num : Int ): Array [(String , Double )] = {
312
347
require(num > 0 , " Number of similar words should > 0" )
313
348
val topK = model.map(
@@ -321,6 +356,15 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
321
356
}
322
357
323
358
object Word2Vec extends Serializable with Logging {
359
+ /**
360
+ * Train Word2Vec model
361
+ * @param input RDD of words
362
+ * @param size vectoer dimension
363
+ * @param startingAlpha initial learning rate
364
+ * @param window context words from [-window, window]
365
+ * @param minCount minimum frequncy to consider a vocabulary word
366
+ * @return Word2Vec model
367
+ */
324
368
def train (
325
369
input : RDD [String ],
326
370
size : Int ,
@@ -329,25 +373,4 @@ object Word2Vec extends Serializable with Logging {
329
373
minCount : Int ): Word2VecModel = {
330
374
new Word2Vec (size,startingAlpha, window, minCount).fit(input)
331
375
}
332
-
333
- def main (args : Array [String ]) {
334
- if (args.length < 6 ) {
335
- println(" Usage: word2vec input size startingAlpha window minCount num" )
336
- sys.exit(1 )
337
- }
338
- val conf = new SparkConf ()
339
- .setAppName(" word2vec" )
340
-
341
- val sc = new SparkContext (conf)
342
- val input = sc.textFile(args(0 ))
343
- val size = args(1 ).toInt
344
- val startingAlpha = args(2 ).toDouble
345
- val window = args(3 ).toInt
346
- val minCount = args(4 ).toInt
347
- val num = args(5 ).toInt
348
- val model = train(input, size, startingAlpha, window, minCount)
349
- val vec = model.findSynonyms(" china" , num)
350
- for ((w, dist) <- vec) logInfo(w.toString + " " + dist.toString)
351
- sc.stop()
352
- }
353
376
}
0 commit comments