Skip to content

Commit 1045eec

Browse files
committed
Merge pull request #2 from jkbradley/hhbyyh-ldaonline2
Various cleanups, use random seed, optimization
2 parents 6149ca6 + cf376ff commit 1045eec

File tree

3 files changed

+59
-31
lines changed

3 files changed

+59
-31
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.Random
2121

2222
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
2323
import breeze.numerics.{digamma, exp, abs}
24-
import breeze.stats.distributions.Gamma
24+
import breeze.stats.distributions.{Gamma, RandBasis}
2525

2626
import org.apache.spark.annotation.Experimental
2727
import org.apache.spark.graphx._
@@ -227,20 +227,37 @@ class OnlineLDAOptimizer extends LDAOptimizer {
227227
private var k: Int = 0
228228
private var corpusSize: Long = 0
229229
private var vocabSize: Int = 0
230-
private[clustering] var alpha: Double = 0
231-
private[clustering] var eta: Double = 0
230+
231+
/** alias for docConcentration */
232+
private var alpha: Double = 0
233+
234+
/** (private[clustering] for debugging) Get docConcentration */
235+
private[clustering] def getAlpha: Double = alpha
236+
237+
/** alias for topicConcentration */
238+
private var eta: Double = 0
239+
240+
/** (private[clustering] for debugging) Get topicConcentration */
241+
private[clustering] def getEta: Double = eta
242+
232243
private var randomGenerator: java.util.Random = null
233244

234245
// Online LDA specific parameters
246+
// Learning rate is: (tau_0 + t)^{-kappa}
235247
private var tau_0: Double = 1024
236248
private var kappa: Double = 0.51
237-
private var miniBatchFraction: Double = 0.01
249+
private var miniBatchFraction: Double = 0.05
238250

239251
// internal data structure
240252
private var docs: RDD[(Long, Vector)] = null
241-
private[clustering] var lambda: BDM[Double] = null
242253

243-
// count of invocation to next, which helps deciding the weight for each iteration
254+
/** Dirichlet parameter for the posterior over topics */
255+
private var lambda: BDM[Double] = null
256+
257+
/** (private[clustering] for debugging) Get parameter for topics */
258+
private[clustering] def getLambda: BDM[Double] = lambda
259+
260+
/** Current iteration (count of invocations of [[next()]]) */
244261
private var iteration: Int = 0
245262
private var gammaShape: Double = 100
246263

@@ -285,7 +302,12 @@ class OnlineLDAOptimizer extends LDAOptimizer {
285302
/**
286303
* Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in
287304
* each iteration.
288-
* Default: 0.01, i.e., 1% of total documents
305+
*
306+
* Note that this should be adjusted in synch with [[LDA.setMaxIterations()]]
307+
* so the entire corpus is used. Specifically, set both so that
308+
* maxIterations * miniBatchFraction >= 1.
309+
*
310+
* Default: 0.05, i.e., 5% of total documents.
289311
*/
290312
def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
291313
require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0,
@@ -295,15 +317,20 @@ class OnlineLDAOptimizer extends LDAOptimizer {
295317
}
296318

297319
/**
298-
* The function is for test only now. In the future, it can help support training stop/resume
320+
* (private[clustering])
321+
* Set the Dirichlet parameter for the posterior over topics.
322+
* This is only used for testing now. In the future, it can help support training stop/resume.
299323
*/
300324
private[clustering] def setLambda(lambda: BDM[Double]): this.type = {
301325
this.lambda = lambda
302326
this
303327
}
304328

305329
/**
306-
* Used to control the gamma distribution. Larger value produces values closer to 1.0.
330+
* (private[clustering])
331+
* Used for random initialization of the variational parameters.
332+
* Larger value produces values closer to 1.0.
333+
* This is only used for testing currently.
307334
*/
308335
private[clustering] def setGammaShape(shape: Double): this.type = {
309336
this.gammaShape = shape
@@ -380,12 +407,11 @@ class OnlineLDAOptimizer extends LDAOptimizer {
380407
meanchange = sum(abs(gammad - lastgamma)) / k
381408
}
382409

383-
val m1 = expElogthetad.t.toDenseMatrix.t
384-
val m2 = (ctsVector / phinorm).t.toDenseMatrix
385-
val outerResult = kron(m1, m2) // K * ids
410+
val m1 = expElogthetad.t
411+
val m2 = (ctsVector / phinorm).t.toDenseVector
386412
var i = 0
387413
while (i < ids.size) {
388-
stat(::, ids(i)) := (stat(::, ids(i)) + outerResult(::, i))
414+
stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
389415
i += 1
390416
}
391417
}
@@ -423,7 +449,9 @@ class OnlineLDAOptimizer extends LDAOptimizer {
423449
* Get a random matrix to initialize lambda
424450
*/
425451
private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
426-
val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)
452+
val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
453+
randomGenerator.nextLong()))
454+
val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)(randBasis)
427455
val temp = gammaRandomGenerator.sample(row * col).toArray
428456
new BDM[Double](col, row, temp).t
429457
}

mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.io.Serializable;
2121
import java.util.ArrayList;
2222

23-
import org.apache.spark.api.java.JavaRDD;
2423
import scala.Tuple2;
2524

2625
import org.junit.After;
@@ -30,6 +29,7 @@
3029
import org.junit.Test;
3130

3231
import org.apache.spark.api.java.JavaPairRDD;
32+
import org.apache.spark.api.java.JavaRDD;
3333
import org.apache.spark.api.java.JavaSparkContext;
3434
import org.apache.spark.mllib.linalg.Matrix;
3535
import org.apache.spark.mllib.linalg.Vector;
@@ -148,6 +148,6 @@ public void OnlineOptimizerCompatibility() {
148148
private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
149149
private static Tuple2<int[], double[]>[] tinyTopicDescription =
150150
LDASuite$.MODULE$.tinyTopicDescription();
151-
JavaPairRDD<Long, Vector> corpus;
151+
private JavaPairRDD<Long, Vector> corpus;
152152

153153
}

mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
3939

4040
// Check: describeTopics() with all terms
4141
val fullTopicSummary = model.describeTopics()
42-
assert(fullTopicSummary.size === tinyK)
42+
assert(fullTopicSummary.length === tinyK)
4343
fullTopicSummary.zip(tinyTopicDescription).foreach {
4444
case ((algTerms, algTermWeights), (terms, termWeights)) =>
4545
assert(algTerms === terms)
@@ -101,7 +101,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
101101
// Check: per-doc topic distributions
102102
val topicDistributions = model.topicDistributions.collect()
103103
// Ensure all documents are covered.
104-
assert(topicDistributions.size === tinyCorpus.size)
104+
assert(topicDistributions.length === tinyCorpus.length)
105105
assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
106106
// Ensure we have proper distributions
107107
topicDistributions.foreach { case (docId, topicDistribution) =>
@@ -139,8 +139,8 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
139139
val corpus = sc.parallelize(tinyCorpus, 2)
140140
val op = new OnlineLDAOptimizer().initialize(corpus, lda)
141141
op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau_0(567)
142-
assert(op.alpha == 0.5) // default 1.0 / k
143-
assert(op.eta == 0.5) // default 1.0 / k
142+
assert(op.getAlpha == 0.5) // default 1.0 / k
143+
assert(op.getEta == 0.5) // default 1.0 / k
144144
assert(op.getKappa == 0.9876)
145145
assert(op.getMiniBatchFraction == 0.123)
146146
assert(op.getTau_0 == 567)
@@ -154,14 +154,14 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
154154

155155
def docs: Array[(Long, Vector)] = Array(
156156
Vectors.sparse(vocabSize, Array(0, 1, 2), Array(1, 1, 1)), // apple, orange, banana
157-
Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1))) // tiger, cat, dog
158-
.zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
157+
Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1)) // tiger, cat, dog
158+
).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
159159
val corpus = sc.parallelize(docs, 2)
160160

161-
// setGammaShape large so to avoid the stochastic impact.
161+
// Set GammaShape large to avoid the stochastic impact.
162162
val op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51).setGammaShape(1e40)
163163
.setMiniBatchFraction(1)
164-
val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op)
164+
val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op).setSeed(12345)
165165

166166
val state = op.initialize(corpus, lda)
167167
// override lambda to simulate an intermediate state
@@ -175,8 +175,8 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
175175

176176
// verify the result, Note this generate the identical result as
177177
// [[https://github.com/Blei-Lab/onlineldavb]]
178-
val topic1 = op.lambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
179-
val topic2 = op.lambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
178+
val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
179+
val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
180180
assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
181181
assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
182182
}
@@ -186,7 +186,6 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
186186
Vectors.sparse(6, Array(0, 1), Array(1, 1)),
187187
Vectors.sparse(6, Array(1, 2), Array(1, 1)),
188188
Vectors.sparse(6, Array(0, 2), Array(1, 1)),
189-
190189
Vectors.sparse(6, Array(3, 4), Array(1, 1)),
191190
Vectors.sparse(6, Array(3, 5), Array(1, 1)),
192191
Vectors.sparse(6, Array(4, 5), Array(1, 1))
@@ -200,6 +199,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
200199
.setTopicConcentration(0.01)
201200
.setMaxIterations(100)
202201
.setOptimizer(op)
202+
.setSeed(12345)
203203

204204
val ldaModel = lda.run(docs)
205205
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
@@ -208,10 +208,10 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
208208
}
209209

210210
// check distribution for each topic, typical distribution is (0.3, 0.3, 0.3, 0.02, 0.02, 0.02)
211-
topics.foreach(topic =>{
212-
val smalls = topic.filter(t => (t._2 < 0.1)).map(_._2)
213-
assert(smalls.size == 3 && smalls.sum < 0.2)
214-
})
211+
topics.foreach { topic =>
212+
val smalls = topic.filter(t => t._2 < 0.1).map(_._2)
213+
assert(smalls.length == 3 && smalls.sum < 0.2)
214+
}
215215
}
216216

217217
}

0 commit comments

Comments
 (0)