Skip to content

Commit 785734c

Browse files
hhbyyhliuchuanqi
authored andcommitted
[SPARK-7090] [MLLIB] Introduce LDAOptimizer to LDA to further improve extensibility
jira: https://issues.apache.org/jira/browse/SPARK-7090 LDA was implemented with extensibility in mind. And with the development of OnlineLDA and Gibbs Sampling, we are collecting more detailed requirements from different algorithms. As Joseph Bradley jkbradley proposed in apache#4807 and with some further discussion, we'd like to adjust the code structure a little to present the common interface and extension point clearly. Basically class LDA would be a common entrance for LDA computing. And each LDA object will refer to a LDAOptimizer for the concrete algorithm implementation. Users can customize LDAOptimizer with specific parameters and assign it to LDA. Concrete changes: 1. Add a trait `LDAOptimizer`, which defines the common iterface for concrete implementations. Each subClass is a wrapper for a specific LDA algorithm. 2. Move EMOptimizer to file LDAOptimizer and inherits from LDAOptimizer, rename to EMLDAOptimizer. (in case a more generic EMOptimizer comes in the future) -adjust the constructor of EMOptimizer, since all the parameters should be passed in through initialState method. This can avoid unwanted confusion or overwrite. -move the code from LDA.initalState to initalState of EMLDAOptimizer 3. Add property ldaOptimizer to LDA and its getter/setter, and EMLDAOptimizer is the default Optimizer. 4. Change the return type of LDA.run from DistributedLDAModel to LDAModel. Further work: add OnlineLDAOptimizer and other possible Optimizers once ready. Author: Yuhao Yang <[email protected]> Closes apache#5661 from hhbyyh/ldaRefactor and squashes the following commits: 0e2e006 [Yuhao Yang] respond to review comments 08a45da [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor e756ce4 [Yuhao Yang] solve mima exception d74fd8f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor 0bb8400 [Yuhao Yang] refactor LDA with Optimizer ec2f857 [Yuhao Yang] protoptype for discussion
1 parent 1bab377 commit 785734c

File tree

8 files changed

+256
-151
lines changed

8 files changed

+256
-151
lines changed

examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
5858
corpus.cache();
5959

6060
// Cluster the documents into three topics using LDA
61-
DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
61+
DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus);
6262

6363
// Output topics. Each is a distribution over words (matching word count vectors)
6464
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()

examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scopt.OptionParser
2626
import org.apache.log4j.{Level, Logger}
2727

2828
import org.apache.spark.{SparkContext, SparkConf}
29-
import org.apache.spark.mllib.clustering.LDA
29+
import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
3030
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3131
import org.apache.spark.rdd.RDD
3232

@@ -137,7 +137,7 @@ object LDAExample {
137137
sc.setCheckpointDir(params.checkpointDir.get)
138138
}
139139
val startTime = System.nanoTime()
140-
val ldaModel = lda.run(corpus)
140+
val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
141141
val elapsed = (System.nanoTime() - startTime) / 1e9
142142

143143
println(s"Finished training LDA model. Summary:")

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 36 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,11 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20-
import java.util.Random
21-
22-
import breeze.linalg.{DenseVector => BDV, normalize}
23-
20+
import breeze.linalg.{DenseVector => BDV}
2421
import org.apache.spark.Logging
2522
import org.apache.spark.annotation.Experimental
2623
import org.apache.spark.api.java.JavaPairRDD
2724
import org.apache.spark.graphx._
28-
import org.apache.spark.graphx.impl.GraphImpl
29-
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
3025
import org.apache.spark.mllib.linalg.Vector
3126
import org.apache.spark.rdd.RDD
3227
import org.apache.spark.util.Utils
@@ -42,16 +37,9 @@ import org.apache.spark.util.Utils
4237
* - "token": instance of a term appearing in a document
4338
* - "topic": multinomial distribution over words representing some concept
4439
*
45-
* Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
46-
* according to the Asuncion et al. (2009) paper referenced below.
47-
*
4840
* References:
4941
* - Original LDA paper (journal version):
5042
* Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
51-
* - This class implements their "smoothed" LDA model.
52-
* - Paper which clearly explains several algorithms, including EM:
53-
* Asuncion, Welling, Smyth, and Teh.
54-
* "On Smoothing and Inference for Topic Models." UAI, 2009.
5543
*
5644
* @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
5745
* (Wikipedia)]]
@@ -63,10 +51,11 @@ class LDA private (
6351
private var docConcentration: Double,
6452
private var topicConcentration: Double,
6553
private var seed: Long,
66-
private var checkpointInterval: Int) extends Logging {
54+
private var checkpointInterval: Int,
55+
private var ldaOptimizer: LDAOptimizer) extends Logging {
6756

6857
def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
69-
seed = Utils.random.nextLong(), checkpointInterval = 10)
58+
seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer)
7059

7160
/**
7261
* Number of topics to infer. I.e., the number of soft cluster centers.
@@ -220,6 +209,32 @@ class LDA private (
220209
this
221210
}
222211

212+
213+
/** LDAOptimizer used to perform the actual calculation */
214+
def getOptimizer: LDAOptimizer = ldaOptimizer
215+
216+
/**
217+
* LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer)
218+
*/
219+
def setOptimizer(optimizer: LDAOptimizer): this.type = {
220+
this.ldaOptimizer = optimizer
221+
this
222+
}
223+
224+
/**
225+
* Set the LDAOptimizer used to perform the actual calculation by algorithm name.
226+
* Currently "em" is supported.
227+
*/
228+
def setOptimizer(optimizerName: String): this.type = {
229+
this.ldaOptimizer =
230+
optimizerName.toLowerCase match {
231+
case "em" => new EMLDAOptimizer
232+
case other =>
233+
throw new IllegalArgumentException(s"Only em is supported but got $other.")
234+
}
235+
this
236+
}
237+
223238
/**
224239
* Learn an LDA model using the given dataset.
225240
*
@@ -229,9 +244,9 @@ class LDA private (
229244
* Document IDs must be unique and >= 0.
230245
* @return Inferred LDA model
231246
*/
232-
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
233-
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
234-
checkpointInterval)
247+
def run(documents: RDD[(Long, Vector)]): LDAModel = {
248+
val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration,
249+
seed, checkpointInterval)
235250
var iter = 0
236251
val iterationTimes = Array.fill[Double](maxIterations)(0)
237252
while (iter < maxIterations) {
@@ -241,12 +256,11 @@ class LDA private (
241256
iterationTimes(iter) = elapsedSeconds
242257
iter += 1
243258
}
244-
state.graphCheckpointer.deleteAllCheckpoints()
245-
new DistributedLDAModel(state, iterationTimes)
259+
state.getLDAModel(iterationTimes)
246260
}
247261

248262
/** Java-friendly version of [[run()]] */
249-
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
263+
def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
250264
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
251265
}
252266
}
@@ -320,88 +334,10 @@ private[clustering] object LDA {
320334

321335
private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
322336

323-
/**
324-
* Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
325-
*
326-
* @param graph EM graph, storing current parameter estimates in vertex descriptors and
327-
* data (token counts) in edge descriptors.
328-
* @param k Number of topics
329-
* @param vocabSize Number of unique terms
330-
* @param docConcentration "alpha"
331-
* @param topicConcentration "beta" or "eta"
332-
*/
333-
private[clustering] class EMOptimizer(
334-
var graph: Graph[TopicCounts, TokenCount],
335-
val k: Int,
336-
val vocabSize: Int,
337-
val docConcentration: Double,
338-
val topicConcentration: Double,
339-
checkpointInterval: Int) {
340-
341-
private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
342-
graph, checkpointInterval)
343-
344-
def next(): EMOptimizer = {
345-
val eta = topicConcentration
346-
val W = vocabSize
347-
val alpha = docConcentration
348-
349-
val N_k = globalTopicTotals
350-
val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
351-
(edgeContext) => {
352-
// Compute N_{wj} gamma_{wjk}
353-
val N_wj = edgeContext.attr
354-
// E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
355-
// N_{wj}.
356-
val scaledTopicDistribution: TopicCounts =
357-
computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
358-
edgeContext.sendToDst((false, scaledTopicDistribution))
359-
edgeContext.sendToSrc((false, scaledTopicDistribution))
360-
}
361-
// This is a hack to detect whether we could modify the values in-place.
362-
// TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
363-
val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
364-
(m0, m1) => {
365-
val sum =
366-
if (m0._1) {
367-
m0._2 += m1._2
368-
} else if (m1._1) {
369-
m1._2 += m0._2
370-
} else {
371-
m0._2 + m1._2
372-
}
373-
(true, sum)
374-
}
375-
// M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
376-
val docTopicDistributions: VertexRDD[TopicCounts] =
377-
graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
378-
.mapValues(_._2)
379-
// Update the vertex descriptors with the new counts.
380-
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
381-
graph = newGraph
382-
graphCheckpointer.updateGraph(newGraph)
383-
globalTopicTotals = computeGlobalTopicTotals()
384-
this
385-
}
386-
387-
/**
388-
* Aggregate distributions over topics from all term vertices.
389-
*
390-
* Note: This executes an action on the graph RDDs.
391-
*/
392-
var globalTopicTotals: TopicCounts = computeGlobalTopicTotals()
393-
394-
private def computeGlobalTopicTotals(): TopicCounts = {
395-
val numTopics = k
396-
graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
397-
}
398-
399-
}
400-
401337
/**
402338
* Compute gamma_{wjk}, a distribution over topics k.
403339
*/
404-
private def computePTopic(
340+
private[clustering] def computePTopic(
405341
docTopicCounts: TopicCounts,
406342
termTopicCounts: TopicCounts,
407343
totalTopicCounts: TopicCounts,
@@ -427,49 +363,4 @@ private[clustering] object LDA {
427363
// normalize
428364
BDV(gamma_wj) /= sum
429365
}
430-
431-
/**
432-
* Compute bipartite term/doc graph.
433-
*/
434-
private def initialState(
435-
docs: RDD[(Long, Vector)],
436-
k: Int,
437-
docConcentration: Double,
438-
topicConcentration: Double,
439-
randomSeed: Long,
440-
checkpointInterval: Int): EMOptimizer = {
441-
// For each document, create an edge (Document -> Term) for each unique term in the document.
442-
val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
443-
// Add edges for terms with non-zero counts.
444-
termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
445-
Edge(docID, term2index(term), cnt)
446-
}
447-
}
448-
449-
val vocabSize = docs.take(1).head._2.size
450-
451-
// Create vertices.
452-
// Initially, we use random soft assignments of tokens to topics (random gamma).
453-
def createVertices(): RDD[(VertexId, TopicCounts)] = {
454-
val verticesTMP: RDD[(VertexId, TopicCounts)] =
455-
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
456-
val random = new Random(partIndex + randomSeed)
457-
partEdges.flatMap { edge =>
458-
val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
459-
val sum = gamma * edge.attr
460-
Seq((edge.srcId, sum), (edge.dstId, sum))
461-
}
462-
}
463-
verticesTMP.reduceByKey(_ + _)
464-
}
465-
466-
val docTermVertices = createVertices()
467-
468-
// Partition such that edges are grouped by document
469-
val graph = Graph(docTermVertices, edges)
470-
.partitionBy(PartitionStrategy.EdgePartition1D)
471-
472-
new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
473-
}
474-
475366
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class DistributedLDAModel private (
203203

204204
import LDA._
205205

206-
private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = {
206+
private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = {
207207
this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
208208
state.topicConcentration, iterationTimes)
209209
}

0 commit comments

Comments
 (0)