Skip to content

Commit e6ea805

Browse files
committed
Merged with master branch; update test suite with latest context changes.
Improved cluster initialization strategy.
1 parent 86fb382 commit e6ea805

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximization.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ class GMMExpectationMaximization private (
8181
private type DenseDoubleVector = BreezeVector[Double]
8282
private type DenseDoubleMatrix = BreezeMatrix[Double]
8383

84+
// number of samples per cluster to use when initializing Gaussians
85+
private val nSamples = 5;
86+
8487
// A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold
8588
def this() = this(2, 0.01, 100)
8689

@@ -118,15 +121,15 @@ class GMMExpectationMaximization private (
118121
// Get length of the input vectors
119122
val d = breezeData.first.length
120123

121-
// For each Gaussian, we will initialize the mean as some random
122-
// point from the data. (This could be improved)
123-
val samples = breezeData.takeSample(true, k, scala.util.Random.nextInt)
124+
// For each Gaussian, we will initialize the mean as the average
125+
// of some random samples from the data
126+
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
124127

125128
// C will be array of (weight, mean, covariance) tuples
126129
// we start with uniform weights, a random mean from the data, and
127130
// identity matrices for covariance
128131
var C = (0 until k).map(i => (1.0/k,
129-
samples(i),
132+
vec_mean(samples.slice(i * nSamples, (i + 1) * nSamples)),
130133
BreezeMatrix.eye[Double](d))).toArray
131134

132135
val acc_w = new Array[Accumulator[Double]](k)
@@ -148,7 +151,7 @@ class GMMExpectationMaximization private (
148151
}
149152

150153
val log_likelihood = ctx.accumulator(0.0)
151-
154+
152155
// broadcast the current weights and distributions to all nodes
153156
val dists = ctx.broadcast((0 until k).map(i =>
154157
new MultivariateGaussian(C(i)._2, C(i)._3)).toArray)
@@ -164,11 +167,12 @@ class GMMExpectationMaximization private (
164167
log_likelihood += math.log(norm)
165168

166169
// accumulate weighted sums
170+
val xxt = x * new Transpose(x)
167171
for(i <- 0 until k){
168172
p(i) /= norm
169173
acc_w(i) += p(i)
170174
acc_mu(i) += x * p(i)
171-
acc_sigma(i) += x * new Transpose(x) * p(i)
175+
acc_sigma(i) += xxt * p(i)
172176
}
173177
})
174178

@@ -205,6 +209,13 @@ class GMMExpectationMaximization private (
205209
s
206210
}
207211

212+
/** Average of dense breeze vectors */
213+
private def vec_mean(x : Array[DenseDoubleVector]) : DenseDoubleVector = {
214+
val v = BreezeVector.zeros[Double](x(0).length)
215+
(0 until x.length).foreach(j => v += x(j))
216+
v / x.length.asInstanceOf[Double]
217+
}
218+
208219
/** AccumulatorParam for Dense Breeze Vectors */
209220
private object DenseDoubleVectorAccumulatorParam extends AccumulatorParam[DenseDoubleVector] {
210221
def zero(initialVector : DenseDoubleVector) : DenseDoubleVector = {

mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ package org.apache.spark.mllib.clustering
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.linalg.{Vectors, Matrices}
23-
import org.apache.spark.mllib.util.LocalSparkContext
23+
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
2424
import org.apache.spark.mllib.util.TestingUtils._
2525

26-
class GMMExpectationMaximizationSuite extends FunSuite with LocalSparkContext {
26+
class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContext {
2727
test("single cluster") {
2828
val data = sc.parallelize(Array(
2929
Vectors.dense(6.0, 9.0),

0 commit comments

Comments
 (0)