@@ -19,16 +19,15 @@ package org.apache.spark.mllib.clustering
19
19
20
20
import scala .reflect .ClassTag
21
21
22
- import breeze .linalg .{Vector => BV }
23
-
24
- import org .apache .spark .annotation .DeveloperApi
25
22
import org .apache .spark .Logging
26
- import org .apache .spark .mllib .linalg .{Vectors , Vector }
27
- import org .apache .spark .rdd .RDD
28
23
import org .apache .spark .SparkContext ._
29
- import org .apache .spark .streaming .dstream .DStream
24
+ import org .apache .spark .annotation .DeveloperApi
25
+ import org .apache .spark .mllib .linalg .{BLAS , Vector , Vectors }
26
+ import org .apache .spark .rdd .RDD
30
27
import org .apache .spark .streaming .StreamingContext ._
28
+ import org .apache .spark .streaming .dstream .DStream
31
29
import org .apache .spark .util .Utils
30
+ import org .apache .spark .util .random .XORShiftRandom
32
31
33
32
/**
34
33
* :: DeveloperApi ::
@@ -66,55 +65,81 @@ import org.apache.spark.util.Utils
66
65
@ DeveloperApi
67
66
class StreamingKMeansModel (
68
67
override val clusterCenters : Array [Vector ],
69
- val clusterCounts : Array [Long ]) extends KMeansModel (clusterCenters) with Logging {
68
+ val clusterWeights : Array [Double ]) extends KMeansModel (clusterCenters) with Logging {
70
69
71
70
/** Perform a k-means update on a batch of data. */
72
71
def update (data : RDD [Vector ], decayFactor : Double , timeUnit : String ): StreamingKMeansModel = {
73
72
74
- val centers = clusterCenters
75
- val counts = clusterCounts
76
-
77
73
// find nearest cluster to each point
78
- val closest = data.map(point => (this .predict(point), (point.toBreeze, 1 .toLong )))
74
+ val closest = data.map(point => (this .predict(point), (point, 1L )))
79
75
80
76
// get sums and counts for updating each cluster
81
- type WeightedPoint = ( BV [ Double ] , Long )
82
- def mergeContribs ( p1 : WeightedPoint , p2 : WeightedPoint ) : WeightedPoint = {
83
- (p1._1 += p2._1 , p1._2 + p2._2)
77
+ val mergeContribs : (( Vector , Long ), ( Vector , Long )) => ( Vector , Long ) = (p1, p2) => {
78
+ BLAS .axpy( 1.0 , p2._1, p1._1)
79
+ (p1._1, p1._2 + p2._2)
84
80
}
85
- val pointStats : Array [(Int , (BV [Double ], Long ))] =
86
- closest.reduceByKey(mergeContribs).collect()
81
+ val dim = clusterCenters(0 ).size
82
+ val pointStats : Array [(Int , (Vector , Long ))] = closest
83
+ .aggregateByKey((Vectors .zeros(dim), 0L ))(mergeContribs, mergeContribs)
84
+ .collect()
85
+
86
+ val discount = timeUnit match {
87
+ case StreamingKMeans .BATCHES => decayFactor
88
+ case StreamingKMeans .POINTS =>
89
+ val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
90
+ n
91
+ }.sum
92
+ math.pow(decayFactor, numNewPoints)
93
+ }
94
+
95
+ // apply discount to weights
96
+ BLAS .scal(discount, Vectors .dense(clusterWeights))
87
97
88
98
// implement update rule
89
- pointStats.foreach { case (label, (mean, count)) =>
90
- // store old count and centroid
91
- val oldCount = counts(label)
92
- val oldCentroid = centers(label).toBreeze
93
- // get new count and centroid
94
- val newCount = count
95
- val newCentroid = mean / newCount.toDouble
96
- // compute the normalized scale factor that controls forgetting
97
- val lambda = timeUnit match {
98
- case " batches" => newCount / (decayFactor * oldCount + newCount)
99
- case " points" => newCount / (math.pow(decayFactor, newCount) * oldCount + newCount)
100
- }
101
- // perform the update
102
- val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda
103
- // store the new counts and centers
104
- counts(label) = oldCount + newCount
105
- centers(label) = Vectors .fromBreeze(updatedCentroid)
99
+ pointStats.foreach { case (label, (sum, count)) =>
100
+ val centroid = clusterCenters(label)
101
+
102
+ val updatedWeight = clusterWeights(label) + count
103
+ val lambda = count / math.max(updatedWeight, 1e-16 )
104
+
105
+ clusterWeights(label) = updatedWeight
106
+ BLAS .scal(1.0 - lambda, centroid)
107
+ BLAS .axpy(lambda / count, sum, centroid)
106
108
107
109
// display the updated cluster centers
108
- val display = centers(label).size match {
109
- case x if x > 100 => centers(label).toArray.take(100 ).mkString(" [" , " ," , " ..." )
110
- case _ => centers(label).toArray.mkString(" [" , " ," , " ]" )
110
+ val display = clusterCenters(label).size match {
111
+ case x if x > 100 => centroid.toArray.take(100 ).mkString(" [" , " ," , " ..." )
112
+ case _ => centroid.toArray.mkString(" [" , " ," , " ]" )
113
+ }
114
+
115
+ logInfo(s " Cluster $label updated with weight $updatedWeight and centroid: $display" )
116
+ }
117
+
118
+ // Check whether the smallest cluster is dying. If so, split the largest cluster.
119
+ val weightsWithIndex = clusterWeights.view.zipWithIndex
120
+ val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
121
+ val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
122
+ if (minWeight < 1e-8 * maxWeight) {
123
+ logInfo(s " Cluster $smallest is dying. Split the largest cluster $largest into two. " )
124
+ val weight = (maxWeight + minWeight) / 2.0
125
+ clusterWeights(largest) = weight
126
+ clusterWeights(smallest) = weight
127
+ val largestClusterCenter = clusterCenters(largest)
128
+ val smallestClusterCenter = clusterCenters(smallest)
129
+ var j = 0
130
+ while (j < dim) {
131
+ val x = largestClusterCenter(j)
132
+ val p = 1e-14 * math.max(math.abs(x), 1.0 )
133
+ largestClusterCenter.toBreeze(j) = x + p
134
+ smallestClusterCenter.toBreeze(j) = x - p
135
+ j += 1
111
136
}
112
- logInfo(" Cluster %d updated: %s " .format (label, display))
113
137
}
114
- new StreamingKMeansModel (centers, counts)
115
- }
116
138
139
+ this
140
+ }
117
141
}
142
+
118
143
/**
119
144
* :: DeveloperApi ::
120
145
* StreamingKMeans provides methods for configuring a
@@ -128,7 +153,7 @@ class StreamingKMeansModel(
128
153
* val model = new StreamingKMeans()
129
154
* .setDecayFactor(0.5)
130
155
* .setK(3)
131
- * .setRandomCenters(5)
156
+ * .setRandomCenters(5, 100.0 )
132
157
* .trainOn(DStream)
133
158
*/
134
159
@ DeveloperApi
@@ -137,9 +162,9 @@ class StreamingKMeans(
137
162
var decayFactor : Double ,
138
163
var timeUnit : String ) extends Logging {
139
164
140
- protected var model : StreamingKMeansModel = new StreamingKMeansModel ( null , null )
165
+ def this () = this ( 2 , 1.0 , StreamingKMeans . BATCHES )
141
166
142
- def this () = this ( 2 , 1.0 , " batches " )
167
+ protected var model : StreamingKMeansModel = new StreamingKMeansModel ( null , null )
143
168
144
169
/** Set the number of clusters. */
145
170
def setK (k : Int ): this .type = {
@@ -155,7 +180,7 @@ class StreamingKMeans(
155
180
156
181
/** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
157
182
def setHalfLife (halfLife : Double , timeUnit : String ): this .type = {
158
- if (timeUnit != " batches " && timeUnit != " points " ) {
183
+ if (timeUnit != StreamingKMeans . BATCHES && timeUnit != StreamingKMeans . POINTS ) {
159
184
throw new IllegalArgumentException (" Invalid time unit for decay: " + timeUnit)
160
185
}
161
186
this .decayFactor = math.exp(math.log(0.5 ) / halfLife)
@@ -165,26 +190,23 @@ class StreamingKMeans(
165
190
}
166
191
167
192
/** Specify initial centers directly. */
168
- def setInitialCenters (initialCenters : Array [Vector ]): this .type = {
169
- val clusterCounts = new Array [Long ](this .k)
170
- this .model = new StreamingKMeansModel (initialCenters, clusterCounts)
193
+ def setInitialCenters (centers : Array [Vector ], weights : Array [Double ]): this .type = {
194
+ model = new StreamingKMeansModel (centers, weights)
171
195
this
172
196
}
173
197
174
- /** Initialize random centers, requiring only the number of dimensions.
175
- *
176
- * @param dim Number of dimensions
177
- * @param seed Random seed
178
- * */
179
- def setRandomCenters (dim : Int , seed : Long = Utils .random.nextLong): this .type = {
180
-
181
- val random = Utils .random
182
- random.setSeed(seed)
183
-
184
- val initialCenters = (0 until k)
185
- .map(_ => Vectors .dense(Array .fill(dim)(random.nextGaussian()))).toArray
186
- val clusterCounts = new Array [Long ](this .k)
187
- this .model = new StreamingKMeansModel (initialCenters, clusterCounts)
198
+ /**
199
+ * Initialize random centers, requiring only the number of dimensions.
200
+ *
201
+ * @param dim Number of dimensions
202
+ * @param weight Weight for each center
203
+ * @param seed Random seed
204
+ */
205
+ def setRandomCenters (dim : Int , weight : Double , seed : Long = Utils .random.nextLong): this .type = {
206
+ val random = new XORShiftRandom (seed)
207
+ val centers = Array .fill(k)(Vectors .dense(Array .fill(dim)(random.nextGaussian())))
208
+ val weights = Array .fill(k)(weight)
209
+ model = new StreamingKMeansModel (centers, weights)
188
210
this
189
211
}
190
212
@@ -202,9 +224,9 @@ class StreamingKMeans(
202
224
* @param data DStream containing vector data
203
225
*/
204
226
def trainOn (data : DStream [Vector ]) {
205
- this . assertInitialized()
227
+ assertInitialized()
206
228
data.foreachRDD { (rdd, time) =>
207
- model = model.update(rdd, this . decayFactor, this . timeUnit)
229
+ model = model.update(rdd, decayFactor, timeUnit)
208
230
}
209
231
}
210
232
@@ -215,7 +237,7 @@ class StreamingKMeans(
215
237
* @return DStream containing predictions
216
238
*/
217
239
def predictOn (data : DStream [Vector ]): DStream [Int ] = {
218
- this . assertInitialized()
240
+ assertInitialized()
219
241
data.map(model.predict)
220
242
}
221
243
@@ -227,16 +249,20 @@ class StreamingKMeans(
227
249
* @return DStream containing the input keys and the predictions as values
228
250
*/
229
251
def predictOnValues [K : ClassTag ](data : DStream [(K , Vector )]): DStream [(K , Int )] = {
230
- this . assertInitialized()
252
+ assertInitialized()
231
253
data.mapValues(model.predict)
232
254
}
233
255
234
256
/** Check whether cluster centers have been initialized. */
235
- def assertInitialized (): Unit = {
236
- if (Option ( model.clusterCenters) == None ) {
257
+ private [ this ] def assertInitialized (): Unit = {
258
+ if (model.clusterCenters == null ) {
237
259
throw new IllegalStateException (
238
260
" Initial cluster centers must be set before starting predictions" )
239
261
}
240
262
}
263
+ }
241
264
265
+ private [clustering] object StreamingKMeans {
266
+ final val BATCHES = " batches"
267
+ final val POINTS = " points"
242
268
}
0 commit comments