Skip to content

Commit d7aca6e

Browse files
yanboliangjeanlyn
authored andcommitted
[SPARK-5913] [MLLIB] Python API for ChiSqSelector
Add a Python API for mllib.feature.ChiSqSelector https://issues.apache.org/jira/browse/SPARK-5913 Author: Yanbo Liang <[email protected]> Closes apache#5939 from yanboliang/spark-5913 and squashes the following commits: cdaac99 [Yanbo Liang] Python API for ChiSqSelector
1 parent 34c3532 commit d7aca6e

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,16 @@ private[python] class PythonMLLibAPI extends Serializable {
494494
new StandardScaler(withMean, withStd).fit(data.rdd)
495495
}
496496

497+
/**
498+
* Java stub for ChiSqSelector.fit(). This stub returns a
499+
* handle to the Java object instead of the content of the Java object.
500+
* Extra care needs to be taken in the Python code to ensure it gets freed on
501+
* exit; see the Py4J documentation.
502+
*/
503+
def fitChiSqSelector(numTopFeatures: Int, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
504+
new ChiSqSelector(numTopFeatures).fit(data.rdd)
505+
}
506+
497507
/**
498508
* Java stub for IDF.fit(). This stub returns a
499509
* handle to the Java object instead of the content of the Java object.

python/pyspark/mllib/feature.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@
3333
from pyspark import SparkContext
3434
from pyspark.rdd import RDD, ignore_unicode_prefix
3535
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
36-
from pyspark.mllib.linalg import Vectors, _convert_to_vector
36+
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector
37+
from pyspark.mllib.regression import LabeledPoint
3738

3839
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
39-
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
40+
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
41+
'ChiSqSelector', 'ChiSqSelectorModel']
4042

4143

4244
class VectorTransformer(object):
@@ -199,6 +201,59 @@ def fit(self, dataset):
199201
return StandardScalerModel(jmodel)
200202

201203

204+
class ChiSqSelectorModel(JavaVectorTransformer):
205+
"""
206+
.. note:: Experimental
207+
208+
Represents a Chi Squared selector model.
209+
"""
210+
def transform(self, vector):
211+
"""
212+
Applies transformation on a vector.
213+
214+
:param vector: Vector or RDD of Vector to be transformed.
215+
:return: transformed vector.
216+
"""
217+
return JavaVectorTransformer.transform(self, vector)
218+
219+
220+
class ChiSqSelector(object):
221+
"""
222+
.. note:: Experimental
223+
224+
Creates a ChiSquared feature selector.
225+
226+
>>> data = [
227+
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
228+
... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
229+
... LabeledPoint(1.0, [0.0, 9.0, 8.0]),
230+
... LabeledPoint(2.0, [8.0, 9.0, 5.0])
231+
... ]
232+
>>> model = ChiSqSelector(1).fit(sc.parallelize(data))
233+
>>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
234+
SparseVector(1, {0: 6.0})
235+
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
236+
DenseVector([5.0])
237+
"""
238+
def __init__(self, numTopFeatures):
239+
"""
240+
:param numTopFeatures: number of features that selector will select.
241+
"""
242+
self.numTopFeatures = int(numTopFeatures)
243+
244+
def fit(self, data):
245+
"""
246+
Returns a ChiSquared feature selector.
247+
248+
:param data: an `RDD[LabeledPoint]` containing the labeled dataset
249+
with categorical features. Real-valued features will be
250+
treated as categorical for each distinct value.
251+
Apply feature discretizer before using this function.
252+
"""
253+
jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
254+
return ChiSqSelectorModel(jmodel)
255+
256+
202257
class HashingTF(object):
203258
"""
204259
.. note:: Experimental

0 commit comments

Comments
 (0)