Skip to content

Commit 7450a99

Browse files
str-janusmengxr
authored andcommitted
[SPARK-4749] [mllib]: Allow initializing KMeans clusters using a seed
This implements the functionality for SPARK-4749 and provides units tests in Scala and PySpark Author: nate.crosswhite <[email protected]> Author: nxwhite-str <[email protected]> Author: Xiangrui Meng <[email protected]> Closes apache#3610 from nxwhite-str/master and squashes the following commits: a2ebbd3 [nxwhite-str] Merge pull request #1 from mengxr/SPARK-4749-kmeans-seed 7668124 [Xiangrui Meng] minor updates f8d5928 [nate.crosswhite] Addressing PR issues 277d367 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 9156a57 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 5d087b4 [nate.crosswhite] Adding KMeans train with seed and Scala unit test 616d111 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 35c1884 [nate.crosswhite] Add kmeans initial seed to pyspark API
1 parent aa1e22b commit 7450a99

File tree

5 files changed

+84
-12
lines changed

5 files changed

+84
-12
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable {
266266
k: Int,
267267
maxIterations: Int,
268268
runs: Int,
269-
initializationMode: String): KMeansModel = {
269+
initializationMode: String,
270+
seed: java.lang.Long): KMeansModel = {
270271
val kMeansAlg = new KMeans()
271272
.setK(k)
272273
.setMaxIterations(maxIterations)
273274
.setRuns(runs)
274275
.setInitializationMode(initializationMode)
276+
277+
if (seed != null) kMeansAlg.setSeed(seed)
278+
275279
try {
276280
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
277281
} finally {

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22-
import org.apache.spark.annotation.Experimental
2322
import org.apache.spark.Logging
24-
import org.apache.spark.SparkContext._
23+
import org.apache.spark.annotation.Experimental
2524
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2625
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
2726
import org.apache.spark.mllib.util.MLUtils
2827
import org.apache.spark.rdd.RDD
2928
import org.apache.spark.storage.StorageLevel
29+
import org.apache.spark.util.Utils
3030
import org.apache.spark.util.random.XORShiftRandom
3131

3232
/**
@@ -43,13 +43,14 @@ class KMeans private (
4343
private var runs: Int,
4444
private var initializationMode: String,
4545
private var initializationSteps: Int,
46-
private var epsilon: Double) extends Serializable with Logging {
46+
private var epsilon: Double,
47+
private var seed: Long) extends Serializable with Logging {
4748

4849
/**
4950
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
50-
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}.
51+
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.
5152
*/
52-
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
53+
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong())
5354

5455
/** Set the number of clusters to create (k). Default: 2. */
5556
def setK(k: Int): this.type = {
@@ -112,6 +113,12 @@ class KMeans private (
112113
this
113114
}
114115

116+
/** Set the random seed for cluster initialization. */
117+
def setSeed(seed: Long): this.type = {
118+
this.seed = seed
119+
this
120+
}
121+
115122
/**
116123
* Train a K-means model on the given set of points; `data` should be cached for high
117124
* performance, because this is an iterative algorithm.
@@ -255,7 +262,7 @@ class KMeans private (
255262
private def initRandom(data: RDD[VectorWithNorm])
256263
: Array[Array[VectorWithNorm]] = {
257264
// Sample all the cluster centers in one pass to avoid repeated scans
258-
val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
265+
val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq
259266
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
260267
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
261268
}.toArray)
@@ -273,7 +280,7 @@ class KMeans private (
273280
private def initKMeansParallel(data: RDD[VectorWithNorm])
274281
: Array[Array[VectorWithNorm]] = {
275282
// Initialize each run's center to a random point
276-
val seed = new XORShiftRandom().nextInt()
283+
val seed = new XORShiftRandom(this.seed).nextInt()
277284
val sample = data.takeSample(true, runs, seed).toSeq
278285
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
279286

@@ -333,7 +340,32 @@ object KMeans {
333340
/**
334341
* Trains a k-means model using the given set of parameters.
335342
*
336-
* @param data training points stored as `RDD[Array[Double]]`
343+
* @param data training points stored as `RDD[Vector]`
344+
* @param k number of clusters
345+
* @param maxIterations max number of iterations
346+
* @param runs number of parallel runs, defaults to 1. The best model is returned.
347+
* @param initializationMode initialization model, either "random" or "k-means||" (default).
348+
* @param seed random seed value for cluster initialization
349+
*/
350+
def train(
351+
data: RDD[Vector],
352+
k: Int,
353+
maxIterations: Int,
354+
runs: Int,
355+
initializationMode: String,
356+
seed: Long): KMeansModel = {
357+
new KMeans().setK(k)
358+
.setMaxIterations(maxIterations)
359+
.setRuns(runs)
360+
.setInitializationMode(initializationMode)
361+
.setSeed(seed)
362+
.run(data)
363+
}
364+
365+
/**
366+
* Trains a k-means model using the given set of parameters.
367+
*
368+
* @param data training points stored as `RDD[Vector]`
337369
* @param k number of clusters
338370
* @param maxIterations max number of iterations
339371
* @param runs number of parallel runs, defaults to 1. The best model is returned.

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
9090
assert(model.clusterCenters.size === 3)
9191
}
9292

93+
test("deterministic initialization") {
94+
// Create a large-ish set of points for clustering
95+
val points = List.tabulate(1000)(n => Vectors.dense(n, n))
96+
val rdd = sc.parallelize(points, 3)
97+
98+
for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
99+
// Create three deterministic models and compare cluster means
100+
val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
101+
initializationMode = initMode, seed = 42)
102+
val centers1 = model1.clusterCenters
103+
104+
val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
105+
initializationMode = initMode, seed = 42)
106+
val centers2 = model2.clusterCenters
107+
108+
centers1.zip(centers2).foreach { case (c1, c2) =>
109+
assert(c1 ~== c2 absTol 1E-14)
110+
}
111+
}
112+
}
113+
93114
test("single cluster with big dataset") {
94115
val smallData = Array(
95116
Vectors.dense(1.0, 2.0, 6.0),

python/pyspark/mllib/clustering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ def predict(self, x):
7878
class KMeans(object):
7979

8080
@classmethod
81-
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
81+
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None):
8282
"""Train a k-means clustering model."""
8383
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
84-
runs, initializationMode)
84+
runs, initializationMode, seed)
8585
centers = callJavaFunc(rdd.context, model.clusterCenters)
8686
return KMeansModel([c.toArray() for c in centers])
8787

python/pyspark/mllib/tests.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class ListTests(PySparkTestCase):
140140
as NumPy arrays.
141141
"""
142142

143-
def test_clustering(self):
143+
def test_kmeans(self):
144144
from pyspark.mllib.clustering import KMeans
145145
data = [
146146
[0, 1.1],
@@ -152,6 +152,21 @@ def test_clustering(self):
152152
self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1]))
153153
self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3]))
154154

155+
def test_kmeans_deterministic(self):
156+
from pyspark.mllib.clustering import KMeans
157+
X = range(0, 100, 10)
158+
Y = range(0, 100, 10)
159+
data = [[x, y] for x, y in zip(X, Y)]
160+
clusters1 = KMeans.train(self.sc.parallelize(data),
161+
3, initializationMode="k-means||", seed=42)
162+
clusters2 = KMeans.train(self.sc.parallelize(data),
163+
3, initializationMode="k-means||", seed=42)
164+
centers1 = clusters1.centers
165+
centers2 = clusters2.centers
166+
for c1, c2 in zip(centers1, centers2):
167+
# TODO: Allow small numeric difference.
168+
self.assertTrue(array_equal(c1, c2))
169+
155170
def test_classification(self):
156171
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
157172
from pyspark.mllib.tree import DecisionTree

0 commit comments

Comments
 (0)