Skip to content

Commit 6cf6fdf

Browse files
tgaloppomengxr
authored andcommitted
SPARK-4156 [MLLIB] EM algorithm for GMMs
Implementation of Expectation-Maximization for Gaussian Mixture Models. This is my maiden contribution to Apache Spark, so I apologize now if I have done anything incorrectly; having said that, this work is my own, and I offer it to the project under the project's open source license. Author: Travis Galoppo <[email protected]> Author: Travis Galoppo <[email protected]> Author: tgaloppo <[email protected]> Author: FlytxtRnD <[email protected]> Closes apache#3022 from tgaloppo/master and squashes the following commits: aaa8f25 [Travis Galoppo] MLUtils: changed privacy of EPSILON from [util] to [mllib] 709e4bf [Travis Galoppo] fixed usage line to include optional maxIterations parameter acf1fba [Travis Galoppo] Fixed parameter comment in GaussianMixtureModel Made maximum iterations an optional parameter to DenseGmmEM 9b2fc2a [Travis Galoppo] Style improvements Changed ExpectationSum to a private class b97fe00 [Travis Galoppo] Minor fixes and tweaks. 1de73f3 [Travis Galoppo] Removed redundant array from array creation 578c2d1 [Travis Galoppo] Removed unused import 227ad66 [Travis Galoppo] Moved prediction methods into model class. 308c8ad [Travis Galoppo] Numerous changes to improve code cff73e0 [Travis Galoppo] Replaced accumulators with RDD.aggregate 20ebca1 [Travis Galoppo] Removed unusued code 42b2142 [Travis Galoppo] Added functionality to allow setting of GMM starting point. Added two cluster test to testing suite. 8b633f3 [Travis Galoppo] Style issue 9be2534 [Travis Galoppo] Style issue d695034 [Travis Galoppo] Fixed style issues c3b8ce0 [Travis Galoppo] Merge branch 'master' of https://github.com/tgaloppo/spark Adds predict() method 2df336b [Travis Galoppo] Fixed style issue b99ecc4 [tgaloppo] Merge pull request #1 from FlytxtRnD/predictBranch f407b4c [FlytxtRnD] Added predict() to return the cluster labels and membership values 97044cf [Travis Galoppo] Fixed style issues dc9c742 [Travis Galoppo] Moved MultivariateGaussian utility class e7d413b [Travis Galoppo] Moved multivariate Gaussian utility class to mllib/stat/impl Improved comments 9770261 [Travis Galoppo] Corrected a variety of style and naming issues. 8aaa17d [Travis Galoppo] Added additional train() method to companion object for cluster count and tolerance parameters. 676e523 [Travis Galoppo] Fixed to no longer ignore delta value provided on command line e6ea805 [Travis Galoppo] Merged with master branch; update test suite with latest context changes. Improved cluster initialization strategy. 86fb382 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' 719d8cc [Travis Galoppo] Added scala test suite with basic test c1a8e16 [Travis Galoppo] Made GaussianMixtureModel class serializable Modified sum function for better performance 5c96c57 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' c15405c [Travis Galoppo] SPARK-4156
1 parent 9bc0df6 commit 6cf6fdf

File tree

6 files changed

+517
-1
lines changed

6 files changed

+517
-1
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib
19+
20+
import org.apache.spark.{SparkConf, SparkContext}
21+
import org.apache.spark.mllib.clustering.GaussianMixtureEM
22+
import org.apache.spark.mllib.linalg.Vectors
23+
24+
/**
25+
* An example Gaussian Mixture Model EM app. Run with
26+
* {{{
27+
* ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM <input> <k> <covergenceTol>
28+
* }}}
29+
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
30+
*/
31+
object DenseGmmEM {
32+
def main(args: Array[String]): Unit = {
33+
if (args.length < 3) {
34+
println("usage: DenseGmmEM <input file> <k> <convergenceTol> [maxIterations]")
35+
} else {
36+
val maxIterations = if (args.length > 3) args(3).toInt else 100
37+
run(args(0), args(1).toInt, args(2).toDouble, maxIterations)
38+
}
39+
}
40+
41+
private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) {
42+
val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
43+
val ctx = new SparkContext(conf)
44+
45+
val data = ctx.textFile(inputFile).map { line =>
46+
Vectors.dense(line.trim.split(' ').map(_.toDouble))
47+
}.cache()
48+
49+
val clusters = new GaussianMixtureEM()
50+
.setK(k)
51+
.setConvergenceTol(convergenceTol)
52+
.setMaxIterations(maxIterations)
53+
.run(data)
54+
55+
for (i <- 0 until clusters.k) {
56+
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
57+
(clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
58+
}
59+
60+
println("Cluster labels (first <= 100):")
61+
val clusterLabels = clusters.predict(data)
62+
clusterLabels.take(100).foreach { x =>
63+
print(" " + x)
64+
}
65+
println()
66+
}
67+
}
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.clustering
19+
20+
import scala.collection.mutable.IndexedSeq
21+
22+
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
23+
import org.apache.spark.rdd.RDD
24+
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
25+
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
26+
import org.apache.spark.mllib.util.MLUtils
27+
28+
/**
29+
* This class performs expectation maximization for multivariate Gaussian
30+
* Mixture Models (GMMs). A GMM represents a composite distribution of
31+
* independent Gaussian distributions with associated "mixing" weights
32+
* specifying each's contribution to the composite.
33+
*
34+
* Given a set of sample points, this class will maximize the log-likelihood
35+
* for a mixture of k Gaussians, iterating until the log-likelihood changes by
36+
* less than convergenceTol, or until it has reached the max number of iterations.
37+
* While this process is generally guaranteed to converge, it is not guaranteed
38+
* to find a global optimum.
39+
*
40+
* @param k The number of independent Gaussians in the mixture model
41+
* @param convergenceTol The maximum change in log-likelihood at which convergence
42+
* is considered to have occurred.
43+
* @param maxIterations The maximum number of iterations to perform
44+
*/
45+
class GaussianMixtureEM private (
46+
private var k: Int,
47+
private var convergenceTol: Double,
48+
private var maxIterations: Int) extends Serializable {
49+
50+
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
51+
def this() = this(2, 0.01, 100)
52+
53+
// number of samples per cluster to use when initializing Gaussians
54+
private val nSamples = 5
55+
56+
// an initializing GMM can be provided rather than using the
57+
// default random starting point
58+
private var initialModel: Option[GaussianMixtureModel] = None
59+
60+
/** Set the initial GMM starting point, bypassing the random initialization.
61+
* You must call setK() prior to calling this method, and the condition
62+
* (model.k == this.k) must be met; failure will result in an IllegalArgumentException
63+
*/
64+
def setInitialModel(model: GaussianMixtureModel): this.type = {
65+
if (model.k == k) {
66+
initialModel = Some(model)
67+
} else {
68+
throw new IllegalArgumentException("mismatched cluster count (model.k != k)")
69+
}
70+
this
71+
}
72+
73+
/** Return the user supplied initial GMM, if supplied */
74+
def getInitialModel: Option[GaussianMixtureModel] = initialModel
75+
76+
/** Set the number of Gaussians in the mixture model. Default: 2 */
77+
def setK(k: Int): this.type = {
78+
this.k = k
79+
this
80+
}
81+
82+
/** Return the number of Gaussians in the mixture model */
83+
def getK: Int = k
84+
85+
/** Set the maximum number of iterations to run. Default: 100 */
86+
def setMaxIterations(maxIterations: Int): this.type = {
87+
this.maxIterations = maxIterations
88+
this
89+
}
90+
91+
/** Return the maximum number of iterations to run */
92+
def getMaxIterations: Int = maxIterations
93+
94+
/**
95+
* Set the largest change in log-likelihood at which convergence is
96+
* considered to have occurred.
97+
*/
98+
def setConvergenceTol(convergenceTol: Double): this.type = {
99+
this.convergenceTol = convergenceTol
100+
this
101+
}
102+
103+
/** Return the largest change in log-likelihood at which convergence is
104+
* considered to have occurred.
105+
*/
106+
def getConvergenceTol: Double = convergenceTol
107+
108+
/** Perform expectation maximization */
109+
def run(data: RDD[Vector]): GaussianMixtureModel = {
110+
val sc = data.sparkContext
111+
112+
// we will operate on the data as breeze data
113+
val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
114+
115+
// Get length of the input vectors
116+
val d = breezeData.first.length
117+
118+
// Determine initial weights and corresponding Gaussians.
119+
// If the user supplied an initial GMM, we use those values, otherwise
120+
// we start with uniform weights, a random mean from the data, and
121+
// diagonal covariance matrices using component variances
122+
// derived from the samples
123+
val (weights, gaussians) = initialModel match {
124+
case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
125+
new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)
126+
})
127+
128+
case None => {
129+
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
130+
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
131+
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
132+
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
133+
})
134+
}
135+
}
136+
137+
var llh = Double.MinValue // current log-likelihood
138+
var llhp = 0.0 // previous log-likelihood
139+
140+
var iter = 0
141+
while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol) {
142+
// create and broadcast curried cluster contribution function
143+
val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)
144+
145+
// aggregate the cluster contribution for all sample points
146+
val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)
147+
148+
// Create new distributions based on the partial assignments
149+
// (often referred to as the "M" step in literature)
150+
val sumWeights = sums.weights.sum
151+
var i = 0
152+
while (i < k) {
153+
val mu = sums.means(i) / sums.weights(i)
154+
val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr
155+
weights(i) = sums.weights(i) / sumWeights
156+
gaussians(i) = new MultivariateGaussian(mu, sigma)
157+
i = i + 1
158+
}
159+
160+
llhp = llh // current becomes previous
161+
llh = sums.logLikelihood // this is the freshly computed log-likelihood
162+
iter += 1
163+
}
164+
165+
// Need to convert the breeze matrices to MLlib matrices
166+
val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) }
167+
val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) }
168+
new GaussianMixtureModel(weights, means, sigmas)
169+
}
170+
171+
/** Average of dense breeze vectors */
172+
private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = {
173+
val v = BreezeVector.zeros[Double](x(0).length)
174+
x.foreach(xi => v += xi)
175+
v / x.length.toDouble
176+
}
177+
178+
/**
179+
* Construct matrix where diagonal entries are element-wise
180+
* variance of input vectors (computes biased variance)
181+
*/
182+
private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = {
183+
val mu = vectorMean(x)
184+
val ss = BreezeVector.zeros[Double](x(0).length)
185+
x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
186+
diag(ss / x.length.toDouble)
187+
}
188+
}
189+
190+
// companion class to provide zero constructor for ExpectationSum
191+
private object ExpectationSum {
192+
def zero(k: Int, d: Int): ExpectationSum = {
193+
new ExpectationSum(0.0, Array.fill(k)(0.0),
194+
Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
195+
}
196+
197+
// compute cluster contributions for each input point
198+
// (U, T) => U for aggregation
199+
def add(
200+
weights: Array[Double],
201+
dists: Array[MultivariateGaussian])
202+
(sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
203+
val p = weights.zip(dists).map {
204+
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
205+
}
206+
val pSum = p.sum
207+
sums.logLikelihood += math.log(pSum)
208+
val xxt = x * new Transpose(x)
209+
var i = 0
210+
while (i < sums.k) {
211+
p(i) /= pSum
212+
sums.weights(i) += p(i)
213+
sums.means(i) += x * p(i)
214+
sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
215+
i = i + 1
216+
}
217+
sums
218+
}
219+
}
220+
221+
// Aggregation class for partial expectation results
222+
private class ExpectationSum(
223+
var logLikelihood: Double,
224+
val weights: Array[Double],
225+
val means: Array[BreezeVector[Double]],
226+
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
227+
228+
val k = weights.length
229+
230+
def +=(x: ExpectationSum): ExpectationSum = {
231+
var i = 0
232+
while (i < k) {
233+
weights(i) += x.weights(i)
234+
means(i) += x.means(i)
235+
sigmas(i) += x.sigmas(i)
236+
i = i + 1
237+
}
238+
logLikelihood += x.logLikelihood
239+
this
240+
}
241+
}

0 commit comments

Comments
 (0)