17
17
18
18
package org .apache .spark .mllib .feature
19
19
20
- import scala .util .Random
21
- import scala .collection .mutable .ArrayBuffer
22
20
import scala .collection .mutable
21
+ import scala .collection .mutable .ArrayBuffer
22
+ import scala .util .Random
23
23
24
24
import com .github .fommil .netlib .BLAS .{getInstance => blas }
25
-
26
- import org .apache .spark .annotation .Experimental
27
- import org .apache .spark .Logging
28
- import org .apache .spark .rdd ._
25
+ import org .apache .spark .{HashPartitioner , Logging }
29
26
import org .apache .spark .SparkContext ._
27
+ import org .apache .spark .annotation .Experimental
30
28
import org .apache .spark .mllib .linalg .{Vector , Vectors }
31
- import org .apache .spark .HashPartitioner
32
- import org .apache .spark .storage .StorageLevel
33
29
import org .apache .spark .mllib .rdd .RDDFunctions ._
30
+ import org .apache .spark .rdd ._
31
+ import org .apache .spark .storage .StorageLevel
34
32
35
33
/**
36
34
* Entry in vocabulary
@@ -53,7 +51,7 @@ private case class VocabWord(
53
51
*
54
52
* We used skip-gram model in our implementation and hierarchical softmax
55
53
* method to train the model. The variable names in the implementation
56
- * mathes the original C implementation.
54
+ * matches the original C implementation.
57
55
*
58
56
* For original C implementation, see https://code.google.com/p/word2vec/
59
57
* For research papers, see
@@ -69,10 +67,14 @@ private case class VocabWord(
69
67
class Word2Vec (
70
68
val size : Int ,
71
69
val startingAlpha : Double ,
72
- val parallelism : Int = 1 ,
73
- val numIterations : Int = 1 )
74
- extends Serializable with Logging {
75
-
70
+ val parallelism : Int ,
71
+ val numIterations : Int ) extends Serializable with Logging {
72
+
73
+ /**
74
+ * Word2Vec with a single thread.
75
+ */
76
+ def this (size : Int , startingAlpha : Int ) = this (size, startingAlpha, 1 , 1 )
77
+
76
78
private val EXP_TABLE_SIZE = 1000
77
79
private val MAX_EXP = 6
78
80
private val MAX_CODE_LENGTH = 40
@@ -92,7 +94,7 @@ class Word2Vec(
92
94
private var vocabHash = mutable.HashMap .empty[String , Int ]
93
95
private var alpha = startingAlpha
94
96
95
- private def learnVocab (words: RDD [String ]){
97
+ private def learnVocab (words: RDD [String ]): Unit = {
96
98
vocab = words.map(w => (w, 1 ))
97
99
.reduceByKey(_ + _)
98
100
.map(x => VocabWord (
@@ -126,7 +128,7 @@ class Word2Vec(
126
128
expTable
127
129
}
128
130
129
- private def createBinaryTree () {
131
+ private def createBinaryTree (): Unit = {
130
132
val count = new Array [Long ](vocabSize * 2 + 1 )
131
133
val binary = new Array [Int ](vocabSize * 2 + 1 )
132
134
val parentNode = new Array [Int ](vocabSize * 2 + 1 )
@@ -208,7 +210,6 @@ class Word2Vec(
208
210
* @param dataset an RDD of words
209
211
* @return a Word2VecModel
210
212
*/
211
-
212
213
def fit [S <: Iterable [String ]](dataset : RDD [S ]): Word2VecModel = {
213
214
214
215
val words = dataset.flatMap(x => x)
@@ -339,7 +340,7 @@ class Word2Vec(
339
340
/**
340
341
* Word2Vec model
341
342
*/
342
- class Word2VecModel (private val model : RDD [(String , Array [Float ])]) extends Serializable {
343
+ class Word2VecModel (private val model : RDD [(String , Array [Float ])]) extends Serializable {
343
344
344
345
private def cosineSimilarity (v1 : Array [Float ], v2 : Array [Float ]): Double = {
345
346
require(v1.length == v2.length, " Vectors should have the same length" )
@@ -358,7 +359,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Seri
358
359
def transform (word : String ): Vector = {
359
360
val result = model.lookup(word)
360
361
if (result.isEmpty) {
361
- throw new IllegalStateException (s " ${ word} not in vocabulary " )
362
+ throw new IllegalStateException (s " $word not in vocabulary " )
362
363
}
363
364
else Vectors .dense(result(0 ).map(_.toDouble))
364
365
}
0 commit comments