Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit f367cc9

Browse files
committed
change to optimization
1 parent 8cb16a6 commit f367cc9

File tree

1 file changed

+41
-34
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/clustering

1 file changed

+41
-34
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -249,32 +249,24 @@ class LDA private (
249249

250250

251251
/**
252+
* TODO: add API to take documents paths once tokenizer is ready.
252253
* Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
253254
*
254255
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
255256
* The term count vectors are "bags of words" with a fixed-size vocabulary
256257
* (where the vocabulary size is the length of the vector).
257258
* Document IDs must be unique and >= 0.
258-
* @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384].
259-
* -1 for automatic batchNumber.
259+
* @param batchNumber Number of batches to split input corpus. For each batch, recommendation
260+
* size is [4, 16384]. -1 for automatic batchNumber.
260261
* @return Inferred LDA model
261262
*/
262263
def runOnlineLDA(documents: RDD[(Long, Vector)], batchNumber: Int = -1): LDAModel = {
263-
val D = documents.count().toInt
264-
val batchSize =
265-
if (batchNumber == -1) { // auto mode
266-
if (D / 100 > 16384) 16384
267-
else if (D / 100 < 4) 4
268-
else D / 100
269-
}
270-
else {
271-
require(batchNumber > 0, "batchNumber should be positive or -1")
272-
D / batchNumber
273-
}
264+
require(batchNumber > 0 || batchNumber == -1,
265+
s"batchNumber must be greater or -1, but was set to $batchNumber")
274266

275-
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, batchSize)
276-
(0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next())
277-
new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose)
267+
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, batchNumber)
268+
val model = onlineLDA.optimize()
269+
new LocalLDAModel(Matrices.fromBreeze(model).transpose)
278270
}
279271

280272
/** Java-friendly version of [[run()]] */
@@ -437,39 +429,54 @@ private[clustering] object LDA {
437429
private[clustering] class OnlineLDAOptimizer(
438430
private val documents: RDD[(Long, Vector)],
439431
private val k: Int,
440-
private val batchSize: Int) extends Serializable{
432+
private val batchNumber: Int) extends Serializable{
441433

442434
private val vocabSize = documents.first._2.size
443435
private val D = documents.count().toInt
444-
val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt
436+
private val batchSize =
437+
if (batchNumber == -1) { // auto mode
438+
if (D / 100 > 16384) 16384
439+
else if (D / 100 < 4) 4
440+
else D / 100
441+
}
442+
else {
443+
D / batchNumber
444+
}
445445

446446
// Initialize the variational distribution q(beta|lambda)
447-
var lambda = getGammaMatrix(k, vocabSize) // K * V
447+
private var lambda = getGammaMatrix(k, vocabSize) // K * V
448448
private var Elogbeta = dirichlet_expectation(lambda) // K * V
449449
private var expElogbeta = exp(Elogbeta) // K * V
450450

451-
private var batchId = 0
452-
def next(): Unit = {
453-
require(batchId < actualBatchNumber)
454-
// weight of the mini-batch. 1024 down weights early iterations
455-
val weight = math.pow(1024 + batchId, -0.5)
456-
val batch = documents.sample(true, batchSize.toDouble / D)
457-
batch.cache()
458-
// Given a mini-batch of documents, estimates the parameters gamma controlling the
459-
// variational distribution over the topic weights for each document in the mini-batch.
460-
var stat = BDM.zeros[Double](k, vocabSize)
461-
stat = batch.aggregate(stat)(seqOp, _ += _)
462-
stat = stat :* expElogbeta
451+
def optimize(): BDM[Double] = {
452+
val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt
453+
for(i <- 1 to actualBatchNumber){
454+
val batch = documents.sample(true, batchSize.toDouble / D)
455+
456+
// Given a mini-batch of documents, estimates the parameters gamma controlling the
457+
// variational distribution over the topic weights for each document in the mini-batch.
458+
var stat = BDM.zeros[Double](k, vocabSize)
459+
stat = batch.treeAggregate(stat)(gradient, _ += _)
460+
update(stat, i)
461+
}
462+
lambda
463+
}
464+
465+
private def update(raw: BDM[Double], iter:Int): Unit ={
466+
// weight of the mini-batch. 1024 helps down weights early iterations
467+
val weight = math.pow(1024 + iter, -0.5)
468+
469+
// This step finishes computing the sufficient statistics for the M step
470+
val stat = raw :* expElogbeta
463471

464472
// Update lambda based on documents.
465473
lambda = lambda * (1 - weight) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * weight
466474
Elogbeta = dirichlet_expectation(lambda)
467475
expElogbeta = exp(Elogbeta)
468-
batchId += 1
469476
}
470477

471478
// for each document d update that document's gamma and phi
472-
private def seqOp(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
479+
private def gradient(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
473480
val termCounts = doc._2
474481
val (ids, cts) = termCounts match {
475482
case v: DenseVector => (((0 until v.size).toList), v.values)
@@ -488,7 +495,7 @@ private[clustering] object LDA {
488495
val ctsVector = new BDV[Double](cts).t // 1 * ids
489496

490497
// Iterate between gamma and phi until convergence
491-
while (meanchange > 1e-6) {
498+
while (meanchange > 1e-5) {
492499
val lastgamma = gammad
493500
// 1*K 1 * ids ids * k
494501
gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0/k

0 commit comments

Comments
 (0)