17
17
18
18
package org .apache .spark .mllib .clustering
19
19
20
- import java .util .Random
21
-
22
- import breeze .linalg .{DenseVector => BDV , normalize }
23
-
20
+ import breeze .linalg .{DenseVector => BDV }
24
21
import org .apache .spark .Logging
25
22
import org .apache .spark .annotation .Experimental
26
23
import org .apache .spark .api .java .JavaPairRDD
27
24
import org .apache .spark .graphx ._
28
- import org .apache .spark .graphx .impl .GraphImpl
29
- import org .apache .spark .mllib .impl .PeriodicGraphCheckpointer
30
25
import org .apache .spark .mllib .linalg .Vector
31
26
import org .apache .spark .rdd .RDD
32
27
import org .apache .spark .util .Utils
@@ -42,16 +37,9 @@ import org.apache.spark.util.Utils
42
37
* - "token": instance of a term appearing in a document
43
38
* - "topic": multinomial distribution over words representing some concept
44
39
*
45
- * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
46
- * according to the Asuncion et al. (2009) paper referenced below.
47
- *
48
40
* References:
49
41
* - Original LDA paper (journal version):
50
42
* Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
51
- * - This class implements their "smoothed" LDA model.
52
- * - Paper which clearly explains several algorithms, including EM:
53
- * Asuncion, Welling, Smyth, and Teh.
54
- * "On Smoothing and Inference for Topic Models." UAI, 2009.
55
43
*
56
44
* @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
57
45
* (Wikipedia)]]
@@ -63,10 +51,11 @@ class LDA private (
63
51
private var docConcentration : Double ,
64
52
private var topicConcentration : Double ,
65
53
private var seed : Long ,
66
- private var checkpointInterval : Int ) extends Logging {
54
+ private var checkpointInterval : Int ,
55
+ private var ldaOptimizer : LDAOptimizer ) extends Logging {
67
56
68
57
def this () = this (k = 10 , maxIterations = 20 , docConcentration = - 1 , topicConcentration = - 1 ,
69
- seed = Utils .random.nextLong(), checkpointInterval = 10 )
58
+ seed = Utils .random.nextLong(), checkpointInterval = 10 , ldaOptimizer = new EMLDAOptimizer )
70
59
71
60
/**
72
61
* Number of topics to infer. I.e., the number of soft cluster centers.
@@ -220,6 +209,32 @@ class LDA private (
220
209
this
221
210
}
222
211
212
+
213
+ /** LDAOptimizer used to perform the actual calculation */
214
+ def getOptimizer : LDAOptimizer = ldaOptimizer
215
+
216
+ /**
217
+ * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer)
218
+ */
219
+ def setOptimizer (optimizer : LDAOptimizer ): this .type = {
220
+ this .ldaOptimizer = optimizer
221
+ this
222
+ }
223
+
224
+ /**
225
+ * Set the LDAOptimizer used to perform the actual calculation by algorithm name.
226
+ * Currently "em" is supported.
227
+ */
228
+ def setOptimizer (optimizerName : String ): this .type = {
229
+ this .ldaOptimizer =
230
+ optimizerName.toLowerCase match {
231
+ case " em" => new EMLDAOptimizer
232
+ case other =>
233
+ throw new IllegalArgumentException (s " Only em is supported but got $other. " )
234
+ }
235
+ this
236
+ }
237
+
223
238
/**
224
239
* Learn an LDA model using the given dataset.
225
240
*
@@ -229,9 +244,9 @@ class LDA private (
229
244
* Document IDs must be unique and >= 0.
230
245
* @return Inferred LDA model
231
246
*/
232
- def run (documents : RDD [(Long , Vector )]): DistributedLDAModel = {
233
- val state = LDA .initialState(documents, k, getDocConcentration, getTopicConcentration, seed ,
234
- checkpointInterval)
247
+ def run (documents : RDD [(Long , Vector )]): LDAModel = {
248
+ val state = ldaOptimizer .initialState(documents, k, getDocConcentration, getTopicConcentration,
249
+ seed, checkpointInterval)
235
250
var iter = 0
236
251
val iterationTimes = Array .fill[Double ](maxIterations)(0 )
237
252
while (iter < maxIterations) {
@@ -241,12 +256,11 @@ class LDA private (
241
256
iterationTimes(iter) = elapsedSeconds
242
257
iter += 1
243
258
}
244
- state.graphCheckpointer.deleteAllCheckpoints()
245
- new DistributedLDAModel (state, iterationTimes)
259
+ state.getLDAModel(iterationTimes)
246
260
}
247
261
248
262
/** Java-friendly version of [[run() ]] */
249
- def run (documents : JavaPairRDD [java.lang.Long , Vector ]): DistributedLDAModel = {
263
+ def run (documents : JavaPairRDD [java.lang.Long , Vector ]): LDAModel = {
250
264
run(documents.rdd.asInstanceOf [RDD [(Long , Vector )]])
251
265
}
252
266
}
@@ -320,88 +334,10 @@ private[clustering] object LDA {
320
334
321
335
private [clustering] def isTermVertex (v : (VertexId , _)): Boolean = v._1 < 0
322
336
323
- /**
324
- * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
325
- *
326
- * @param graph EM graph, storing current parameter estimates in vertex descriptors and
327
- * data (token counts) in edge descriptors.
328
- * @param k Number of topics
329
- * @param vocabSize Number of unique terms
330
- * @param docConcentration "alpha"
331
- * @param topicConcentration "beta" or "eta"
332
- */
333
- private [clustering] class EMOptimizer (
334
- var graph : Graph [TopicCounts , TokenCount ],
335
- val k : Int ,
336
- val vocabSize : Int ,
337
- val docConcentration : Double ,
338
- val topicConcentration : Double ,
339
- checkpointInterval : Int ) {
340
-
341
- private [LDA ] val graphCheckpointer = new PeriodicGraphCheckpointer [TopicCounts , TokenCount ](
342
- graph, checkpointInterval)
343
-
344
- def next (): EMOptimizer = {
345
- val eta = topicConcentration
346
- val W = vocabSize
347
- val alpha = docConcentration
348
-
349
- val N_k = globalTopicTotals
350
- val sendMsg : EdgeContext [TopicCounts , TokenCount , (Boolean , TopicCounts )] => Unit =
351
- (edgeContext) => {
352
- // Compute N_{wj} gamma_{wjk}
353
- val N_wj = edgeContext.attr
354
- // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
355
- // N_{wj}.
356
- val scaledTopicDistribution : TopicCounts =
357
- computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k , W , eta, alpha) *= N_wj
358
- edgeContext.sendToDst((false , scaledTopicDistribution))
359
- edgeContext.sendToSrc((false , scaledTopicDistribution))
360
- }
361
- // This is a hack to detect whether we could modify the values in-place.
362
- // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
363
- val mergeMsg : ((Boolean , TopicCounts ), (Boolean , TopicCounts )) => (Boolean , TopicCounts ) =
364
- (m0, m1) => {
365
- val sum =
366
- if (m0._1) {
367
- m0._2 += m1._2
368
- } else if (m1._1) {
369
- m1._2 += m0._2
370
- } else {
371
- m0._2 + m1._2
372
- }
373
- (true , sum)
374
- }
375
- // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
376
- val docTopicDistributions : VertexRDD [TopicCounts ] =
377
- graph.aggregateMessages[(Boolean , TopicCounts )](sendMsg, mergeMsg)
378
- .mapValues(_._2)
379
- // Update the vertex descriptors with the new counts.
380
- val newGraph = GraphImpl .fromExistingRDDs(docTopicDistributions, graph.edges)
381
- graph = newGraph
382
- graphCheckpointer.updateGraph(newGraph)
383
- globalTopicTotals = computeGlobalTopicTotals()
384
- this
385
- }
386
-
387
- /**
388
- * Aggregate distributions over topics from all term vertices.
389
- *
390
- * Note: This executes an action on the graph RDDs.
391
- */
392
- var globalTopicTotals : TopicCounts = computeGlobalTopicTotals()
393
-
394
- private def computeGlobalTopicTotals (): TopicCounts = {
395
- val numTopics = k
396
- graph.vertices.filter(isTermVertex).values.fold(BDV .zeros[Double ](numTopics))(_ += _)
397
- }
398
-
399
- }
400
-
401
337
/**
402
338
* Compute gamma_{wjk}, a distribution over topics k.
403
339
*/
404
- private def computePTopic (
340
+ private [clustering] def computePTopic (
405
341
docTopicCounts : TopicCounts ,
406
342
termTopicCounts : TopicCounts ,
407
343
totalTopicCounts : TopicCounts ,
@@ -427,49 +363,4 @@ private[clustering] object LDA {
427
363
// normalize
428
364
BDV (gamma_wj) /= sum
429
365
}
430
-
431
- /**
432
- * Compute bipartite term/doc graph.
433
- */
434
- private def initialState (
435
- docs : RDD [(Long , Vector )],
436
- k : Int ,
437
- docConcentration : Double ,
438
- topicConcentration : Double ,
439
- randomSeed : Long ,
440
- checkpointInterval : Int ): EMOptimizer = {
441
- // For each document, create an edge (Document -> Term) for each unique term in the document.
442
- val edges : RDD [Edge [TokenCount ]] = docs.flatMap { case (docID : Long , termCounts : Vector ) =>
443
- // Add edges for terms with non-zero counts.
444
- termCounts.toBreeze.activeIterator.filter(_._2 != 0.0 ).map { case (term, cnt) =>
445
- Edge (docID, term2index(term), cnt)
446
- }
447
- }
448
-
449
- val vocabSize = docs.take(1 ).head._2.size
450
-
451
- // Create vertices.
452
- // Initially, we use random soft assignments of tokens to topics (random gamma).
453
- def createVertices (): RDD [(VertexId , TopicCounts )] = {
454
- val verticesTMP : RDD [(VertexId , TopicCounts )] =
455
- edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
456
- val random = new Random (partIndex + randomSeed)
457
- partEdges.flatMap { edge =>
458
- val gamma = normalize(BDV .fill[Double ](k)(random.nextDouble()), 1.0 )
459
- val sum = gamma * edge.attr
460
- Seq ((edge.srcId, sum), (edge.dstId, sum))
461
- }
462
- }
463
- verticesTMP.reduceByKey(_ + _)
464
- }
465
-
466
- val docTermVertices = createVertices()
467
-
468
- // Partition such that edges are grouped by document
469
- val graph = Graph (docTermVertices, edges)
470
- .partitionBy(PartitionStrategy .EdgePartition1D )
471
-
472
- new EMOptimizer (graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
473
- }
474
-
475
366
}
0 commit comments