Skip to content

Commit b93350f

Browse files
committed
Streaming KMeans with decay
- Used trainOn and predictOn pattern, similar to StreamingLinearAlgorithm - Decay factor can be set explicitly, or via fractional decay parameters expressed in units of number of batches, or number of points - Unit tests for basic functionality and decay settings
1 parent 31f0b07 commit b93350f

File tree

2 files changed

+337
-0
lines changed

2 files changed

+337
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package org.apache.spark.mllib.clustering
2+
3+
import breeze.linalg.{Vector => BV}
4+
5+
import scala.reflect.ClassTag
6+
import scala.util.Random._
7+
8+
import org.apache.spark.annotation.DeveloperApi
9+
import org.apache.spark.Logging
10+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
11+
import org.apache.spark.rdd.RDD
12+
import org.apache.spark.SparkContext._
13+
import org.apache.spark.streaming.dstream.DStream
14+
import org.apache.spark.streaming.StreamingContext._
15+
16+
@DeveloperApi
17+
class StreamingKMeansModel(
18+
override val clusterCenters: Array[Vector],
19+
val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) {
20+
21+
/** do a sequential KMeans update on a batch of data **/
22+
def update(data: RDD[Vector], a: Double, units: String): StreamingKMeansModel = {
23+
24+
val centers = clusterCenters
25+
val counts = clusterCounts
26+
27+
// find nearest cluster to each point
28+
val closest = data.map(point => (this.predict(point), (point.toBreeze, 1.toLong)))
29+
30+
// get sums and counts for updating each cluster
31+
type WeightedPoint = (BV[Double], Long)
32+
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
33+
(p1._1 += p2._1, p1._2 + p2._2)
34+
}
35+
val pointStats: Array[(Int, (BV[Double], Long))] =
36+
closest.reduceByKey{mergeContribs}.collectAsMap().toArray
37+
38+
// implement update rule
39+
for (newP <- pointStats) {
40+
// store old count and centroid
41+
val oldCount = counts(newP._1)
42+
val oldCentroid = centers(newP._1).toBreeze
43+
// get new count and centroid
44+
val newCount = newP._2._2
45+
val newCentroid = newP._2._1 / newCount.toDouble
46+
// compute the normalized scale factor that controls forgetting
47+
val decayFactor = units match {
48+
case "batches" => newCount / (a * oldCount + newCount)
49+
case "points" => newCount / (math.pow(a, newCount) * oldCount + newCount)
50+
}
51+
// perform the update
52+
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * decayFactor
53+
// store the new counts and centers
54+
counts(newP._1) = oldCount + newCount
55+
centers(newP._1) = Vectors.fromBreeze(updatedCentroid)
56+
}
57+
58+
new StreamingKMeansModel(centers, counts)
59+
}
60+
61+
}
62+
63+
@DeveloperApi
64+
class StreamingKMeans(
65+
var k: Int,
66+
var a: Double,
67+
var units: String) extends Logging {
68+
69+
protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
70+
71+
def this() = this(2, 1.0, "batches")
72+
73+
def setK(k: Int): this.type = {
74+
this.k = k
75+
this
76+
}
77+
78+
def setDecayFactor(a: Double): this.type = {
79+
this.a = a
80+
this
81+
}
82+
83+
def setUnits(units: String): this.type = {
84+
this.units = units
85+
this
86+
}
87+
88+
def setDecayFractionBatches(q: Double): this.type = {
89+
this.a = math.log(1 - q) / math.log(0.5)
90+
this.units = "batches"
91+
this
92+
}
93+
94+
def setDecayFractionPoints(q: Double, m: Double): this.type = {
95+
this.a = math.pow(math.log(1 - q) / math.log(0.5), 1/m)
96+
this.units = "points"
97+
this
98+
}
99+
100+
def setInitialCenters(initialCenters: Array[Vector]): this.type = {
101+
val clusterCounts = Array.fill(this.k)(0).map(_.toLong)
102+
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
103+
this
104+
}
105+
106+
def setRandomCenters(d: Int): this.type = {
107+
val initialCenters = (0 until k).map(_ => Vectors.dense(Array.fill(d)(nextGaussian()))).toArray
108+
val clusterCounts = Array.fill(0)(d).map(_.toLong)
109+
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
110+
this
111+
}
112+
113+
def latestModel(): StreamingKMeansModel = {
114+
model
115+
}
116+
117+
def trainOn(data: DStream[Vector]) {
118+
this.isInitialized
119+
data.foreachRDD { (rdd, time) =>
120+
model = model.update(rdd, this.a, this.units)
121+
}
122+
}
123+
124+
def predictOn(data: DStream[Vector]): DStream[Int] = {
125+
this.isInitialized
126+
data.map(model.predict)
127+
}
128+
129+
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
130+
this.isInitialized
131+
data.mapValues(model.predict)
132+
}
133+
134+
def isInitialized: Boolean = {
135+
if (Option(model.clusterCenters) == None) {
136+
logError("Initial cluster centers must be set before starting predictions")
137+
throw new IllegalArgumentException
138+
} else {
139+
true
140+
}
141+
}
142+
143+
}
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
package org.apache.spark.mllib.clustering
2+
3+
import scala.collection.mutable.ArrayBuffer
4+
import scala.util.Random
5+
6+
import org.scalatest.FunSuite
7+
8+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
9+
import org.apache.spark.mllib.util.TestingUtils._
10+
import org.apache.spark.streaming.dstream.DStream
11+
import org.apache.spark.streaming.TestSuiteBase
12+
13+
class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
14+
15+
override def maxWaitTimeMillis = 30000
16+
17+
test("accuracy for single center and equivalence to grand average") {
18+
19+
// set parameters
20+
val numBatches = 10
21+
val numPoints = 50
22+
val k = 1
23+
val d = 5
24+
val r = 0.1
25+
26+
// create model with one cluster
27+
val model = new StreamingKMeans()
28+
.setK(1)
29+
.setDecayFactor(1.0)
30+
.setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)))
31+
32+
// generate random data for kmeans
33+
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
34+
35+
// setup and run the model training
36+
val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
37+
model.trainOn(inputDStream)
38+
inputDStream.count()
39+
})
40+
runStreams(ssc, numBatches, numBatches)
41+
42+
// estimated center should be close to true center
43+
assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
44+
45+
// estimated center from streaming should exactly match the arithmetic mean of all data points
46+
val grandMean = input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble
47+
assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5)
48+
49+
}
50+
51+
test("accuracy for two centers") {
52+
53+
val numBatches = 10
54+
val numPoints = 5
55+
val k = 2
56+
val d = 5
57+
val r = 0.1
58+
59+
// create model with two clusters
60+
val model = new StreamingKMeans()
61+
.setK(2)
62+
.setDecayFactor(1.0)
63+
.setInitialCenters(Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1),
64+
Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)))
65+
66+
// generate random data for kmeans
67+
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
68+
69+
// setup and run the model training
70+
val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
71+
model.trainOn(inputDStream)
72+
inputDStream.count()
73+
})
74+
runStreams(ssc, numBatches, numBatches)
75+
76+
// check that estimated centers are close to true centers
77+
// NOTE this depends on the initialization! allow for binary flip
78+
assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
79+
assert(centers(1) ~== model.latestModel().clusterCenters(1) absTol 1E-1)
80+
81+
}
82+
83+
test("drifting with fractional decay in units of batches") {
84+
85+
val numBatches1 = 50
86+
val numBatches2 = 50
87+
val numPoints = 1
88+
val q = 0.25
89+
val k = 1
90+
val d = 1
91+
val r = 2.0
92+
93+
// create model with two clusters
94+
val model = new StreamingKMeans()
95+
.setK(1)
96+
.setDecayFractionBatches(q)
97+
.setInitialCenters(Array(Vectors.dense(0.0)))
98+
99+
// create two batches of data with different, pre-specified centers
100+
// to simulate a transition from one cluster to another
101+
val (input1, centers1) = StreamingKMeansDataGenerator(
102+
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0)))
103+
val (input2, centers2) = StreamingKMeansDataGenerator(
104+
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0)))
105+
106+
// store the history
107+
val history = new ArrayBuffer[Double](numBatches1 + numBatches2)
108+
109+
// setup and run the model training
110+
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => {
111+
model.trainOn(inputDStream)
112+
// extract the center (in this case one-dimensional)
113+
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0)))
114+
inputDStream.count()
115+
})
116+
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2)
117+
118+
// check that the fraction of batches required to reach 50
119+
// equals the setting of q, by finding the index of the first batch
120+
// below 50 and comparing to total number of batches received
121+
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble
122+
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex
123+
assert(fraction ~== q absTol 1E-1)
124+
125+
}
126+
127+
test("drifting with fractional decay in units of points") {
128+
129+
val numBatches1 = 50
130+
val numBatches2 = 50
131+
val numPoints = 10
132+
val q = 0.25
133+
val k = 1
134+
val d = 1
135+
val r = 2.0
136+
137+
// create model with two clusters
138+
val model = new StreamingKMeans()
139+
.setK(1)
140+
.setDecayFractionPoints(q, numPoints)
141+
.setInitialCenters(Array(Vectors.dense(0.0)))
142+
143+
// create two batches of data with different, pre-specified centers
144+
// to simulate a transition from one cluster to another
145+
val (input1, centers1) = StreamingKMeansDataGenerator(
146+
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0)))
147+
val (input2, centers2) = StreamingKMeansDataGenerator(
148+
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0)))
149+
150+
// store the history
151+
val history = new ArrayBuffer[Double](numBatches1 + numBatches2)
152+
153+
// setup and run the model training
154+
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => {
155+
model.trainOn(inputDStream)
156+
// extract the center (in this case one-dimensional)
157+
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0)))
158+
inputDStream.count()
159+
})
160+
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2)
161+
162+
// check that the fraction of batches required to reach 50
163+
// equals the setting of q, by finding the index of the first batch
164+
// below 50 and comparing to total number of batches received
165+
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble
166+
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex
167+
assert(fraction ~== q absTol 1E-1)
168+
169+
}
170+
171+
def StreamingKMeansDataGenerator(
172+
numPoints: Int,
173+
numBatches: Int,
174+
k: Int,
175+
d: Int,
176+
r: Double,
177+
seed: Int,
178+
initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = {
179+
val rand = new Random(seed)
180+
val centers = initCenters match {
181+
case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian())))
182+
case _ => initCenters
183+
}
184+
val data = (0 until numBatches).map { i =>
185+
(0 until numPoints).map { idx =>
186+
val center = centers(idx % k)
187+
Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r))
188+
}
189+
}
190+
(data, centers)
191+
}
192+
193+
194+
}

0 commit comments

Comments
 (0)