@@ -19,15 +19,17 @@ package org.apache.spark.mllib.clustering
19
19
20
20
import java .util .Random
21
21
22
- import breeze .linalg .{DenseVector => BDV , normalize , axpy => brzAxpy }
22
+ import breeze .linalg .{DenseVector => BDV , normalize , kron , sum , axpy => brzAxpy , DenseMatrix => BDM }
23
+ import breeze .numerics .{exp , abs , digamma }
24
+ import breeze .stats .distributions .Gamma
23
25
24
26
import org .apache .spark .Logging
25
27
import org .apache .spark .annotation .Experimental
26
28
import org .apache .spark .api .java .JavaPairRDD
27
29
import org .apache .spark .graphx ._
28
30
import org .apache .spark .graphx .impl .GraphImpl
29
31
import org .apache .spark .mllib .impl .PeriodicGraphCheckpointer
30
- import org .apache .spark .mllib .linalg .Vector
32
+ import org .apache .spark .mllib .linalg .{ Vector , DenseVector , SparseVector , Matrices }
31
33
import org .apache .spark .rdd .RDD
32
34
import org .apache .spark .util .Utils
33
35
@@ -250,6 +252,10 @@ class LDA private (
250
252
this
251
253
}
252
254
255
+ object LDAMode extends Enumeration {
256
+ val EM, Online = Value
257
+ }
258
+
253
259
/**
254
260
* Learn an LDA model using the given dataset.
255
261
*
@@ -259,24 +265,39 @@ class LDA private (
259
265
* Document IDs must be unique and >= 0.
260
266
* @return Inferred LDA model
261
267
*/
262
- def run (documents : RDD [(Long , Vector )]): DistributedLDAModel = {
263
- val state = LDA .initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
264
- checkpointDir, checkpointInterval)
265
- var iter = 0
266
- val iterationTimes = Array .fill[Double ](maxIterations)(0 )
267
- while (iter < maxIterations) {
268
- val start = System .nanoTime()
269
- state.next()
270
- val elapsedSeconds = (System .nanoTime() - start) / 1e9
271
- iterationTimes(iter) = elapsedSeconds
272
- iter += 1
268
+ def run (documents : RDD [(Long , Vector )], mode : LDAMode .Value = LDAMode .EM ): LDAModel = {
269
+ mode match {
270
+ case LDAMode .EM =>
271
+ val state = LDA .initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
272
+ checkpointDir, checkpointInterval)
273
+ var iter = 0
274
+ val iterationTimes = Array .fill[Double ](maxIterations)(0 )
275
+ while (iter < maxIterations) {
276
+ val start = System .nanoTime()
277
+ state.next()
278
+ val elapsedSeconds = (System .nanoTime() - start) / 1e9
279
+ iterationTimes(iter) = elapsedSeconds
280
+ iter += 1
281
+ }
282
+ state.graphCheckpointer.deleteAllCheckpoints()
283
+ new DistributedLDAModel (state, iterationTimes)
284
+ case LDAMode .Online =>
285
+ // todo: delete the comment in next line
286
+ // I changed the return type to LDAModel, as DistributedLDAModel is based on Graph.
287
+ val vocabSize = documents.first._2.size
288
+ val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, vocabSize)
289
+ var iter = 0
290
+ while (iter < onlineLDA.batchNumber) {
291
+ onlineLDA.next()
292
+ iter += 1
293
+ }
294
+ new LocalLDAModel (Matrices .fromBreeze(onlineLDA._lambda).transpose)
295
+ case _ => throw new IllegalArgumentException (s " Do not support mode $mode. " )
273
296
}
274
- state.graphCheckpointer.deleteAllCheckpoints()
275
- new DistributedLDAModel (state, iterationTimes)
276
297
}
277
298
278
299
/** Java-friendly version of [[run() ]] */
279
- def run (documents : JavaPairRDD [java.lang.Long , Vector ]): DistributedLDAModel = {
300
+ def run (documents : JavaPairRDD [java.lang.Long , Vector ]): LDAModel = {
280
301
run(documents.rdd.asInstanceOf [RDD [(Long , Vector )]])
281
302
}
282
303
}
@@ -429,6 +450,97 @@ private[clustering] object LDA {
429
450
430
451
}
431
452
453
+ // todo: add reference to paper and Hoffman
454
+ class OnlineLDAOptimizer (
455
+ val documents : RDD [(Long , Vector )],
456
+ val k : Int ,
457
+ val vocabSize : Int ) extends Serializable {
458
+
459
+ private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten
460
+ private val tau0 = 1024 // down weights early iterations
461
+ private val D = documents.count()
462
+ private val batchSize = if (D / 1000 > 4096 ) 4096
463
+ else if (D / 1000 < 4 ) 4
464
+ else D / 1000
465
+ val batchNumber = (D / batchSize + 1 ).toInt
466
+ // todo: performance killer, need to be replaced
467
+ private val batches = documents.randomSplit(Array .fill[Double ](batchNumber)(1.0 ))
468
+
469
+ // Initialize the variational distribution q(beta|lambda)
470
+ var _lambda = getGammaMatrix(k, vocabSize) // K * V
471
+ private var _Elogbeta = dirichlet_expectation(_lambda) // K * V
472
+ private var _expElogbeta = exp(_Elogbeta) // K * V
473
+
474
+ private var batchCount = 0
475
+ def next (): Unit = {
476
+ // weight of the mini-batch.
477
+ val rhot = math.pow(tau0 + batchCount, - kappa)
478
+
479
+ var stat = BDM .zeros[Double ](k, vocabSize)
480
+ stat = batches(batchCount).aggregate(stat)(seqOp, _ += _)
481
+
482
+ stat = stat :* _expElogbeta
483
+ _lambda = _lambda * (1 - rhot) + (stat * D .toDouble / batchSize.toDouble + 1.0 / k) * rhot
484
+ _Elogbeta = dirichlet_expectation(_lambda)
485
+ _expElogbeta = exp(_Elogbeta)
486
+ batchCount += 1
487
+ }
488
+
489
+ private def seqOp (other : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
490
+ val termCounts = doc._2
491
+ val (ids, cts) = termCounts match {
492
+ case v : DenseVector => (((0 until v.size).toList), v.values)
493
+ case v : SparseVector => (v.indices.toList, v.values)
494
+ case v => throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
495
+ }
496
+
497
+ var gammad = new Gamma (100 , 1.0 / 100.0 ).samplesVector(k).t // 1 * K
498
+ var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
499
+ var expElogthetad = exp(Elogthetad .t).t // 1 * K
500
+ val expElogbetad = _expElogbeta(:: , ids).toDenseMatrix // K * ids
501
+
502
+ var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
503
+ var meanchange = 1D
504
+ val ctsVector = new BDV [Double ](cts).t // 1 * ids
505
+
506
+ while (meanchange > 1e-6 ) {
507
+ val lastgamma = gammad
508
+ // 1*K 1 * ids ids * k
509
+ gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0 / k
510
+ Elogthetad = vector_dirichlet_expectation(gammad.t).t
511
+ expElogthetad = exp(Elogthetad .t).t
512
+ phinorm = expElogthetad * expElogbetad + 1e-100
513
+ meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
514
+ }
515
+
516
+ val v1 = expElogthetad.t.toDenseMatrix.t
517
+ val v2 = (ctsVector / phinorm).t.toDenseMatrix
518
+ val outerResult = kron(v1, v2) // K * ids
519
+ for (i <- 0 until ids.size) {
520
+ other(:: , ids(i)) := (other(:: , ids(i)) + outerResult(:: , i))
521
+ }
522
+ other
523
+ }
524
+
525
+ def getGammaMatrix (row: Int , col: Int ): BDM [Double ] = {
526
+ val gammaRandomGenerator = new Gamma (100 , 1.0 / 100.0 )
527
+ val temp = gammaRandomGenerator.sample(row * col).toArray
528
+ (new BDM [Double ](col, row, temp)).t
529
+ }
530
+
531
+ def dirichlet_expectation (alpha : BDM [Double ]): BDM [Double ] = {
532
+ val rowSum = sum(alpha(breeze.linalg.* , :: ))
533
+ val digAlpha = digamma(alpha)
534
+ val digRowSum = digamma(rowSum)
535
+ val result = digAlpha(:: , breeze.linalg.* ) - digRowSum
536
+ result
537
+ }
538
+
539
+ def vector_dirichlet_expectation (v : BDV [Double ]): (BDV [Double ]) = {
540
+ digamma(v) - digamma(sum(v))
541
+ }
542
+ }
543
+
432
544
/**
433
545
* Compute gamma_{wjk}, a distribution over topics k.
434
546
*/
0 commit comments