Skip to content

[SPARK-5913] [MLLIB] Python API for ChiSqSelector #5939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ private[python] class PythonMLLibAPI extends Serializable {
new StandardScaler(withMean, withStd).fit(data.rdd)
}

/**
* Java stub for ChiSqSelector.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
def fitChiSqSelector(numTopFeatures: Int, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
new ChiSqSelector(numTopFeatures).fit(data.rdd)
}

/**
* Java stub for IDF.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
Expand Down
59 changes: 57 additions & 2 deletions python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
from pyspark import SparkContext
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import Vectors, _convert_to_vector
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint

__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
'ChiSqSelector', 'ChiSqSelectorModel']


class VectorTransformer(object):
Expand Down Expand Up @@ -199,6 +201,59 @@ def fit(self, dataset):
return StandardScalerModel(jmodel)


class ChiSqSelectorModel(JavaVectorTransformer):
"""
.. note:: Experimental

Represents a Chi Squared selector model.
"""
def transform(self, vector):
"""
Applies transformation on a vector.

:param vector: Vector or RDD of Vector to be transformed.
:return: transformed vector.
"""
return JavaVectorTransformer.transform(self, vector)


class ChiSqSelector(object):
"""
.. note:: Experimental

Creates a ChiSquared feature selector.

>>> data = [
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
... LabeledPoint(1.0, [0.0, 9.0, 8.0]),
... LabeledPoint(2.0, [8.0, 9.0, 5.0])
... ]
>>> model = ChiSqSelector(1).fit(sc.parallelize(data))
>>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
SparseVector(1, {0: 6.0})
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
DenseVector([5.0])
"""
def __init__(self, numTopFeatures):
"""
:param numTopFeatures: number of features that selector will select.
"""
self.numTopFeatures = int(numTopFeatures)

def fit(self, data):
"""
Returns a ChiSquared feature selector.

:param data: an `RDD[LabeledPoint]` containing the labeled dataset
with categorical features. Real-valued features will be
treated as categorical for each distinct value.
Apply feature discretizer before using this function.
"""
jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
return ChiSqSelectorModel(jmodel)


class HashingTF(object):
"""
.. note:: Experimental
Expand Down