Skip to content

Commit b97fe00

Browse files
committed
Minor fixes and tweaks.
1 parent 1de73f3 commit b97fe00

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ object DenseGmmEM {
3737
}
3838
}
3939

40-
def run(inputFile: String, k: Int, convergenceTol: Double) {
40+
private def run(inputFile: String, k: Int, convergenceTol: Double) {
4141
val conf = new SparkConf().setAppName("Spark EM Sample")
4242
val ctx = new SparkContext(conf)
4343

4444
val data = ctx.textFile(inputFile).map{ line =>
4545
Vectors.dense(line.trim.split(' ').map(_.toDouble))
46-
}.cache
46+
}.cache()
4747

4848
val clusters = new GaussianMixtureModelEM()
4949
.setK(k)
@@ -55,11 +55,11 @@ object DenseGmmEM {
5555
(clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
5656
}
5757

58-
println("Cluster labels:")
58+
println("Cluster labels (first <= 100):")
5959
val (responsibilityMatrix, clusterLabels) = clusters.predict(data)
60-
for (x <- clusterLabels.collect) {
60+
clusterLabels.take(100).foreach{ x =>
6161
print(" " + x)
6262
}
63-
println
63+
println()
6464
}
6565
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,12 @@ class GaussianMixtureModelEM private (
7171
// (U, U) => U for aggregation
7272
private def addExpectationSums(m1: ExpectationSum, m2: ExpectationSum): ExpectationSum = {
7373
m1._1(0) += m2._1(0)
74-
for (i <- 0 until m1._2.length) {
74+
var i = 0
75+
while (i < m1._2.length) {
7576
m1._2(i) += m2._2(i)
7677
m1._3(i) += m2._3(i)
7778
m1._4(i) += m2._4(i)
79+
i = i + 1
7880
}
7981
m1
8082
}
@@ -90,11 +92,13 @@ class GaussianMixtureModelEM private (
9092
val pSum = p.sum
9193
sums._1(0) += math.log(pSum)
9294
val xxt = x * new Transpose(x)
93-
for (i <- 0 until k) {
95+
var i = 0
96+
while (i < k) {
9497
p(i) /= pSum
9598
sums._2(i) += p(i)
9699
sums._3(i) += x * p(i)
97100
sums._4(i) += xxt * p(i)
101+
i = i + 1
98102
}
99103
sums
100104
}
@@ -123,7 +127,7 @@ class GaussianMixtureModelEM private (
123127
}
124128

125129
/** Return the user supplied initial GMM, if supplied */
126-
def getInitialiGmm: Option[GaussianMixtureModel] = initialGmm
130+
def getInitialGmm: Option[GaussianMixtureModel] = initialGmm
127131

128132
/** Set the number of Gaussians in the mixture model. Default: 2 */
129133
def setK(k: Int): this.type = {
@@ -182,7 +186,7 @@ class GaussianMixtureModelEM private (
182186

183187
case None => {
184188
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
185-
((0 until k).map(_ => 1.0 / k).toArray, (0 until k).map{ i =>
189+
(Array.fill[Double](k)(1.0 / k), (0 until k).map{ i =>
186190
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
187191
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
188192
}.toArray)

0 commit comments

Comments
 (0)