Skip to content

Commit 0411bf5

Browse files
committed
Change decay parameterization
- Use a single halfLife parameter that now determines the decay factor directly - Allow specification of timeUnit for the halfLife as “batches” or “points” - Documentation adjusted accordingly
1 parent 9f7aea9 commit 0411bf5

File tree

3 files changed

+35
-130
lines changed

3 files changed

+35
-130
lines changed

docs/mllib-clustering.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,13 @@ to the cluster thus far, `$x_t$` is the new cluster center from the current batc
174174
is the number of points added to the cluster in the current batch. The decay factor `$\alpha$`
175175
can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning;
176176
with `$\alpha$=0` only the most recent data will be used. This is analogous to an
177-
exponentially-weighted moving average.
177+
exponentially-weighted moving average.
178+
179+
The decay can be specified using a `halfLife` parameter, which determines the
180+
correct decay factor `a` such that, for data acquired
181+
at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5.
182+
The unit of time can be specified either as `batches` or `points` and the update rule
183+
will be adjusted accordingly.
178184

179185
### Examples
180186

mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,28 @@ import org.apache.spark.util.Utils
3939
*
4040
* The update algorithm uses the "mini-batch" KMeans rule,
4141
* generalized to incorporate forgetfullness (i.e. decay).
42-
* The basic update rule (for each cluster) is:
42+
* The update rule (for each cluster) is:
4343
*
44-
* c_t+1 = [(c_t * n_t) + (x_t * m_t)] / [n_t + m_t]
45-
* n_t+t = n_t + m_t
44+
* c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
45+
* n_t+t = n_t * a + m_t
4646
*
4747
* Where c_t is the previously estimated centroid for that cluster,
4848
* n_t is the number of points assigned to it thus far, x_t is the centroid
4949
* estimated on the current batch, and m_t is the number of points assigned
5050
* to that centroid in the current batch.
5151
*
52-
* This update rule is modified with a decay factor 'a' that scales
53-
* the contribution of the clusters as estimated thus far.
54-
* If a=1, all batches are weighted equally. If a=0, new centroids
52+
* The decay factor 'a' scales the contribution of the clusters as estimated thus far,
53+
* by applying a as a discount weighting on the current point when evaluating
54+
* new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
5555
* are determined entirely by recent data. Lower values correspond to
5656
* more forgetting.
5757
*
58-
* Decay can optionally be specified as a decay fraction 'q',
59-
* which corresponds to the fraction of batches (or points)
60-
* after which the past will be reduced to a contribution of 0.5.
61-
* This decay fraction can be specified in units of 'points' or 'batches'.
62-
* if 'batches', behavior will be independent of the number of points per batch;
63-
* if 'points', the expected number of points per batch must be specified.
58+
* Decay can optionally be specified by a half life and associated
59+
* time unit. The time unit can either be a batch of data or a single
60+
* data point. Considering data arrived at time t, the half life h is defined
61+
* such that at time t + h the discount applied to the data from t is 0.5.
62+
* The definition remains the same whether the time unit is given
63+
* as batches or points.
6464
*
6565
*/
6666
@DeveloperApi
@@ -69,7 +69,7 @@ class StreamingKMeansModel(
6969
val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) with Logging {
7070

7171
/** Perform a k-means update on a batch of data. */
72-
def update(data: RDD[Vector], a: Double, units: String): StreamingKMeansModel = {
72+
def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
7373

7474
val centers = clusterCenters
7575
val counts = clusterCounts
@@ -94,12 +94,12 @@ class StreamingKMeansModel(
9494
val newCount = count
9595
val newCentroid = mean / newCount.toDouble
9696
// compute the normalized scale factor that controls forgetting
97-
val decayFactor = units match {
98-
case "batches" => newCount / (a * oldCount + newCount)
99-
case "points" => newCount / (math.pow(a, newCount) * oldCount + newCount)
97+
val lambda = timeUnit match {
98+
case "batches" => newCount / (decayFactor * oldCount + newCount)
99+
case "points" => newCount / (math.pow(decayFactor, newCount) * oldCount + newCount)
100100
}
101101
// perform the update
102-
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * decayFactor
102+
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda
103103
// store the new counts and centers
104104
counts(label) = oldCount + newCount
105105
centers(label) = Vectors.fromBreeze(updatedCentroid)
@@ -134,8 +134,8 @@ class StreamingKMeansModel(
134134
@DeveloperApi
135135
class StreamingKMeans(
136136
var k: Int,
137-
var a: Double,
138-
var units: String) extends Logging {
137+
var decayFactor: Double,
138+
var timeUnit: String) extends Logging {
139139

140140
protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
141141

@@ -149,30 +149,18 @@ class StreamingKMeans(
149149

150150
/** Set the decay factor directly (for forgetful algorithms). */
151151
def setDecayFactor(a: Double): this.type = {
152-
this.a = a
152+
this.decayFactor = decayFactor
153153
this
154154
}
155155

156-
/** Set the decay units for forgetful algorithms ("batches" or "points"). */
157-
def setUnits(units: String): this.type = {
158-
if (units != "batches" && units != "points") {
159-
throw new IllegalArgumentException("Invalid units for decay: " + units)
156+
/** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
157+
def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
158+
if (timeUnit != "batches" && timeUnit != "points") {
159+
throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
160160
}
161-
this.units = units
162-
this
163-
}
164-
165-
/** Set decay fraction in units of batches. */
166-
def setDecayFractionBatches(q: Double): this.type = {
167-
this.a = math.log(1 - q) / math.log(0.5)
168-
this.units = "batches"
169-
this
170-
}
171-
172-
/** Set decay fraction in units of points. Must specify expected number of points per batch. */
173-
def setDecayFractionPoints(q: Double, m: Double): this.type = {
174-
this.a = math.pow(math.log(1 - q) / math.log(0.5), 1/m)
175-
this.units = "points"
161+
this.decayFactor = math.exp(math.log(0.5) / halfLife)
162+
logInfo("Setting decay factor to: %g ".format (this.decayFactor))
163+
this.timeUnit = timeUnit
176164
this
177165
}
178166

@@ -216,7 +204,7 @@ class StreamingKMeans(
216204
def trainOn(data: DStream[Vector]) {
217205
this.assertInitialized()
218206
data.foreachRDD { (rdd, time) =>
219-
model = model.update(rdd, this.a, this.units)
207+
model = model.update(rdd, this.decayFactor, this.timeUnit)
220208
}
221209
}
222210

mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20-
import scala.collection.mutable.ArrayBuffer
2120
import scala.util.Random
2221

2322
import org.scalatest.FunSuite
@@ -98,94 +97,6 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
9897

9998
}
10099

101-
test("drifting with fractional decay in units of batches") {
102-
103-
val numBatches1 = 50
104-
val numBatches2 = 50
105-
val numPoints = 1
106-
val q = 0.25
107-
val k = 1
108-
val d = 1
109-
val r = 2.0
110-
111-
// create model with two clusters
112-
val model = new StreamingKMeans()
113-
.setK(1)
114-
.setDecayFractionBatches(q)
115-
.setInitialCenters(Array(Vectors.dense(0.0)))
116-
117-
// create two batches of data with different, pre-specified centers
118-
// to simulate a transition from one cluster to another
119-
val (input1, centers1) = StreamingKMeansDataGenerator(
120-
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0)))
121-
val (input2, centers2) = StreamingKMeansDataGenerator(
122-
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0)))
123-
124-
// store the history
125-
val history = new ArrayBuffer[Double](numBatches1 + numBatches2)
126-
127-
// setup and run the model training
128-
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => {
129-
model.trainOn(inputDStream)
130-
// extract the center (in this case one-dimensional)
131-
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0)))
132-
inputDStream.count()
133-
})
134-
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2)
135-
136-
// check that the fraction of batches required to reach 50
137-
// equals the setting of q, by finding the index of the first batch
138-
// below 50 and comparing to total number of batches received
139-
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble
140-
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex
141-
assert(fraction ~== q absTol 1E-1)
142-
143-
}
144-
145-
test("drifting with fractional decay in units of points") {
146-
147-
val numBatches1 = 50
148-
val numBatches2 = 50
149-
val numPoints = 10
150-
val q = 0.25
151-
val k = 1
152-
val d = 1
153-
val r = 2.0
154-
155-
// create model with two clusters
156-
val model = new StreamingKMeans()
157-
.setK(1)
158-
.setDecayFractionPoints(q, numPoints)
159-
.setInitialCenters(Array(Vectors.dense(0.0)))
160-
161-
// create two batches of data with different, pre-specified centers
162-
// to simulate a transition from one cluster to another
163-
val (input1, centers1) = StreamingKMeansDataGenerator(
164-
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0)))
165-
val (input2, centers2) = StreamingKMeansDataGenerator(
166-
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0)))
167-
168-
// store the history
169-
val history = new ArrayBuffer[Double](numBatches1 + numBatches2)
170-
171-
// setup and run the model training
172-
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => {
173-
model.trainOn(inputDStream)
174-
// extract the center (in this case one-dimensional)
175-
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0)))
176-
inputDStream.count()
177-
})
178-
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2)
179-
180-
// check that the fraction of batches required to reach 50
181-
// equals the setting of q, by finding the index of the first batch
182-
// below 50 and comparing to total number of batches received
183-
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble
184-
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex
185-
assert(fraction ~== q absTol 1E-1)
186-
187-
}
188-
189100
def StreamingKMeansDataGenerator(
190101
numPoints: Int,
191102
numBatches: Int,

0 commit comments

Comments
 (0)