@@ -249,32 +249,24 @@ class LDA private (
249
249
250
250
251
251
/**
252
+ * TODO: add API to take documents paths once tokenizer is ready.
252
253
* Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
253
254
*
254
255
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
255
256
* The term count vectors are "bags of words" with a fixed-size vocabulary
256
257
* (where the vocabulary size is the length of the vector).
257
258
* Document IDs must be unique and >= 0.
258
- * @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384].
259
- * -1 for automatic batchNumber.
259
+ * @param batchNumber Number of batches to split input corpus . For each batch, recommendation
260
+ * size is [4, 16384]. -1 for automatic batchNumber.
260
261
* @return Inferred LDA model
261
262
*/
262
263
def runOnlineLDA (documents : RDD [(Long , Vector )], batchNumber : Int = - 1 ): LDAModel = {
263
- val D = documents.count().toInt
264
- val batchSize =
265
- if (batchNumber == - 1 ) { // auto mode
266
- if (D / 100 > 16384 ) 16384
267
- else if (D / 100 < 4 ) 4
268
- else D / 100
269
- }
270
- else {
271
- require(batchNumber > 0 , " batchNumber should be positive or -1" )
272
- D / batchNumber
273
- }
264
+ require(batchNumber > 0 || batchNumber == - 1 ,
265
+ s " batchNumber must be greater or -1, but was set to $batchNumber" )
274
266
275
- val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, batchSize )
276
- ( 0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next() )
277
- new LocalLDAModel (Matrices .fromBreeze(onlineLDA.lambda ).transpose)
267
+ val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, batchNumber )
268
+ val model = onlineLDA.optimize( )
269
+ new LocalLDAModel (Matrices .fromBreeze(model ).transpose)
278
270
}
279
271
280
272
/** Java-friendly version of [[run() ]] */
@@ -437,39 +429,54 @@ private[clustering] object LDA {
437
429
private [clustering] class OnlineLDAOptimizer (
438
430
private val documents : RDD [(Long , Vector )],
439
431
private val k : Int ,
440
- private val batchSize : Int ) extends Serializable {
432
+ private val batchNumber : Int ) extends Serializable {
441
433
442
434
private val vocabSize = documents.first._2.size
443
435
private val D = documents.count().toInt
444
- val actualBatchNumber = Math .ceil(D .toDouble / batchSize).toInt
436
+ private val batchSize =
437
+ if (batchNumber == - 1 ) { // auto mode
438
+ if (D / 100 > 16384 ) 16384
439
+ else if (D / 100 < 4 ) 4
440
+ else D / 100
441
+ }
442
+ else {
443
+ D / batchNumber
444
+ }
445
445
446
446
// Initialize the variational distribution q(beta|lambda)
447
- var lambda = getGammaMatrix(k, vocabSize) // K * V
447
+ private var lambda = getGammaMatrix(k, vocabSize) // K * V
448
448
private var Elogbeta = dirichlet_expectation(lambda) // K * V
449
449
private var expElogbeta = exp(Elogbeta ) // K * V
450
450
451
- private var batchId = 0
452
- def next (): Unit = {
453
- require(batchId < actualBatchNumber)
454
- // weight of the mini-batch. 1024 down weights early iterations
455
- val weight = math.pow(1024 + batchId, - 0.5 )
456
- val batch = documents.sample(true , batchSize.toDouble / D )
457
- batch.cache()
458
- // Given a mini-batch of documents, estimates the parameters gamma controlling the
459
- // variational distribution over the topic weights for each document in the mini-batch.
460
- var stat = BDM .zeros[Double ](k, vocabSize)
461
- stat = batch.aggregate(stat)(seqOp, _ += _)
462
- stat = stat :* expElogbeta
451
+ def optimize (): BDM [Double ] = {
452
+ val actualBatchNumber = Math .ceil(D .toDouble / batchSize).toInt
453
+ for (i <- 1 to actualBatchNumber){
454
+ val batch = documents.sample(true , batchSize.toDouble / D )
455
+
456
+ // Given a mini-batch of documents, estimates the parameters gamma controlling the
457
+ // variational distribution over the topic weights for each document in the mini-batch.
458
+ var stat = BDM .zeros[Double ](k, vocabSize)
459
+ stat = batch.treeAggregate(stat)(gradient, _ += _)
460
+ update(stat, i)
461
+ }
462
+ lambda
463
+ }
464
+
465
+ private def update (raw : BDM [Double ], iter: Int ): Unit = {
466
+ // weight of the mini-batch. 1024 helps down weights early iterations
467
+ val weight = math.pow(1024 + iter, - 0.5 )
468
+
469
+ // This step finishes computing the sufficient statistics for the M step
470
+ val stat = raw :* expElogbeta
463
471
464
472
// Update lambda based on documents.
465
473
lambda = lambda * (1 - weight) + (stat * D .toDouble / batchSize.toDouble + 1.0 / k) * weight
466
474
Elogbeta = dirichlet_expectation(lambda)
467
475
expElogbeta = exp(Elogbeta )
468
- batchId += 1
469
476
}
470
477
471
478
// for each document d update that document's gamma and phi
472
- private def seqOp (stat : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
479
+ private def gradient (stat : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
473
480
val termCounts = doc._2
474
481
val (ids, cts) = termCounts match {
475
482
case v : DenseVector => (((0 until v.size).toList), v.values)
@@ -488,7 +495,7 @@ private[clustering] object LDA {
488
495
val ctsVector = new BDV [Double ](cts).t // 1 * ids
489
496
490
497
// Iterate between gamma and phi until convergence
491
- while (meanchange > 1e-6 ) {
498
+ while (meanchange > 1e-5 ) {
492
499
val lastgamma = gammad
493
500
// 1*K 1 * ids ids * k
494
501
gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0 / k
0 commit comments