@@ -39,28 +39,28 @@ import org.apache.spark.util.Utils
39
39
*
40
40
* The update algorithm uses the "mini-batch" KMeans rule,
41
41
* generalized to incorporate forgetfullness (i.e. decay).
42
- * The basic update rule (for each cluster) is:
42
+ * The update rule (for each cluster) is:
43
43
*
44
- * c_t+1 = [(c_t * n_t) + (x_t * m_t)] / [n_t + m_t]
45
- * n_t+t = n_t + m_t
44
+ * c_t+1 = [(c_t * n_t * a ) + (x_t * m_t)] / [n_t + m_t]
45
+ * n_t+t = n_t * a + m_t
46
46
*
47
47
* Where c_t is the previously estimated centroid for that cluster,
48
48
* n_t is the number of points assigned to it thus far, x_t is the centroid
49
49
* estimated on the current batch, and m_t is the number of points assigned
50
50
* to that centroid in the current batch.
51
51
*
52
- * This update rule is modified with a decay factor 'a' that scales
53
- * the contribution of the clusters as estimated thus far.
54
- * If a=1, all batches are weighted equally. If a=0, new centroids
52
+ * The decay factor 'a' scales the contribution of the clusters as estimated thus far,
53
+ * by applying a as a discount weighting on the current point when evaluating
54
+ * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
55
55
* are determined entirely by recent data. Lower values correspond to
56
56
* more forgetting.
57
57
*
58
- * Decay can optionally be specified as a decay fraction 'q',
59
- * which corresponds to the fraction of batches ( or points)
60
- * after which the past will be reduced to a contribution of 0.5.
61
- * This decay fraction can be specified in units of 'points' or 'batches' .
62
- * if 'batches', behavior will be independent of the number of points per batch;
63
- * if 'points', the expected number of points per batch must be specified .
58
+ * Decay can optionally be specified by a half life and associated
59
+ * time unit. The time unit can either be a batch of data or a single
60
+ * data point. Considering data arrived at time t, the half life h is defined
61
+ * such that at time t + h the discount applied to the data from t is 0.5 .
62
+ * The definition remains the same whether the time unit is given
63
+ * as batches or points.
64
64
*
65
65
*/
66
66
@ DeveloperApi
@@ -69,7 +69,7 @@ class StreamingKMeansModel(
69
69
val clusterCounts : Array [Long ]) extends KMeansModel (clusterCenters) with Logging {
70
70
71
71
/** Perform a k-means update on a batch of data. */
72
- def update (data : RDD [Vector ], a : Double , units : String ): StreamingKMeansModel = {
72
+ def update (data : RDD [Vector ], decayFactor : Double , timeUnit : String ): StreamingKMeansModel = {
73
73
74
74
val centers = clusterCenters
75
75
val counts = clusterCounts
@@ -94,12 +94,12 @@ class StreamingKMeansModel(
94
94
val newCount = count
95
95
val newCentroid = mean / newCount.toDouble
96
96
// compute the normalized scale factor that controls forgetting
97
- val decayFactor = units match {
98
- case " batches" => newCount / (a * oldCount + newCount)
99
- case " points" => newCount / (math.pow(a , newCount) * oldCount + newCount)
97
+ val lambda = timeUnit match {
98
+ case " batches" => newCount / (decayFactor * oldCount + newCount)
99
+ case " points" => newCount / (math.pow(decayFactor , newCount) * oldCount + newCount)
100
100
}
101
101
// perform the update
102
- val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * decayFactor
102
+ val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda
103
103
// store the new counts and centers
104
104
counts(label) = oldCount + newCount
105
105
centers(label) = Vectors .fromBreeze(updatedCentroid)
@@ -134,8 +134,8 @@ class StreamingKMeansModel(
134
134
@ DeveloperApi
135
135
class StreamingKMeans (
136
136
var k : Int ,
137
- var a : Double ,
138
- var units : String ) extends Logging {
137
+ var decayFactor : Double ,
138
+ var timeUnit : String ) extends Logging {
139
139
140
140
protected var model : StreamingKMeansModel = new StreamingKMeansModel (null , null )
141
141
@@ -149,30 +149,18 @@ class StreamingKMeans(
149
149
150
150
/** Set the decay factor directly (for forgetful algorithms). */
151
151
def setDecayFactor (a : Double ): this .type = {
152
- this .a = a
152
+ this .decayFactor = decayFactor
153
153
this
154
154
}
155
155
156
- /** Set the decay units for forgetful algorithms ("batches" or "points"). */
157
- def setUnits ( units : String ): this .type = {
158
- if (units != " batches" && units != " points" ) {
159
- throw new IllegalArgumentException (" Invalid units for decay: " + units )
156
+ /** Set the half life and time unit ("batches" or "points") for forgetful algorithms . */
157
+ def setHalfLife ( halfLife : Double , timeUnit : String ): this .type = {
158
+ if (timeUnit != " batches" && timeUnit != " points" ) {
159
+ throw new IllegalArgumentException (" Invalid time unit for decay: " + timeUnit )
160
160
}
161
- this .units = units
162
- this
163
- }
164
-
165
- /** Set decay fraction in units of batches. */
166
- def setDecayFractionBatches (q : Double ): this .type = {
167
- this .a = math.log(1 - q) / math.log(0.5 )
168
- this .units = " batches"
169
- this
170
- }
171
-
172
- /** Set decay fraction in units of points. Must specify expected number of points per batch. */
173
- def setDecayFractionPoints (q : Double , m : Double ): this .type = {
174
- this .a = math.pow(math.log(1 - q) / math.log(0.5 ), 1 / m)
175
- this .units = " points"
161
+ this .decayFactor = math.exp(math.log(0.5 ) / halfLife)
162
+ logInfo(" Setting decay factor to: %g " .format (this .decayFactor))
163
+ this .timeUnit = timeUnit
176
164
this
177
165
}
178
166
@@ -216,7 +204,7 @@ class StreamingKMeans(
216
204
def trainOn (data : DStream [Vector ]) {
217
205
this .assertInitialized()
218
206
data.foreachRDD { (rdd, time) =>
219
- model = model.update(rdd, this .a , this .units )
207
+ model = model.update(rdd, this .decayFactor , this .timeUnit )
220
208
}
221
209
}
222
210
0 commit comments