Skip to content

Commit 7668124

Browse files
committed
minor updates
1 parent f8d5928 commit 7668124

File tree

3 files changed

+23
-25
lines changed

3 files changed

+23
-25
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22-
import org.apache.spark.annotation.Experimental
2322
import org.apache.spark.Logging
24-
import org.apache.spark.SparkContext._
23+
import org.apache.spark.annotation.Experimental
2524
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2625
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
2726
import org.apache.spark.mllib.util.MLUtils
2827
import org.apache.spark.rdd.RDD
2928
import org.apache.spark.storage.StorageLevel
29+
import org.apache.spark.util.Utils
3030
import org.apache.spark.util.random.XORShiftRandom
3131

3232
/**
@@ -48,9 +48,9 @@ class KMeans private (
4848

4949
/**
5050
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
51-
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, System.nanoTime()}.
51+
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.
5252
*/
53-
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, System.nanoTime())
53+
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong())
5454

5555
/** Set the number of clusters to create (k). Default: 2. */
5656
def setK(k: Int): this.type = {
@@ -345,17 +345,20 @@ object KMeans {
345345
* @param maxIterations max number of iterations
346346
* @param runs number of parallel runs, defaults to 1. The best model is returned.
347347
* @param initializationMode initialization model, either "random" or "k-means||" (default).
348+
* @param seed random seed value for cluster initialization
348349
*/
349350
def train(
350351
data: RDD[Vector],
351352
k: Int,
352353
maxIterations: Int,
353354
runs: Int,
354-
initializationMode: String): KMeansModel = {
355+
initializationMode: String,
356+
seed: Long): KMeansModel = {
355357
new KMeans().setK(k)
356358
.setMaxIterations(maxIterations)
357359
.setRuns(runs)
358360
.setInitializationMode(initializationMode)
361+
.setSeed(seed)
359362
.run(data)
360363
}
361364

@@ -367,20 +370,17 @@ object KMeans {
367370
* @param maxIterations max number of iterations
368371
* @param runs number of parallel runs, defaults to 1. The best model is returned.
369372
* @param initializationMode initialization model, either "random" or "k-means||" (default).
370-
* @param seed random seed value for cluster initialization
371373
*/
372374
def train(
373375
data: RDD[Vector],
374376
k: Int,
375377
maxIterations: Int,
376378
runs: Int,
377-
initializationMode: String,
378-
seed: Long): KMeansModel = {
379+
initializationMode: String): KMeansModel = {
379380
new KMeans().setK(k)
380381
.setMaxIterations(maxIterations)
381382
.setRuns(runs)
382383
.setInitializationMode(initializationMode)
383-
.setSeed(seed)
384384
.run(data)
385385
}
386386

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,17 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
9797

9898
for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
9999
// Create three deterministic models and compare cluster means
100-
val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42)
100+
val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
101+
initializationMode = initMode, seed = 42)
101102
val centers1 = model1.clusterCenters
102103

103-
val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42)
104+
val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
105+
initializationMode = initMode, seed = 42)
104106
val centers2 = model2.clusterCenters
105107

106-
val model3 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42)
107-
val centers3 = model3.clusterCenters
108-
109-
assert(centers1.deep == centers2.deep)
110-
assert(centers1.deep == centers3.deep)
108+
centers1.zip(centers2).foreach { case (c1, c2) =>
109+
assert(c1 ~== c2 absTol 1E-14)
110+
}
111111
}
112112
}
113113

python/pyspark/mllib/tests.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class ListTests(PySparkTestCase):
117117
as NumPy arrays.
118118
"""
119119

120-
def test_clustering(self):
120+
def test_kmeans(self):
121121
from pyspark.mllib.clustering import KMeans
122122
data = [
123123
[0, 1.1],
@@ -129,7 +129,7 @@ def test_clustering(self):
129129
self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1]))
130130
self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3]))
131131

132-
def test_clustering_deterministic(self):
132+
def test_kmeans_deterministic(self):
133133
from pyspark.mllib.clustering import KMeans
134134
X = range(0, 100, 10)
135135
Y = range(0, 100, 10)
@@ -138,13 +138,11 @@ def test_clustering_deterministic(self):
138138
3, initializationMode="k-means||", seed=42)
139139
clusters2 = KMeans.train(self.sc.parallelize(data),
140140
3, initializationMode="k-means||", seed=42)
141-
clusters3 = KMeans.train(self.sc.parallelize(data),
142-
3, initializationMode="k-means||", seed=42)
143-
centers1 = array(clusters1.centers).flatten().tolist()
144-
centers2 = array(clusters2.centers).flatten().tolist()
145-
centers3 = array(clusters3.centers).flatten().tolist()
146-
self.assertListEqual(centers1, centers2)
147-
self.assertListEqual(centers1, centers3)
141+
centers1 = clusters1.centers
142+
centers2 = clusters2.centers
143+
for c1, c2 in zip(centers1, centers2):
144+
# TODO: Allow small numeric difference.
145+
self.assertTrue(array_equal(c1, c2))
148146

149147
def test_classification(self):
150148
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes

0 commit comments

Comments
 (0)