Skip to content

Commit c94d062

Browse files
committed
[SPARK-6226][MLLIB] add save/load in PySpark's KMeansModel
Use `_py2java` and `_java2py` to convert Python model to/from Java model. yinxusen Author: Xiangrui Meng <[email protected]> Closes #5049 from mengxr/SPARK-6226-mengxr and squashes the following commits: 570ba81 [Xiangrui Meng] fix python style b10b911 [Xiangrui Meng] add save/load in PySpark's KMeansModel
1 parent d9f3e01 commit c94d062

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20+
import scala.collection.JavaConverters._
21+
2022
import org.json4s._
2123
import org.json4s.JsonDSL._
2224
import org.json4s.jackson.JsonMethods._
@@ -34,6 +36,9 @@ import org.apache.spark.sql.Row
3436
*/
3537
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
3638

39+
/** A Java-friendly constructor that takes an Iterable of Vectors. */
40+
def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)
41+
3742
/** Total number of clusters. */
3843
def k: Int = clusterCenters.length
3944

python/pyspark/mllib/clustering.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@
1919

2020
from pyspark import RDD
2121
from pyspark import SparkContext
22-
from pyspark.mllib.common import callMLlibFunc, callJavaFunc
23-
from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
22+
from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py
23+
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
2424
from pyspark.mllib.stat.distribution import MultivariateGaussian
25+
from pyspark.mllib.util import Saveable, Loader, inherit_doc
2526

2627
__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']
2728

2829

29-
class KMeansModel(object):
30+
@inherit_doc
31+
class KMeansModel(Saveable, Loader):
3032

3133
"""A clustering model derived from the k-means method.
3234
@@ -55,6 +57,16 @@ class KMeansModel(object):
5557
True
5658
>>> type(model.clusterCenters)
5759
<type 'list'>
60+
>>> import os, tempfile
61+
>>> path = tempfile.mkdtemp()
62+
>>> model.save(sc, path)
63+
>>> sameModel = KMeansModel.load(sc, path)
64+
>>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0])
65+
True
66+
>>> try:
67+
... os.removedirs(path)
68+
... except OSError:
69+
... pass
5870
"""
5971

6072
def __init__(self, centers):
@@ -77,6 +89,16 @@ def predict(self, x):
7789
best_distance = distance
7890
return best
7991

92+
def save(self, sc, path):
93+
java_centers = _py2java(sc, map(_convert_to_vector, self.centers))
94+
java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers)
95+
java_model.save(sc._jsc.sc(), path)
96+
97+
@classmethod
98+
def load(cls, sc, path):
99+
java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path)
100+
return KMeansModel(_java2py(sc, java_model.clusterCenters()))
101+
80102

81103
class KMeans(object):
82104

python/pyspark/mllib/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def _py2java(sc, obj):
7070
obj = _to_java_object_rdd(obj)
7171
elif isinstance(obj, SparkContext):
7272
obj = obj._jsc
73-
elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)):
74-
obj = ListConverter().convert(obj, sc._gateway._gateway_client)
73+
elif isinstance(obj, list):
74+
obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
7575
elif isinstance(obj, JavaObject):
7676
pass
7777
elif isinstance(obj, (int, long, float, bool, basestring)):

0 commit comments

Comments
 (0)