Skip to content

Commit a87146c

Browse files
committed
add setters and make a default constructor
1 parent e5c923b commit a87146c

File tree

2 files changed

+52
-45
lines changed

2 files changed

+52
-45
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.annotation.Experimental
2828
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2929
import org.apache.spark.mllib.rdd.RDDFunctions._
3030
import org.apache.spark.rdd._
31+
import org.apache.spark.util.Utils
3132
import org.apache.spark.util.random.XORShiftRandom
3233

3334
/**
@@ -58,29 +59,63 @@ private case class VocabWord(
5859
* Efficient Estimation of Word Representations in Vector Space
5960
* and
6061
* Distributed Representations of Words and Phrases and their Compositionality.
61-
* @param size vector dimension
62-
* @param startingAlpha initial learning rate
63-
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
64-
* @param numIterations number of iterations to run, should be smaller than or equal to parallelism
6562
*/
6663
@Experimental
67-
class Word2Vec(
68-
val size: Int,
69-
val startingAlpha: Double,
70-
val parallelism: Int,
71-
val numIterations: Int) extends Serializable with Logging {
64+
class Word2Vec extends Serializable with Logging {
65+
66+
private var vectorSize = 100
67+
private var startingAlpha = 0.025
68+
private var numPartitions = 1
69+
private var numIterations = 1
70+
private var seed = Utils.random.nextLong()
71+
72+
/**
73+
* Sets vector size (default: 100).
74+
*/
75+
def setVectorSize(vectorSize: Int): this.type = {
76+
this.vectorSize = vectorSize
77+
this
78+
}
79+
80+
/**
81+
* Sets initial learning rate (default: 0.025).
82+
*/
83+
def setLearningRate(learningRate: Double): this.type = {
84+
this.startingAlpha = learningRate
85+
this
86+
}
7287

7388
/**
74-
* Word2Vec with a single thread.
89+
* Sets number of partitions (default: 1). Use a small number for accuracy.
7590
*/
76-
def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)
91+
def setNumPartitions(numPartitions: Int): this.type = {
92+
require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
93+
this.numPartitions = numPartitions
94+
this
95+
}
96+
97+
/**
98+
* Sets number of iterations (default: 1), which should be smaller than or equal to number of
99+
* partitions.
100+
*/
101+
def setNumIterations(numIterations: Int): this.type = {
102+
this.numIterations = numIterations
103+
this
104+
}
105+
106+
/**
107+
* Sets random seed (default: a random long integer).
108+
*/
109+
def setSeed(seed: Long): this.type = {
110+
this.seed = seed
111+
this
112+
}
77113

78114
private val EXP_TABLE_SIZE = 1000
79115
private val MAX_EXP = 6
80116
private val MAX_CODE_LENGTH = 40
81117
private val MAX_SENTENCE_LENGTH = 1000
82-
private val layer1Size = size
83-
private val modelPartitionNum = 100
118+
private val layer1Size = vectorSize
84119

85120
/** context words from [-window, window] */
86121
private val window = 5
@@ -245,8 +280,7 @@ class Word2Vec(
245280
}
246281
}
247282

248-
val newSentences = sentences.repartition(parallelism).cache()
249-
val seed = 5875483L
283+
val newSentences = sentences.repartition(numPartitions).cache()
250284
val initRandom = new XORShiftRandom(seed)
251285
var syn0Global =
252286
Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size)
@@ -263,7 +297,7 @@ class Word2Vec(
263297
lwc = wordCount
264298
// TODO: discount by iteration?
265299
alpha =
266-
startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
300+
startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
267301
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
268302
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
269303
}
@@ -404,23 +438,3 @@ class Word2VecModel private[mllib] (
404438
.toArray
405439
}
406440
}
407-
408-
object Word2Vec{
409-
/**
410-
* Train Word2Vec model
411-
* @param input RDD of words
412-
* @param size vector dimension
413-
* @param startingAlpha initial learning rate
414-
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
415-
* @param numIterations number of iterations, should be smaller than or equal to parallelism
416-
* @return Word2Vec model
417-
*/
418-
def train[S <: Iterable[String]](
419-
input: RDD[S],
420-
size: Int,
421-
startingAlpha: Double,
422-
parallelism: Int = 1,
423-
numIterations:Int = 1): Word2VecModel = {
424-
new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
425-
}
426-
}

mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,13 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
3030
val localDoc = Seq(sentence, sentence)
3131
val doc = sc.parallelize(localDoc)
3232
.map(line => line.split(" ").toSeq)
33-
val size = 10
34-
val startingAlpha = 0.025
35-
val window = 2
36-
val minCount = 2
37-
val num = 2
38-
39-
val model = Word2Vec.train(doc, size, startingAlpha)
33+
val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
4034
val syms = model.findSynonyms("a", 2)
41-
assert(syms.length == num)
35+
assert(syms.length == 2)
4236
assert(syms(0)._1 == "b")
4337
assert(syms(1)._1 == "c")
4438
}
4539

46-
4740
test("Word2VecModel") {
4841
val num = 2
4942
val word2VecMap = Map(

0 commit comments

Comments
 (0)