Skip to content

Commit 68c2318

Browse files
committed
add a java ut
1 parent 4041723 commit 68c2318

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,37 @@ public void distributedLDAModel() {
109109
assert(model.logPrior() < 0.0);
110110
}
111111

112+
113+
@Test
114+
public void OnlineOptimizerCompatibility() {
115+
int k = 3;
116+
double topicSmoothing = 1.2;
117+
double termSmoothing = 1.2;
118+
119+
// Train a model
120+
OnlineLDAOptimizer op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51)
121+
.setGammaShape(1e40).setMiniBatchFraction(0.5);
122+
LDA lda = new LDA();
123+
lda.setK(k)
124+
.setDocConcentration(topicSmoothing)
125+
.setTopicConcentration(termSmoothing)
126+
.setMaxIterations(5)
127+
.setSeed(12345)
128+
.setOptimizer(op);
129+
130+
LDAModel model = lda.run(corpus);
131+
132+
// Check: basic parameters
133+
assertEquals(model.k(), k);
134+
assertEquals(model.vocabSize(), tinyVocabSize);
135+
136+
// Check: topic summaries
137+
Tuple2<int[], double[]>[] roundedTopicSummary = model.describeTopics();
138+
assertEquals(roundedTopicSummary.length, k);
139+
Tuple2<int[], double[]>[] roundedLocalTopicSummary = model.describeTopics();
140+
assertEquals(roundedLocalTopicSummary.length, k);
141+
}
142+
112143
private static int tinyK = LDASuite$.MODULE$.tinyK();
113144
private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
114145
private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();

0 commit comments

Comments
 (0)