Skip to content

Commit d19ef55

Browse files
committed
change OnlineLDA to class
1 parent 97b9e1a commit d19ef55

File tree

1 file changed

+72
-31
lines changed

1 file changed

+72
-31
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/OnlineLDA.scala

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import org.apache.spark.annotation.Experimental
2424
import org.apache.spark.mllib.linalg._
2525
import org.apache.spark.rdd.RDD
2626

27-
2827
/**
2928
* :: Experimental ::
3029
* Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
@@ -37,7 +36,58 @@ import org.apache.spark.rdd.RDD
3736
* Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
3837
*/
3938
@Experimental
40-
object OnlineLDA{
39+
class OnlineLDA(
40+
private var k: Int,
41+
private var numIterations: Int,
42+
private var miniBatchFraction: Double,
43+
private var tau_0: Double,
44+
private var kappa: Double) {
45+
46+
def this() = this(k = 10, numIterations = 100, miniBatchFraction = 0.01,
47+
tau_0 = 1024, kappa = 0.5)
48+
49+
/**
50+
* Number of topics to infer. I.e., the number of soft cluster centers.
51+
* (default = 10)
52+
*/
53+
def setK(k: Int): this.type = {
54+
require(k > 0, s"OnlineLDA k (number of clusters) must be > 0, but was set to $k")
55+
this.k = k
56+
this
57+
}
58+
59+
/**
60+
* Set the number of iterations for OnlineLDA. Default 100.
61+
*/
62+
def setNumIterations(iters: Int): this.type = {
63+
this.numIterations = iters
64+
this
65+
}
66+
67+
/**
68+
* Set fraction of data to be used for each iteration. Default 0.01.
69+
*/
70+
def setMiniBatchFraction(fraction: Double): this.type = {
71+
this.miniBatchFraction = fraction
72+
this
73+
}
74+
75+
/**
76+
* A (positive) learning parameter that downweights early iterations. Default 1024.
77+
*/
78+
def setTau_0(t: Double): this.type = {
79+
this.tau_0 = t
80+
this
81+
}
82+
83+
/**
84+
* Learning rate: exponential decay rate. Default 0.5.
85+
*/
86+
def setKappa(kappa: Double): this.type = {
87+
this.kappa = kappa
88+
this
89+
}
90+
4191

4292
/**
4393
* Learns an LDA model from the given data set, using online variational Bayes (VB) algorithm.
@@ -49,33 +99,18 @@ object OnlineLDA{
4999
* The term count vectors are "bags of words" with a fixed-size vocabulary
50100
* (where the vocabulary size is the length of the vector).
51101
* Document IDs must be unique and >= 0.
52-
* @param k Number of topics to infer.
53-
* @param batchNumber Number of batches to split input corpus. For each batch, recommendation
54-
* size is [4, 16384]. -1 for automatic batchNumber.
55102
* @return Inferred LDA model
56103
*/
57-
def run(documents: RDD[(Long, Vector)], k: Int, batchNumber: Int = -1): LDAModel = {
58-
require(batchNumber > 0 || batchNumber == -1,
59-
s"batchNumber must be greater or -1, but was set to $batchNumber")
60-
require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k")
61-
104+
def run(documents: RDD[(Long, Vector)]): LDAModel = {
62105
val vocabSize = documents.first._2.size
63106
val D = documents.count().toInt // total documents count
64-
val batchSize =
65-
if (batchNumber == -1) { // auto mode
66-
if (D / 100 > 16384) 16384
67-
else if (D / 100 < 4) 4
68-
else D / 100
69-
}
70-
else {
71-
D / batchNumber
72-
}
73-
74-
val onlineLDA = new OnlineLDAOptimizer(k, D, vocabSize)
75-
val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt
76-
for(i <- 1 to actualBatchNumber){
77-
val batch = documents.sample(true, batchSize.toDouble / D)
78-
onlineLDA.submitMiniBatch(batch)
107+
val onlineLDA = new OnlineLDAOptimizer(k, D, vocabSize, tau_0, kappa)
108+
109+
val arr = Array.fill(math.ceil(1.0 / miniBatchFraction).toInt)(miniBatchFraction)
110+
for(i <- 0 until numIterations){
111+
val splits = documents.randomSplit(arr)
112+
val index = i % splits.size
113+
onlineLDA.submitMiniBatch(splits(index))
79114
}
80115
onlineLDA.getTopicDistribution()
81116
}
@@ -93,10 +128,12 @@ object OnlineLDA{
93128
* Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
94129
*/
95130
@Experimental
96-
class OnlineLDAOptimizer (
131+
private[clustering] class OnlineLDAOptimizer (
97132
private var k: Int,
98133
private var D: Int,
99-
private val vocabSize:Int) extends Serializable {
134+
private val vocabSize: Int,
135+
private val tau_0: Double,
136+
private val kappa: Double) extends Serializable {
100137

101138
// Initialize the variational distribution q(beta|lambda)
102139
private var lambda = getGammaMatrix(k, vocabSize) // K * V
@@ -115,7 +152,11 @@ class OnlineLDAOptimizer (
115152
* Document IDs must be unique and >= 0.
116153
* @return Inferred LDA model
117154
*/
118-
def submitMiniBatch(documents: RDD[(Long, Vector)]): Unit = {
155+
private[clustering] def submitMiniBatch(documents: RDD[(Long, Vector)]): Unit = {
156+
if(documents.isEmpty()){
157+
return
158+
}
159+
119160
var stat = BDM.zeros[Double](k, vocabSize)
120161
stat = documents.treeAggregate(stat)(gradient, _ += _)
121162
update(stat, i, documents.count().toInt)
@@ -125,13 +166,13 @@ class OnlineLDAOptimizer (
125166
/**
126167
* get the topic-term distribution
127168
*/
128-
def getTopicDistribution(): LDAModel ={
169+
private[clustering] def getTopicDistribution(): LDAModel ={
129170
new LocalLDAModel(Matrices.fromBreeze(lambda).transpose)
130171
}
131172

132173
private def update(raw: BDM[Double], iter:Int, batchSize: Int): Unit ={
133-
// weight of the mini-batch. 1024 helps down weights early iterations
134-
val weight = math.pow(1024 + iter, -0.5)
174+
// weight of the mini-batch.
175+
val weight = math.pow(tau_0 + iter, -kappa)
135176

136177
// This step finishes computing the sufficient statistics for the M step
137178
val stat = raw :* expElogbeta

0 commit comments

Comments
 (0)