From cdaac997e2f1508fac5b1e0dc2453153a3d006b4 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 6 May 2015 19:51:54 +0800 Subject: [PATCH] Python API for ChiSqSelector --- .../mllib/api/python/PythonMLLibAPI.scala | 10 ++++ python/pyspark/mllib/feature.py | 59 ++++++++++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 426306d78c1c3..8c30ad4b391ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -494,6 +494,16 @@ private[python] class PythonMLLibAPI extends Serializable { new StandardScaler(withMean, withStd).fit(data.rdd) } + /** + * Java stub for ChiSqSelector.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitChiSqSelector(numTopFeatures: Int, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { + new ChiSqSelector(numTopFeatures).fit(data.rdd) + } + /** * Java stub for IDF.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 1140539a24e95..aac305db6c19a 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -33,10 +33,12 @@ from pyspark import SparkContext from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, _convert_to_vector +from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.regression import LabeledPoint __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', - 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] + 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', + 'ChiSqSelector', 'ChiSqSelectorModel'] class VectorTransformer(object): @@ -199,6 +201,59 @@ def fit(self, dataset): return StandardScalerModel(jmodel) +class ChiSqSelectorModel(JavaVectorTransformer): + """ + .. note:: Experimental + + Represents a Chi Squared selector model. + """ + def transform(self, vector): + """ + Applies transformation on a vector. + + :param vector: Vector or RDD of Vector to be transformed. + :return: transformed vector. + """ + return JavaVectorTransformer.transform(self, vector) + + +class ChiSqSelector(object): + """ + .. note:: Experimental + + Creates a ChiSquared feature selector. + + >>> data = [ + ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), + ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), + ... LabeledPoint(1.0, [0.0, 9.0, 8.0]), + ... LabeledPoint(2.0, [8.0, 9.0, 5.0]) + ... ] + >>> model = ChiSqSelector(1).fit(sc.parallelize(data)) + >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) + SparseVector(1, {0: 6.0}) + >>> model.transform(DenseVector([8.0, 9.0, 5.0])) + DenseVector([5.0]) + """ + def __init__(self, numTopFeatures): + """ + :param numTopFeatures: number of features that selector will select. + """ + self.numTopFeatures = int(numTopFeatures) + + def fit(self, data): + """ + Returns a ChiSquared feature selector. + + :param data: an `RDD[LabeledPoint]` containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + Apply feature discretizer before using this function. + """ + jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data) + return ChiSqSelectorModel(jmodel) + + class HashingTF(object): """ .. note:: Experimental