Skip to content

Commit c3b8ce0

Browse files
committed
Adds predict() method
2 parents 2df336b + b99ecc4 commit c3b8ce0

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,9 @@ object DenseGmmEM {
4747
println("weight=%f mu=%s sigma=\n%s\n" format
4848
(clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
4949
}
50+
val (responsibility_matrix, cluster_labels) = clusters.predict(data)
51+
for(x <- cluster_labels.collect()){
52+
print(" " + x)
53+
}
5054
}
5155
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20+
import org.apache.spark.rdd.RDD
2021
import org.apache.spark.mllib.linalg.Matrix
2122
import org.apache.spark.mllib.linalg.Vector
2223

@@ -38,4 +39,12 @@ class GaussianMixtureModel(
3839

3940
/** Number of gaussians in mixture */
4041
def k: Int = weight.length;
42+
43+
/** Maps given points to their cluster indices. */
44+
def predict(points: RDD[Vector]): (RDD[Array[Double]],RDD[Int]) = {
45+
val responsibility_matrix = new GaussianMixtureModelEM()
46+
.predictClusters(points,mu,sigma,weight,k)
47+
val cluster_labels = responsibility_matrix.map(r => r.indexOf(r.max))
48+
(responsibility_matrix,cluster_labels)
49+
}
4150
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix}
2121
import breeze.linalg.Transpose
2222

2323
import org.apache.spark.rdd.RDD
24-
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
24+
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
2525
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
2626
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
2727
import org.apache.spark.SparkContext.DoubleAccumulatorParam
@@ -208,6 +208,34 @@ class GaussianMixtureModelEM private (
208208
cov
209209
}
210210

211+
/**
212+
Given the input vectors, return the membership value of each vector
213+
to all mixture components.
214+
*/
215+
def predictClusters(points:RDD[Vector],mu:Array[Vector],sigma:Array[Matrix],
216+
weight:Array[Double],k:Int):RDD[Array[Double]]= {
217+
val ctx = points.sparkContext
218+
val dists = ctx.broadcast((0 until k).map(i =>
219+
new MultivariateGaussian(mu(i).toBreeze.toDenseVector,sigma(i).toBreeze.toDenseMatrix))
220+
.toArray)
221+
val weights = ctx.broadcast((0 until k).map(i => weight(i)).toArray)
222+
points.map(x=>compute_log_likelihood(x.toBreeze.toDenseVector,dists.value,weights.value,k))
223+
224+
}
225+
/**
226+
* Compute the log density of each vector
227+
*/
228+
def compute_log_likelihood(pt:DenseDoubleVector,dists:Array[MultivariateGaussian],
229+
weights:Array[Double],k:Int):Array[Double]={
230+
val p = (0 until k).map(i =>
231+
eps + weights(i) * dists(i).pdf(pt)).toArray
232+
val pSum = p.sum
233+
for(i<- 0 until k){
234+
p(i) /= pSum
235+
}
236+
p
237+
}
238+
211239
/** AccumulatorParam for Dense Breeze Vectors */
212240
private object DenseDoubleVectorAccumulatorParam extends AccumulatorParam[DenseDoubleVector] {
213241
def zero(initialVector: DenseDoubleVector): DenseDoubleVector = {

0 commit comments

Comments
 (0)