1
1
/*
2
- * Licensed to the Apache Software Foundation (ASF) under one or more
3
- * contributor license agreements. See the NOTICE file distributed with
4
- * this work for additional information regarding copyright ownership.
5
- * The ASF licenses this file to You under the Apache License, Version 2.0
6
- * Add a comment to this line
7
- * (the "License"); you may not use this file except in compliance with
8
- * the License. You may obtain a copy of the License at
9
- *
10
- * http://www.apache.org/licenses/LICENSE-2.0
11
- *
12
- * Unless required by applicable law or agreed to in writing, software
13
- * distributed under the License is distributed on an "AS IS" BASIS,
14
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- * See the License for the specific language governing permissions and
16
- * limitations under the License.
17
- */
2
+ * Licensed to the Apache Software Foundation (ASF) under one or more
3
+ * contributor license agreements. See the NOTICE file distributed with
4
+ * this work for additional information regarding copyright ownership.
5
+ * The ASF licenses this file to You under the Apache License, Version 2.0
6
+ * (the "License"); you may not use this file except in compliance with
7
+ * the License. You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
18
17
19
18
package org .apache .spark .mllib .feature
20
19
21
- import scala .util .{ Random => Random }
20
+ import scala .util .Random
22
21
import scala .collection .mutable .ArrayBuffer
23
22
import scala .collection .mutable
24
23
25
24
import com .github .fommil .netlib .BLAS .{getInstance => blas }
26
25
27
- import org .apache .spark ._
26
+ import org .apache .spark .annotation .Experimental
27
+ import org .apache .spark .Logging
28
28
import org .apache .spark .rdd ._
29
29
import org .apache .spark .SparkContext ._
30
- import org .apache .spark .mllib .linalg .Vector
30
+ import org .apache .spark .mllib .linalg .{ Vector , Vectors }
31
31
import org .apache .spark .HashPartitioner
32
32
33
33
/**
@@ -42,8 +42,27 @@ private case class VocabWord(
42
42
)
43
43
44
44
/**
45
- * Vector representation of word
45
+ * :: Experimental ::
46
+ * Word2Vec creates vector representation of words in a text corpus.
47
+ * The algorithm first constructs a vocabulary from the corpus
48
+ * and then learns vector representation of words in the vocabulary.
49
+ * The vector representation can be used as features in
50
+ * natural language processing and machine learning algorithms.
51
+ *
52
+ * We used skip-gram model in our implementation and hierarchical softmax
53
+ * method to train the model.
54
+ *
55
+ * For original C implementation, see https://code.google.com/p/word2vec/
56
+ * For research papers, see
57
+ * Efficient Estimation of Word Representations in Vector Space
58
+ * and
59
+ * Distributed Representations of Words and Phrases and their Compositionality
60
+ * @param size vector dimension
61
+ * @param startingAlpha initial learning rate
62
+ * @param window context words from [-window, window]
63
+ * @param minCount minimum frequncy to consider a vocabulary word
46
64
*/
65
+ @ Experimental
47
66
class Word2Vec (
48
67
val size : Int ,
49
68
val startingAlpha : Double ,
@@ -64,11 +83,15 @@ class Word2Vec(
64
83
private var vocabHash = mutable.HashMap .empty[String , Int ]
65
84
private var alpha = startingAlpha
66
85
67
- private def learnVocab (dataset : RDD [String ]) {
68
- vocab = dataset.flatMap(line => line.split(" " ))
69
- .map(w => (w, 1 ))
86
+ private def learnVocab (words: RDD [String ]) {
87
+ vocab = words.map(w => (w, 1 ))
70
88
.reduceByKey(_ + _)
71
- .map(x => VocabWord (x._1, x._2, new Array [Int ](MAX_CODE_LENGTH ), new Array [Int ](MAX_CODE_LENGTH ), 0 ))
89
+ .map(x => VocabWord (
90
+ x._1,
91
+ x._2,
92
+ new Array [Int ](MAX_CODE_LENGTH ),
93
+ new Array [Int ](MAX_CODE_LENGTH ),
94
+ 0 ))
72
95
.filter(_.cn >= minCount)
73
96
.collect()
74
97
.sortWith((a, b)=> a.cn > b.cn)
@@ -172,15 +195,16 @@ class Word2Vec(
172
195
}
173
196
174
197
/**
175
- * Computes the vector representation of each word in
176
- * vocabulary
177
- * @param dataset an RDD of strings
198
+ * Computes the vector representation of each word in vocabulary.
199
+ * @param dataset an RDD of words
178
200
* @return a Word2VecModel
179
201
*/
180
202
181
- def fit (dataset: RDD [String ]): Word2VecModel = {
203
+ def fit [ S <: Iterable [ String ]] (dataset: RDD [S ]): Word2VecModel = {
182
204
183
- learnVocab(dataset)
205
+ val words = dataset.flatMap(x => x)
206
+
207
+ learnVocab(words)
184
208
185
209
createBinaryTree()
186
210
@@ -190,9 +214,10 @@ class Word2Vec(
190
214
val V = sc.broadcast(vocab)
191
215
val VHash = sc.broadcast(vocabHash)
192
216
193
- val sentences = dataset.flatMap(line => line.split( " " )) .mapPartitions {
217
+ val sentences = words .mapPartitions {
194
218
iter => { new Iterator [Array [Int ]] {
195
219
def hasNext = iter.hasNext
220
+
196
221
def next = {
197
222
var sentence = new ArrayBuffer [Int ]
198
223
var sentenceLength = 0
@@ -215,7 +240,8 @@ class Word2Vec(
215
240
val newSentences = sentences.repartition(1 ).cache()
216
241
val temp = Array .fill[Double ](vocabSize * layer1Size)((Random .nextDouble - 0.5 ) / layer1Size)
217
242
val (aggSyn0, _, _, _) =
218
- // TODO: broadcast temp instead of serializing it directly or initialize the model in each executor
243
+ // TODO: broadcast temp instead of serializing it directly
244
+ // or initialize the model in each executor
219
245
newSentences.aggregate((temp.clone(), new Array [Double ](vocabSize * layer1Size), 0 , 0 ))(
220
246
seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
221
247
var lwc = lastWordCount
@@ -241,7 +267,7 @@ class Word2Vec(
241
267
val lastWord = sentence(c)
242
268
val l1 = lastWord * layer1Size
243
269
val neu1e = new Array [Double ](layer1Size)
244
- // HS
270
+ // Hierarchical softmax
245
271
var d = 0
246
272
while (d < vocab(word).codeLen) {
247
273
val l2 = vocab(word).point(d) * layer1Size
@@ -265,11 +291,12 @@ class Word2Vec(
265
291
}
266
292
(syn0, syn1, lwc, wc)
267
293
},
268
- combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
269
- val n = syn0_1.length
270
- blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
271
- blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
272
- (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
294
+ combOp = (c1, c2) => (c1, c2) match {
295
+ case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
296
+ val n = syn0_1.length
297
+ blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
298
+ blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
299
+ (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
273
300
})
274
301
275
302
val wordMap = new Array [(String , Array [Double ])](vocabSize)
@@ -281,19 +308,18 @@ class Word2Vec(
281
308
wordMap(i) = (word, vector)
282
309
i += 1
283
310
}
284
- val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner (modelPartitionNum))
311
+ val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
312
+ .partitionBy(new HashPartitioner (modelPartitionNum))
285
313
new Word2VecModel (modelRDD)
286
314
}
287
315
}
288
316
289
317
/**
290
318
* Word2Vec model
291
319
*/
292
- class Word2VecModel (val _model : RDD [(String , Array [Double ])]) extends Serializable {
293
-
294
- val model = _model
320
+ class Word2VecModel (private val model : RDD [(String , Array [Double ])]) extends Serializable {
295
321
296
- private def distance (v1 : Array [Double ], v2 : Array [Double ]): Double = {
322
+ private def cosineSimilarity (v1 : Array [Double ], v2 : Array [Double ]): Double = {
297
323
require(v1.length == v2.length, " Vectors should have the same length" )
298
324
val n = v1.length
299
325
val norm1 = blas.dnrm2(n, v1, 1 )
@@ -307,20 +333,20 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
307
333
* @param word a word
308
334
* @return vector representation of word
309
335
*/
310
-
311
- def transform (word : String ): Array [Double ] = {
336
+ def transform (word : String ): Vector = {
312
337
val result = model.lookup(word)
313
- if (result.isEmpty) Array [Double ]()
314
- else result(0 )
338
+ if (result.isEmpty) {
339
+ throw new IllegalStateException (s " ${word} not in vocabulary " )
340
+ }
341
+ else Vectors .dense(result(0 ))
315
342
}
316
343
317
344
/**
318
345
* Transforms an RDD to its vector representation
319
346
* @param dataset a an RDD of words
320
347
* @return RDD of vector representation
321
348
*/
322
-
323
- def transform (dataset : RDD [String ]): RDD [Array [Double ]] = {
349
+ def transform (dataset : RDD [String ]): RDD [Vector ] = {
324
350
dataset.map(word => transform(word))
325
351
}
326
352
@@ -332,44 +358,44 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
332
358
*/
333
359
def findSynonyms (word : String , num : Int ): Array [(String , Double )] = {
334
360
val vector = transform(word)
335
- if (vector.isEmpty) Array [(String , Double )]()
336
- else findSynonyms(vector,num)
361
+ findSynonyms(vector,num)
337
362
}
338
363
339
364
/**
340
365
* Find synonyms of the vector representation of a word
341
366
* @param vector vector representation of a word
342
367
* @param num number of synonyms to find
343
- * @return array of (word, similarity )
368
+ * @return array of (word, cosineSimilarity )
344
369
*/
345
- def findSynonyms (vector : Array [ Double ] , num : Int ): Array [(String , Double )] = {
370
+ def findSynonyms (vector : Vector , num : Int ): Array [(String , Double )] = {
346
371
require(num > 0 , " Number of similar words should > 0" )
347
- val topK = model.map(
348
- { case (w, vec) => (distance( vector, vec), w)})
372
+ val topK = model.map { case (w, vec) =>
373
+ (cosineSimilarity( vector.toArray , vec), w) }
349
374
.sortByKey(ascending = false )
350
375
.take(num + 1 )
351
- .map({case (dist, w) => (w, dist)}).drop(1 )
376
+ .map(_.swap)
377
+ .tail
352
378
353
379
topK
354
380
}
355
381
}
356
382
357
- object Word2Vec extends Serializable with Logging {
383
+ object Word2Vec {
358
384
/**
359
385
* Train Word2Vec model
360
386
* @param input RDD of words
361
- * @param size vectoer dimension
387
+ * @param size vector dimension
362
388
* @param startingAlpha initial learning rate
363
389
* @param window context words from [-window, window]
364
390
* @param minCount minimum frequncy to consider a vocabulary word
365
391
* @return Word2Vec model
366
392
*/
367
- def train (
368
- input : RDD [String ],
393
+ def train [ S <: Iterable [ String ]] (
394
+ input : RDD [S ],
369
395
size : Int ,
370
396
startingAlpha : Double ,
371
397
window : Int ,
372
398
minCount : Int ): Word2VecModel = {
373
- new Word2Vec (size,startingAlpha, window, minCount).fit(input)
399
+ new Word2Vec (size,startingAlpha, window, minCount).fit[ S ] (input)
374
400
}
375
401
}
0 commit comments