Skip to content

Commit c14da41

Browse files
committed
fix styles
1 parent 384c771 commit c14da41

File tree

2 files changed

+34
-32
lines changed

2 files changed

+34
-32
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,18 @@
1717

1818
package org.apache.spark.mllib.feature
1919

20-
import scala.util.Random
21-
import scala.collection.mutable.ArrayBuffer
2220
import scala.collection.mutable
21+
import scala.collection.mutable.ArrayBuffer
22+
import scala.util.Random
2323

2424
import com.github.fommil.netlib.BLAS.{getInstance => blas}
25-
26-
import org.apache.spark.annotation.Experimental
27-
import org.apache.spark.Logging
28-
import org.apache.spark.rdd._
25+
import org.apache.spark.{HashPartitioner, Logging}
2926
import org.apache.spark.SparkContext._
27+
import org.apache.spark.annotation.Experimental
3028
import org.apache.spark.mllib.linalg.{Vector, Vectors}
31-
import org.apache.spark.HashPartitioner
32-
import org.apache.spark.storage.StorageLevel
3329
import org.apache.spark.mllib.rdd.RDDFunctions._
30+
import org.apache.spark.rdd._
31+
import org.apache.spark.storage.StorageLevel
3432

3533
/**
3634
* Entry in vocabulary
@@ -53,7 +51,7 @@ private case class VocabWord(
5351
*
5452
* We used skip-gram model in our implementation and hierarchical softmax
5553
* method to train the model. The variable names in the implementation
56-
* mathes the original C implementation.
54+
* matches the original C implementation.
5755
*
5856
* For original C implementation, see https://code.google.com/p/word2vec/
5957
* For research papers, see
@@ -69,10 +67,14 @@ private case class VocabWord(
6967
class Word2Vec(
7068
val size: Int,
7169
val startingAlpha: Double,
72-
val parallelism: Int = 1,
73-
val numIterations: Int = 1)
74-
extends Serializable with Logging {
75-
70+
val parallelism: Int,
71+
val numIterations: Int) extends Serializable with Logging {
72+
73+
/**
74+
* Word2Vec with a single thread.
75+
*/
76+
def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)
77+
7678
private val EXP_TABLE_SIZE = 1000
7779
private val MAX_EXP = 6
7880
private val MAX_CODE_LENGTH = 40
@@ -92,7 +94,7 @@ class Word2Vec(
9294
private var vocabHash = mutable.HashMap.empty[String, Int]
9395
private var alpha = startingAlpha
9496

95-
private def learnVocab(words:RDD[String]){
97+
private def learnVocab(words:RDD[String]): Unit = {
9698
vocab = words.map(w => (w, 1))
9799
.reduceByKey(_ + _)
98100
.map(x => VocabWord(
@@ -126,7 +128,7 @@ class Word2Vec(
126128
expTable
127129
}
128130

129-
private def createBinaryTree() {
131+
private def createBinaryTree(): Unit = {
130132
val count = new Array[Long](vocabSize * 2 + 1)
131133
val binary = new Array[Int](vocabSize * 2 + 1)
132134
val parentNode = new Array[Int](vocabSize * 2 + 1)
@@ -208,7 +210,6 @@ class Word2Vec(
208210
* @param dataset an RDD of words
209211
* @return a Word2VecModel
210212
*/
211-
212213
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
213214

214215
val words = dataset.flatMap(x => x)
@@ -339,7 +340,7 @@ class Word2Vec(
339340
/**
340341
* Word2Vec model
341342
*/
342-
class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Serializable {
343+
class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {
343344

344345
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
345346
require(v1.length == v2.length, "Vectors should have the same length")
@@ -358,7 +359,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Seri
358359
def transform(word: String): Vector = {
359360
val result = model.lookup(word)
360361
if (result.isEmpty) {
361-
throw new IllegalStateException(s"${word} not in vocabulary")
362+
throw new IllegalStateException(s"$word not in vocabulary")
362363
}
363364
else Vectors.dense(result(0).map(_.toDouble))
364365
}

mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ package org.apache.spark.mllib.feature
1919

2020
import org.scalatest.FunSuite
2121

22-
import org.apache.spark.SparkContext._
2322
import org.apache.spark.mllib.util.LocalSparkContext
2423

2524
class Word2VecSuite extends FunSuite with LocalSparkContext {
25+
26+
// TODO: add more tests
27+
2628
test("Word2Vec") {
2729
val sentence = "a b " * 100 + "a c " * 10
2830
val localDoc = Seq(sentence, sentence)
@@ -33,28 +35,27 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
3335
val window = 2
3436
val minCount = 2
3537
val num = 2
36-
val word = "a"
3738

3839
val model = Word2Vec.train(doc, size, startingAlpha, window, minCount)
39-
val synons = model.findSynonyms("a", 2)
40-
assert(synons.length == num)
41-
assert(synons(0)._1 == "b")
42-
assert(synons(1)._1 == "c")
40+
val syms = model.findSynonyms("a", 2)
41+
assert(syms.length == num)
42+
assert(syms(0)._1 == "b")
43+
assert(syms(1)._1 == "c")
4344
}
4445

4546

4647
test("Word2VecModel") {
4748
val num = 2
4849
val localModel = Seq(
49-
("china" , Array(0.50, 0.50, 0.50, 0.50)),
50-
("japan" , Array(0.40, 0.50, 0.50, 0.50)),
51-
("taiwan", Array(0.60, 0.50, 0.50, 0.50)),
52-
("korea" , Array(0.45, 0.60, 0.60, 0.60))
50+
("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)),
51+
("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)),
52+
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
53+
("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f))
5354
)
5455
val model = new Word2VecModel(sc.parallelize(localModel, 2))
55-
val synons = model.findSynonyms("china", num)
56-
assert(synons.length == num)
57-
assert(synons(0)._1 == "taiwan")
58-
assert(synons(1)._1 == "japan")
56+
val syms = model.findSynonyms("china", num)
57+
assert(syms.length == num)
58+
assert(syms(0)._1 == "taiwan")
59+
assert(syms(1)._1 == "japan")
5960
}
6061
}

0 commit comments

Comments
 (0)