Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 0d0f3ee

Browse files
committed
replace random split with sliding
1 parent fa408a8 commit 0d0f3ee

File tree

1 file changed

+2
-2
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/clustering

1 file changed

+2
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
3232
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices}
3333
import org.apache.spark.rdd.RDD
3434
import org.apache.spark.util.Utils
35+
import org.apache.spark.mllib.rdd.RDDFunctions._
3536

3637

3738
/**
@@ -430,8 +431,7 @@ private[clustering] object LDA {
430431
else if (D / 1000 < 4) 4
431432
else D / 1000
432433
val batchNumber = (D/batchSize + 1).toInt
433-
// todo: performance killer, need to be replaced
434-
private val batches = documents.randomSplit(Array.fill[Double](batchNumber)(1.0))
434+
private val batches = documents.sliding(batchNumber).collect()
435435

436436
// Initialize the variational distribution q(beta|lambda)
437437
var _lambda = getGammaMatrix(k, vocabSize) // K * V

0 commit comments

Comments
 (0)