Skip to content

Commit d8bfc02

Browse files
yu-iskwdavies
authored andcommitted
[SPARK-11566] [MLLIB] [PYTHON] Refactoring GaussianMixtureModel.gaussians in Python
cc jkbradley Author: Yu ISHIKAWA <[email protected]> Closes #9534 from yu-iskw/SPARK-11566. (cherry picked from commit c0e48df) Signed-off-by: Davies Liu <[email protected]>
1 parent 9ccd1bb commit d8bfc02

File tree

2 files changed

+7
-16
lines changed

2 files changed

+7
-16
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@
1717

1818
package org.apache.spark.mllib.api.python
1919

20-
import java.util.{List => JList}
21-
22-
import scala.collection.JavaConverters._
23-
import scala.collection.mutable.ArrayBuffer
20+
import scala.collection.JavaConverters
2421

2522
import org.apache.spark.SparkContext
26-
import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
2723
import org.apache.spark.mllib.clustering.GaussianMixtureModel
24+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2825

2926
/**
3027
* Wrapper around GaussianMixtureModel to provide helper methods in Python
@@ -36,17 +33,11 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
3633
/**
3734
* Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
3835
*/
39-
val gaussians: JList[Object] = {
40-
val modelGaussians = model.gaussians
41-
var i = 0
42-
var mu = ArrayBuffer.empty[Vector]
43-
var sigma = ArrayBuffer.empty[Matrix]
44-
while (i < k) {
45-
mu += modelGaussians(i).mu
46-
sigma += modelGaussians(i).sigma
47-
i += 1
36+
val gaussians: Array[Byte] = {
37+
val modelGaussians = model.gaussians.map { gaussian =>
38+
Array[Any](gaussian.mu, gaussian.sigma)
4839
}
49-
List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
40+
SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava)
5041
}
5142

5243
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)

python/pyspark/mllib/clustering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def gaussians(self):
266266
"""
267267
return [
268268
MultivariateGaussian(gaussian[0], gaussian[1])
269-
for gaussian in zip(*self.call("gaussians"))]
269+
for gaussian in self.call("gaussians")]
270270

271271
@property
272272
@since('1.4.0')

0 commit comments

Comments
 (0)