Skip to content

Commit 2e92b59

Browse files
author
Liquan Pei
committed
modify according to feedback
1 parent 57dc50d commit 2e92b59

File tree

2 files changed

+102
-76
lines changed

2 files changed

+102
-76
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 86 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
11
/*
2-
* Licensed to the Apache Software Foundation (ASF) under one or more
3-
* contributor license agreements. See the NOTICE file distributed with
4-
* this work for additional information regarding copyright ownership.
5-
* The ASF licenses this file to You under the Apache License, Version 2.0
6-
* Add a comment to this line
7-
* (the "License"); you may not use this file except in compliance with
8-
* the License. You may obtain a copy of the License at
9-
*
10-
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
12-
* Unless required by applicable law or agreed to in writing, software
13-
* distributed under the License is distributed on an "AS IS" BASIS,
14-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15-
* See the License for the specific language governing permissions and
16-
* limitations under the License.
17-
*/
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
1817

1918
package org.apache.spark.mllib.feature
2019

21-
import scala.util.{Random => Random}
20+
import scala.util.Random
2221
import scala.collection.mutable.ArrayBuffer
2322
import scala.collection.mutable
2423

2524
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2625

27-
import org.apache.spark._
26+
import org.apache.spark.annotation.Experimental
27+
import org.apache.spark.Logging
2828
import org.apache.spark.rdd._
2929
import org.apache.spark.SparkContext._
30-
import org.apache.spark.mllib.linalg.Vector
30+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3131
import org.apache.spark.HashPartitioner
3232

3333
/**
@@ -42,8 +42,27 @@ private case class VocabWord(
4242
)
4343

4444
/**
45-
* Vector representation of word
45+
* :: Experimental ::
46+
* Word2Vec creates vector representation of words in a text corpus.
47+
* The algorithm first constructs a vocabulary from the corpus
48+
* and then learns vector representation of words in the vocabulary.
49+
* The vector representation can be used as features in
50+
* natural language processing and machine learning algorithms.
51+
*
52+
* We used skip-gram model in our implementation and hierarchical softmax
53+
* method to train the model.
54+
*
55+
* For original C implementation, see https://code.google.com/p/word2vec/
56+
* For research papers, see
57+
* Efficient Estimation of Word Representations in Vector Space
58+
* and
59+
* Distributed Representations of Words and Phrases and their Compositionality
60+
* @param size vector dimension
61+
* @param startingAlpha initial learning rate
62+
* @param window context words from [-window, window]
63+
* @param minCount minimum frequncy to consider a vocabulary word
4664
*/
65+
@Experimental
4766
class Word2Vec(
4867
val size: Int,
4968
val startingAlpha: Double,
@@ -64,11 +83,15 @@ class Word2Vec(
6483
private var vocabHash = mutable.HashMap.empty[String, Int]
6584
private var alpha = startingAlpha
6685

67-
private def learnVocab(dataset: RDD[String]) {
68-
vocab = dataset.flatMap(line => line.split(" "))
69-
.map(w => (w, 1))
86+
private def learnVocab(words:RDD[String]) {
87+
vocab = words.map(w => (w, 1))
7088
.reduceByKey(_ + _)
71-
.map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0))
89+
.map(x => VocabWord(
90+
x._1,
91+
x._2,
92+
new Array[Int](MAX_CODE_LENGTH),
93+
new Array[Int](MAX_CODE_LENGTH),
94+
0))
7295
.filter(_.cn >= minCount)
7396
.collect()
7497
.sortWith((a, b)=> a.cn > b.cn)
@@ -172,15 +195,16 @@ class Word2Vec(
172195
}
173196

174197
/**
175-
* Computes the vector representation of each word in
176-
* vocabulary
177-
* @param dataset an RDD of strings
198+
* Computes the vector representation of each word in vocabulary.
199+
* @param dataset an RDD of words
178200
* @return a Word2VecModel
179201
*/
180202

181-
def fit(dataset:RDD[String]): Word2VecModel = {
203+
def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = {
182204

183-
learnVocab(dataset)
205+
val words = dataset.flatMap(x => x)
206+
207+
learnVocab(words)
184208

185209
createBinaryTree()
186210

@@ -190,9 +214,10 @@ class Word2Vec(
190214
val V = sc.broadcast(vocab)
191215
val VHash = sc.broadcast(vocabHash)
192216

193-
val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions {
217+
val sentences = words.mapPartitions {
194218
iter => { new Iterator[Array[Int]] {
195219
def hasNext = iter.hasNext
220+
196221
def next = {
197222
var sentence = new ArrayBuffer[Int]
198223
var sentenceLength = 0
@@ -215,7 +240,8 @@ class Word2Vec(
215240
val newSentences = sentences.repartition(1).cache()
216241
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
217242
val (aggSyn0, _, _, _) =
218-
// TODO: broadcast temp instead of serializing it directly or initialize the model in each executor
243+
// TODO: broadcast temp instead of serializing it directly
244+
// or initialize the model in each executor
219245
newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))(
220246
seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
221247
var lwc = lastWordCount
@@ -241,7 +267,7 @@ class Word2Vec(
241267
val lastWord = sentence(c)
242268
val l1 = lastWord * layer1Size
243269
val neu1e = new Array[Double](layer1Size)
244-
//HS
270+
// Hierarchical softmax
245271
var d = 0
246272
while (d < vocab(word).codeLen) {
247273
val l2 = vocab(word).point(d) * layer1Size
@@ -265,11 +291,12 @@ class Word2Vec(
265291
}
266292
(syn0, syn1, lwc, wc)
267293
},
268-
combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
269-
val n = syn0_1.length
270-
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
271-
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
272-
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
294+
combOp = (c1, c2) => (c1, c2) match {
295+
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
296+
val n = syn0_1.length
297+
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
298+
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
299+
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
273300
})
274301

275302
val wordMap = new Array[(String, Array[Double])](vocabSize)
@@ -281,19 +308,18 @@ class Word2Vec(
281308
wordMap(i) = (word, vector)
282309
i += 1
283310
}
284-
val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum))
311+
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
312+
.partitionBy(new HashPartitioner(modelPartitionNum))
285313
new Word2VecModel(modelRDD)
286314
}
287315
}
288316

289317
/**
290318
* Word2Vec model
291319
*/
292-
class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable {
293-
294-
val model = _model
320+
class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable {
295321

296-
private def distance(v1: Array[Double], v2: Array[Double]): Double = {
322+
private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = {
297323
require(v1.length == v2.length, "Vectors should have the same length")
298324
val n = v1.length
299325
val norm1 = blas.dnrm2(n, v1, 1)
@@ -307,20 +333,20 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
307333
* @param word a word
308334
* @return vector representation of word
309335
*/
310-
311-
def transform(word: String): Array[Double] = {
336+
def transform(word: String): Vector = {
312337
val result = model.lookup(word)
313-
if (result.isEmpty) Array[Double]()
314-
else result(0)
338+
if (result.isEmpty) {
339+
throw new IllegalStateException(s"${word} not in vocabulary")
340+
}
341+
else Vectors.dense(result(0))
315342
}
316343

317344
/**
318345
* Transforms an RDD to its vector representation
319346
* @param dataset a an RDD of words
320347
* @return RDD of vector representation
321348
*/
322-
323-
def transform(dataset: RDD[String]): RDD[Array[Double]] = {
349+
def transform(dataset: RDD[String]): RDD[Vector] = {
324350
dataset.map(word => transform(word))
325351
}
326352

@@ -332,44 +358,44 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
332358
*/
333359
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
334360
val vector = transform(word)
335-
if (vector.isEmpty) Array[(String, Double)]()
336-
else findSynonyms(vector,num)
361+
findSynonyms(vector,num)
337362
}
338363

339364
/**
340365
* Find synonyms of the vector representation of a word
341366
* @param vector vector representation of a word
342367
* @param num number of synonyms to find
343-
* @return array of (word, similarity)
368+
* @return array of (word, cosineSimilarity)
344369
*/
345-
def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = {
370+
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
346371
require(num > 0, "Number of similar words should > 0")
347-
val topK = model.map(
348-
{case(w, vec) => (distance(vector, vec), w)})
372+
val topK = model.map { case(w, vec) =>
373+
(cosineSimilarity(vector.toArray, vec), w) }
349374
.sortByKey(ascending = false)
350375
.take(num + 1)
351-
.map({case (dist, w) => (w, dist)}).drop(1)
376+
.map(_.swap)
377+
.tail
352378

353379
topK
354380
}
355381
}
356382

357-
object Word2Vec extends Serializable with Logging {
383+
object Word2Vec{
358384
/**
359385
* Train Word2Vec model
360386
* @param input RDD of words
361-
* @param size vectoer dimension
387+
* @param size vector dimension
362388
* @param startingAlpha initial learning rate
363389
* @param window context words from [-window, window]
364390
* @param minCount minimum frequncy to consider a vocabulary word
365391
* @return Word2Vec model
366392
*/
367-
def train(
368-
input: RDD[String],
393+
def train[S <: Iterable[String]](
394+
input: RDD[S],
369395
size: Int,
370396
startingAlpha: Double,
371397
window: Int,
372398
minCount: Int): Word2VecModel = {
373-
new Word2Vec(size,startingAlpha, window, minCount).fit(input)
399+
new Word2Vec(size,startingAlpha, window, minCount).fit[S](input)
374400
}
375401
}

mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
/*
2-
* Licensed to the Apache Software Foundation (ASF) under one or more
3-
* contributor license agreements. See the NOTICE file distributed with
4-
* this work for additional information regarding copyright ownership.
5-
* The ASF licenses this file to You under the Apache License, Version 2.0
6-
* Add a comment to this line
7-
* (the "License"); you may not use this file except in compliance with
8-
* the License. You may obtain a copy of the License at
9-
*
10-
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
12-
* Unless required by applicable law or agreed to in writing, software
13-
* distributed under the License is distributed on an "AS IS" BASIS,
14-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15-
* See the License for the specific language governing permissions and
16-
* limitations under the License.
17-
*/
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
1817

1918
package org.apache.spark.mllib.feature
2019

2120
import org.scalatest.FunSuite
21+
2222
import org.apache.spark.SparkContext._
2323
import org.apache.spark.mllib.util.LocalSparkContext
2424

0 commit comments

Comments
 (0)