Skip to content

Commit 2e682c0

Browse files
committed
take discount on previous weights; use BLAS; detect dying clusters
1 parent 0411bf5 commit 2e682c0

File tree

3 files changed

+148
-91
lines changed

3 files changed

+148
-91
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
5050
object StreamingKMeans {
5151

5252
def main(args: Array[String]) {
53-
5453
if (args.length != 5) {
5554
System.err.println(
5655
"Usage: StreamingKMeans " +
@@ -67,14 +66,12 @@ object StreamingKMeans {
6766
val model = new StreamingKMeans()
6867
.setK(args(3).toInt)
6968
.setDecayFactor(1.0)
70-
.setRandomCenters(args(4).toInt)
69+
.setRandomCenters(args(4).toInt, 0.0)
7170

7271
model.trainOn(trainingData)
7372
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
7473

7574
ssc.start()
7675
ssc.awaitTermination()
77-
7876
}
79-
8077
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala

Lines changed: 92 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,15 @@ package org.apache.spark.mllib.clustering
1919

2020
import scala.reflect.ClassTag
2121

22-
import breeze.linalg.{Vector => BV}
23-
24-
import org.apache.spark.annotation.DeveloperApi
2522
import org.apache.spark.Logging
26-
import org.apache.spark.mllib.linalg.{Vectors, Vector}
27-
import org.apache.spark.rdd.RDD
2823
import org.apache.spark.SparkContext._
29-
import org.apache.spark.streaming.dstream.DStream
24+
import org.apache.spark.annotation.DeveloperApi
25+
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
26+
import org.apache.spark.rdd.RDD
3027
import org.apache.spark.streaming.StreamingContext._
28+
import org.apache.spark.streaming.dstream.DStream
3129
import org.apache.spark.util.Utils
30+
import org.apache.spark.util.random.XORShiftRandom
3231

3332
/**
3433
* :: DeveloperApi ::
@@ -66,55 +65,81 @@ import org.apache.spark.util.Utils
6665
@DeveloperApi
6766
class StreamingKMeansModel(
6867
override val clusterCenters: Array[Vector],
69-
val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) with Logging {
68+
val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {
7069

7170
/** Perform a k-means update on a batch of data. */
7271
def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
7372

74-
val centers = clusterCenters
75-
val counts = clusterCounts
76-
7773
// find nearest cluster to each point
78-
val closest = data.map(point => (this.predict(point), (point.toBreeze, 1.toLong)))
74+
val closest = data.map(point => (this.predict(point), (point, 1L)))
7975

8076
// get sums and counts for updating each cluster
81-
type WeightedPoint = (BV[Double], Long)
82-
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
83-
(p1._1 += p2._1, p1._2 + p2._2)
77+
val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
78+
BLAS.axpy(1.0, p2._1, p1._1)
79+
(p1._1, p1._2 + p2._2)
8480
}
85-
val pointStats: Array[(Int, (BV[Double], Long))] =
86-
closest.reduceByKey(mergeContribs).collect()
81+
val dim = clusterCenters(0).size
82+
val pointStats: Array[(Int, (Vector, Long))] = closest
83+
.aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
84+
.collect()
85+
86+
val discount = timeUnit match {
87+
case StreamingKMeans.BATCHES => decayFactor
88+
case StreamingKMeans.POINTS =>
89+
val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
90+
n
91+
}.sum
92+
math.pow(decayFactor, numNewPoints)
93+
}
94+
95+
// apply discount to weights
96+
BLAS.scal(discount, Vectors.dense(clusterWeights))
8797

8898
// implement update rule
89-
pointStats.foreach { case (label, (mean, count)) =>
90-
// store old count and centroid
91-
val oldCount = counts(label)
92-
val oldCentroid = centers(label).toBreeze
93-
// get new count and centroid
94-
val newCount = count
95-
val newCentroid = mean / newCount.toDouble
96-
// compute the normalized scale factor that controls forgetting
97-
val lambda = timeUnit match {
98-
case "batches" => newCount / (decayFactor * oldCount + newCount)
99-
case "points" => newCount / (math.pow(decayFactor, newCount) * oldCount + newCount)
100-
}
101-
// perform the update
102-
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda
103-
// store the new counts and centers
104-
counts(label) = oldCount + newCount
105-
centers(label) = Vectors.fromBreeze(updatedCentroid)
99+
pointStats.foreach { case (label, (sum, count)) =>
100+
val centroid = clusterCenters(label)
101+
102+
val updatedWeight = clusterWeights(label) + count
103+
val lambda = count / math.max(updatedWeight, 1e-16)
104+
105+
clusterWeights(label) = updatedWeight
106+
BLAS.scal(1.0 - lambda, centroid)
107+
BLAS.axpy(lambda / count, sum, centroid)
106108

107109
// display the updated cluster centers
108-
val display = centers(label).size match {
109-
case x if x > 100 => centers(label).toArray.take(100).mkString("[", ",", "...")
110-
case _ => centers(label).toArray.mkString("[", ",", "]")
110+
val display = clusterCenters(label).size match {
111+
case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
112+
case _ => centroid.toArray.mkString("[", ",", "]")
113+
}
114+
115+
logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
116+
}
117+
118+
// Check whether the smallest cluster is dying. If so, split the largest cluster.
119+
val weightsWithIndex = clusterWeights.view.zipWithIndex
120+
val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
121+
val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
122+
if (minWeight < 1e-8 * maxWeight) {
123+
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
124+
val weight = (maxWeight + minWeight) / 2.0
125+
clusterWeights(largest) = weight
126+
clusterWeights(smallest) = weight
127+
val largestClusterCenter = clusterCenters(largest)
128+
val smallestClusterCenter = clusterCenters(smallest)
129+
var j = 0
130+
while (j < dim) {
131+
val x = largestClusterCenter(j)
132+
val p = 1e-14 * math.max(math.abs(x), 1.0)
133+
largestClusterCenter.toBreeze(j) = x + p
134+
smallestClusterCenter.toBreeze(j) = x - p
135+
j += 1
111136
}
112-
logInfo("Cluster %d updated: %s ".format (label, display))
113137
}
114-
new StreamingKMeansModel(centers, counts)
115-
}
116138

139+
this
140+
}
117141
}
142+
118143
/**
119144
* :: DeveloperApi ::
120145
* StreamingKMeans provides methods for configuring a
@@ -128,7 +153,7 @@ class StreamingKMeansModel(
128153
* val model = new StreamingKMeans()
129154
* .setDecayFactor(0.5)
130155
* .setK(3)
131-
* .setRandomCenters(5)
156+
* .setRandomCenters(5, 100.0)
132157
* .trainOn(DStream)
133158
*/
134159
@DeveloperApi
@@ -137,9 +162,9 @@ class StreamingKMeans(
137162
var decayFactor: Double,
138163
var timeUnit: String) extends Logging {
139164

140-
protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
165+
def this() = this(2, 1.0, StreamingKMeans.BATCHES)
141166

142-
def this() = this(2, 1.0, "batches")
167+
protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
143168

144169
/** Set the number of clusters. */
145170
def setK(k: Int): this.type = {
@@ -155,7 +180,7 @@ class StreamingKMeans(
155180

156181
/** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
157182
def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
158-
if (timeUnit != "batches" && timeUnit != "points") {
183+
if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
159184
throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
160185
}
161186
this.decayFactor = math.exp(math.log(0.5) / halfLife)
@@ -165,26 +190,23 @@ class StreamingKMeans(
165190
}
166191

167192
/** Specify initial centers directly. */
168-
def setInitialCenters(initialCenters: Array[Vector]): this.type = {
169-
val clusterCounts = new Array[Long](this.k)
170-
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
193+
def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
194+
model = new StreamingKMeansModel(centers, weights)
171195
this
172196
}
173197

174-
/** Initialize random centers, requiring only the number of dimensions.
175-
*
176-
* @param dim Number of dimensions
177-
* @param seed Random seed
178-
* */
179-
def setRandomCenters(dim: Int, seed: Long = Utils.random.nextLong): this.type = {
180-
181-
val random = Utils.random
182-
random.setSeed(seed)
183-
184-
val initialCenters = (0 until k)
185-
.map(_ => Vectors.dense(Array.fill(dim)(random.nextGaussian()))).toArray
186-
val clusterCounts = new Array[Long](this.k)
187-
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
198+
/**
199+
* Initialize random centers, requiring only the number of dimensions.
200+
*
201+
* @param dim Number of dimensions
202+
* @param weight Weight for each center
203+
* @param seed Random seed
204+
*/
205+
def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
206+
val random = new XORShiftRandom(seed)
207+
val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
208+
val weights = Array.fill(k)(weight)
209+
model = new StreamingKMeansModel(centers, weights)
188210
this
189211
}
190212

@@ -202,9 +224,9 @@ class StreamingKMeans(
202224
* @param data DStream containing vector data
203225
*/
204226
def trainOn(data: DStream[Vector]) {
205-
this.assertInitialized()
227+
assertInitialized()
206228
data.foreachRDD { (rdd, time) =>
207-
model = model.update(rdd, this.decayFactor, this.timeUnit)
229+
model = model.update(rdd, decayFactor, timeUnit)
208230
}
209231
}
210232

@@ -215,7 +237,7 @@ class StreamingKMeans(
215237
* @return DStream containing predictions
216238
*/
217239
def predictOn(data: DStream[Vector]): DStream[Int] = {
218-
this.assertInitialized()
240+
assertInitialized()
219241
data.map(model.predict)
220242
}
221243

@@ -227,16 +249,20 @@ class StreamingKMeans(
227249
* @return DStream containing the input keys and the predictions as values
228250
*/
229251
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
230-
this.assertInitialized()
252+
assertInitialized()
231253
data.mapValues(model.predict)
232254
}
233255

234256
/** Check whether cluster centers have been initialized. */
235-
def assertInitialized(): Unit = {
236-
if (Option(model.clusterCenters) == None) {
257+
private[this] def assertInitialized(): Unit = {
258+
if (model.clusterCenters == null) {
237259
throw new IllegalStateException(
238260
"Initial cluster centers must be set before starting predictions")
239261
}
240262
}
263+
}
241264

265+
private[clustering] object StreamingKMeans {
266+
final val BATCHES = "batches"
267+
final val POINTS = "points"
242268
}

0 commit comments

Comments
 (0)