Skip to content

Commit d640d9c

Browse files
committed
online lda initial checkin
1 parent 6580929 commit d640d9c

File tree

4 files changed

+133
-20
lines changed

4 files changed

+133
-20
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scopt.OptionParser
2626
import org.apache.log4j.{Level, Logger}
2727

2828
import org.apache.spark.{SparkContext, SparkConf}
29-
import org.apache.spark.mllib.clustering.LDA
29+
import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
3030
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3131
import org.apache.spark.rdd.RDD
3232

@@ -137,7 +137,7 @@ object LDAExample {
137137
lda.setCheckpointDir(params.checkpointDir.get)
138138
}
139139
val startTime = System.nanoTime()
140-
val ldaModel = lda.run(corpus)
140+
val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
141141
val elapsed = (System.nanoTime() - startTime) / 1e9
142142

143143
println(s"Finished training LDA model. Summary:")
@@ -159,6 +159,7 @@ object LDAExample {
159159
}
160160
println()
161161
}
162+
sc.stop()
162163

163164
}
164165

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 128 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@ package org.apache.spark.mllib.clustering
1919

2020
import java.util.Random
2121

22-
import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy}
22+
import breeze.linalg.{DenseVector => BDV, normalize, kron, sum, axpy => brzAxpy, DenseMatrix => BDM}
23+
import breeze.numerics.{exp, abs, digamma}
24+
import breeze.stats.distributions.Gamma
2325

2426
import org.apache.spark.Logging
2527
import org.apache.spark.annotation.Experimental
2628
import org.apache.spark.api.java.JavaPairRDD
2729
import org.apache.spark.graphx._
2830
import org.apache.spark.graphx.impl.GraphImpl
2931
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
30-
import org.apache.spark.mllib.linalg.Vector
32+
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices}
3133
import org.apache.spark.rdd.RDD
3234
import org.apache.spark.util.Utils
3335

@@ -250,6 +252,10 @@ class LDA private (
250252
this
251253
}
252254

255+
object LDAMode extends Enumeration {
256+
val EM, Online = Value
257+
}
258+
253259
/**
254260
* Learn an LDA model using the given dataset.
255261
*
@@ -259,24 +265,39 @@ class LDA private (
259265
* Document IDs must be unique and >= 0.
260266
* @return Inferred LDA model
261267
*/
262-
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
263-
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
264-
checkpointDir, checkpointInterval)
265-
var iter = 0
266-
val iterationTimes = Array.fill[Double](maxIterations)(0)
267-
while (iter < maxIterations) {
268-
val start = System.nanoTime()
269-
state.next()
270-
val elapsedSeconds = (System.nanoTime() - start) / 1e9
271-
iterationTimes(iter) = elapsedSeconds
272-
iter += 1
268+
def run(documents: RDD[(Long, Vector)], mode: LDAMode.Value = LDAMode.EM ): LDAModel = {
269+
mode match {
270+
case LDAMode.EM =>
271+
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
272+
checkpointDir, checkpointInterval)
273+
var iter = 0
274+
val iterationTimes = Array.fill[Double](maxIterations)(0)
275+
while (iter < maxIterations) {
276+
val start = System.nanoTime()
277+
state.next()
278+
val elapsedSeconds = (System.nanoTime() - start) / 1e9
279+
iterationTimes(iter) = elapsedSeconds
280+
iter += 1
281+
}
282+
state.graphCheckpointer.deleteAllCheckpoints()
283+
new DistributedLDAModel(state, iterationTimes)
284+
case LDAMode.Online =>
285+
//todo: delete the comment in next line
286+
// I changed the return type to LDAModel, as DistributedLDAModel is based on Graph.
287+
val vocabSize = documents.first._2.size
288+
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, vocabSize)
289+
var iter = 0
290+
while (iter < onlineLDA.batchNumber) {
291+
onlineLDA.next()
292+
iter += 1
293+
}
294+
new LocalLDAModel(Matrices.fromBreeze(onlineLDA._lambda).transpose)
295+
case _ => throw new IllegalArgumentException(s"Do not support mode $mode.")
273296
}
274-
state.graphCheckpointer.deleteAllCheckpoints()
275-
new DistributedLDAModel(state, iterationTimes)
276297
}
277298

278299
/** Java-friendly version of [[run()]] */
279-
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
300+
def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
280301
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
281302
}
282303
}
@@ -429,6 +450,97 @@ private[clustering] object LDA {
429450

430451
}
431452

453+
// todo: add reference to paper and Hoffman
454+
class OnlineLDAOptimizer(
455+
val documents: RDD[(Long, Vector)],
456+
val k: Int,
457+
val vocabSize: Int) extends Serializable{
458+
459+
private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten
460+
private val tau0 = 1024 // down weights early iterations
461+
private val D = documents.count()
462+
private val batchSize = if (D / 1000 > 4096) 4096
463+
else if (D / 1000 < 4) 4
464+
else D / 1000
465+
val batchNumber = (D/batchSize + 1).toInt
466+
// todo: performance killer, need to be replaced
467+
private val batches = documents.randomSplit(Array.fill[Double](batchNumber)(1.0))
468+
469+
// Initialize the variational distribution q(beta|lambda)
470+
var _lambda = getGammaMatrix(k, vocabSize) // K * V
471+
private var _Elogbeta = dirichlet_expectation(_lambda) // K * V
472+
private var _expElogbeta = exp(_Elogbeta) // K * V
473+
474+
private var batchCount = 0
475+
def next(): Unit = {
476+
// weight of the mini-batch.
477+
val rhot = math.pow(tau0 + batchCount, -kappa)
478+
479+
var stat = BDM.zeros[Double](k, vocabSize)
480+
stat = batches(batchCount).aggregate(stat)(seqOp, _ += _)
481+
482+
stat = stat :* _expElogbeta
483+
_lambda = _lambda * (1 - rhot) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * rhot
484+
_Elogbeta = dirichlet_expectation(_lambda)
485+
_expElogbeta = exp(_Elogbeta)
486+
batchCount += 1
487+
}
488+
489+
private def seqOp(other: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
490+
val termCounts = doc._2
491+
val (ids, cts) = termCounts match {
492+
case v: DenseVector => (((0 until v.size).toList), v.values)
493+
case v: SparseVector => (v.indices.toList, v.values)
494+
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
495+
}
496+
497+
var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K
498+
var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
499+
var expElogthetad = exp(Elogthetad.t).t // 1 * K
500+
val expElogbetad = _expElogbeta(::, ids).toDenseMatrix // K * ids
501+
502+
var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
503+
var meanchange = 1D
504+
val ctsVector = new BDV[Double](cts).t // 1 * ids
505+
506+
while (meanchange > 1e-6) {
507+
val lastgamma = gammad
508+
// 1*K 1 * ids ids * k
509+
gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0/k
510+
Elogthetad = vector_dirichlet_expectation(gammad.t).t
511+
expElogthetad = exp(Elogthetad.t).t
512+
phinorm = expElogthetad * expElogbetad + 1e-100
513+
meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
514+
}
515+
516+
val v1 = expElogthetad.t.toDenseMatrix.t
517+
val v2 = (ctsVector / phinorm).t.toDenseMatrix
518+
val outerResult = kron(v1, v2) // K * ids
519+
for (i <- 0 until ids.size) {
520+
other(::, ids(i)) := (other(::, ids(i)) + outerResult(::, i))
521+
}
522+
other
523+
}
524+
525+
def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
526+
val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0)
527+
val temp = gammaRandomGenerator.sample(row * col).toArray
528+
(new BDM[Double](col, row, temp)).t
529+
}
530+
531+
def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
532+
val rowSum = sum(alpha(breeze.linalg.*, ::))
533+
val digAlpha = digamma(alpha)
534+
val digRowSum = digamma(rowSum)
535+
val result = digAlpha(::, breeze.linalg.*) - digRowSum
536+
result
537+
}
538+
539+
def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
540+
digamma(v) - digamma(sum(v))
541+
}
542+
}
543+
432544
/**
433545
* Compute gamma_{wjk}, a distribution over topics k.
434546
*/

mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public void distributedLDAModel() {
8888
.setMaxIterations(5)
8989
.setSeed(12345);
9090

91-
DistributedLDAModel model = lda.run(corpus);
91+
DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
9292

9393
// Check: basic parameters
9494
LocalLDAModel localModel = model.toLocal();

mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
6868
.setSeed(12345)
6969
val corpus = sc.parallelize(tinyCorpus, 2)
7070

71-
val model: DistributedLDAModel = lda.run(corpus)
71+
val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
7272

7373
// Check: basic parameters
7474
val localModel = model.toLocal

0 commit comments

Comments
 (0)