Skip to content

Commit 3ef6c7f

Browse files
committed
In GaussianMixtureModel: Changed name of weight, gaussian to weights, gaussians. Other sources modified accordingly.
1 parent 091e8da commit 3ef6c7f

File tree

4 files changed

+18
-18
lines changed

4 files changed

+18
-18
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ object DenseGmmEM {
5454

5555
for (i <- 0 until clusters.k) {
5656
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
57-
(clusters.weight(i), clusters.gaussian(i).mu, clusters.gaussian(i).sigma))
57+
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
5858
}
5959

6060
println("Cluster labels (first <= 100):")

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class GaussianMixtureEM private (
134134
// diagonal covariance matrices using component variances
135135
// derived from the samples
136136
val (weights, gaussians) = initialModel match {
137-
case Some(gmm) => (gmm.weight, gmm.gaussian)
137+
case Some(gmm) => (gmm.weights, gmm.gaussians)
138138

139139
case None => {
140140
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ import org.apache.spark.mllib.util.MLUtils
3636
* covariance matrix for Gaussian i
3737
*/
3838
class GaussianMixtureModel(
39-
val weight: Array[Double],
40-
val gaussian: Array[MultivariateGaussian]) extends Serializable {
39+
val weights: Array[Double],
40+
val gaussians: Array[MultivariateGaussian]) extends Serializable {
4141

42-
require(weight.length == gaussian.length, "Length of weight and Gaussian arrays must match")
42+
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
4343

4444
/** Number of gaussians in mixture */
45-
def k: Int = weight.length
45+
def k: Int = weights.length
4646

4747
/** Maps given points to their cluster indices. */
4848
def predict(points: RDD[Vector]): RDD[Int] = {
@@ -56,10 +56,10 @@ class GaussianMixtureModel(
5656
*/
5757
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
5858
val sc = points.sparkContext
59-
val dists = sc.broadcast(gaussian)
60-
val weights = sc.broadcast(weight)
59+
val bcDists = sc.broadcast(gaussians)
60+
val bcWeights = sc.broadcast(weights)
6161
points.map { x =>
62-
computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
62+
computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
6363
}
6464
}
6565

mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
4040
val seeds = Array(314589, 29032897, 50181, 494821, 4660)
4141
seeds.foreach { seed =>
4242
val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
43-
assert(gmm.weight(0) ~== Ew absTol 1E-5)
44-
assert(gmm.gaussian(0).mu ~== Emu absTol 1E-5)
45-
assert(gmm.gaussian(0).sigma ~== Esigma absTol 1E-5)
43+
assert(gmm.weights(0) ~== Ew absTol 1E-5)
44+
assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
45+
assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
4646
}
4747
}
4848

@@ -73,11 +73,11 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
7373
.setInitialModel(initialGmm)
7474
.run(data)
7575

76-
assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
77-
assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
78-
assert(gmm.gaussian(0).mu ~== Emu(0) absTol 1E-3)
79-
assert(gmm.gaussian(1).mu ~== Emu(1) absTol 1E-3)
80-
assert(gmm.gaussian(0).sigma ~== Esigma(0) absTol 1E-3)
81-
assert(gmm.gaussian(1).sigma ~== Esigma(1) absTol 1E-3)
76+
assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
77+
assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
78+
assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
79+
assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3)
80+
assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
81+
assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
8282
}
8383
}

0 commit comments

Comments
 (0)