Skip to content

Commit 9cfc301

Browse files
committed
Make random seed an argument
1 parent 44050a9 commit 9cfc301

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,19 @@ class StreamingKMeans(
183183
this
184184
}
185185

186-
/** Initialize random centers, requiring only the number of dimensions. */
187-
def setRandomCenters(d: Int): this.type = {
188-
val initialCenters = (0 until k).map(_ => Vectors.dense(Array.fill(d)(nextGaussian()))).toArray
189-
val clusterCounts = Array.fill(this.k)(0).map(_.toLong)
186+
/** Initialize random centers, requiring only the number of dimensions.
187+
*
188+
* @param dim Number of dimensions
189+
* @param seed Random seed
190+
* */
191+
def setRandomCenters(dim: Int, seed: Long = Utils.random.nextLong): this.type = {
192+
193+
val random = Utils.random
194+
random.setSeed(seed)
195+
196+
val initialCenters = (0 until k)
197+
.map(_ => Vectors.dense(Array.fill(dim)(random.nextGaussian()))).toArray
198+
val clusterCounts = new Array[Long](this.k)
190199
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
191200
this
192201
}

0 commit comments

Comments
 (0)