Skip to content

Commit d0c5bb8

Browse files
committed
a working copy
1 parent bce72f4 commit d0c5bb8

File tree

7 files changed

+127
-108
lines changed

7 files changed

+127
-108
lines changed

examples/src/main/python/ml/simple_text_classification_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pyspark.ml.feature import HashingTF, Tokenizer
2222
from pyspark.ml.classification import LogisticRegression
2323

24+
2425
if __name__ == "__main__":
2526
sc = SparkContext(appName="SimpleTextClassificationPipeline")
2627
sqlCtx = SQLContext(sc)

python/pyspark/ml/__init__.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
# limitations under the License.
1616
#
1717

18-
from abc import ABCMeta, abstractmethod
18+
from abc import ABCMeta, abstractmethod, abstractproperty
1919

2020
from pyspark import SparkContext
21-
from pyspark.sql import inherit_doc # TODO: move inherit_doc to Spark Core
21+
from pyspark.sql import SchemaRDD, inherit_doc # TODO: move inherit_doc to Spark Core
2222
from pyspark.ml.param import Param, Params
2323
from pyspark.ml.util import Identifiable
2424

@@ -146,16 +146,17 @@ def getStages(self):
146146
if self.stages in self.paramMap:
147147
return self.paramMap[self.stages]
148148

149-
def fit(self, dataset):
149+
def fit(self, dataset, params={}):
150+
map = self._merge_params(params)
150151
transformers = []
151152
for stage in self.getStages():
152153
if isinstance(stage, Transformer):
153154
transformers.append(stage)
154-
dataset = stage.transform(dataset)
155+
dataset = stage.transform(dataset, map)
155156
elif isinstance(stage, Estimator):
156-
model = stage.fit(dataset)
157+
model = stage.fit(dataset, map)
157158
transformers.append(model)
158-
dataset = model.transform(dataset)
159+
dataset = model.transform(dataset, map)
159160
else:
160161
raise ValueError(
161162
"Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
@@ -169,7 +170,65 @@ def __init__(self, transformers):
169170
super(PipelineModel, self).__init__()
170171
self.transformers = transformers
171172

172-
def transform(self, dataset):
173+
def transform(self, dataset, params={}):
174+
map = self._merge_params(params)
173175
for t in self.transformers:
174-
dataset = t.transform(dataset)
176+
dataset = t.transform(dataset, map)
175177
return dataset
178+
179+
180+
@inherit_doc
181+
class JavaWrapper(object):
182+
183+
__metaclass__ = ABCMeta
184+
185+
def __init__(self):
186+
super(JavaWrapper, self).__init__()
187+
188+
@abstractproperty
189+
def _java_class(self):
190+
raise NotImplementedError
191+
192+
def _create_java_obj(self):
193+
java_obj = _jvm()
194+
for name in self._java_class.split("."):
195+
java_obj = getattr(java_obj, name)
196+
return java_obj()
197+
198+
199+
@inherit_doc
200+
class JavaEstimator(Estimator, JavaWrapper):
201+
202+
__metaclass__ = ABCMeta
203+
204+
def __init__(self):
205+
super(JavaEstimator, self).__init__()
206+
207+
@abstractmethod
208+
def _create_model(self, java_model):
209+
raise NotImplementedError
210+
211+
def _fit_java(self, dataset, params={}):
212+
java_obj = self._create_java_obj()
213+
self._transfer_params_to_java(params, java_obj)
214+
return java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())
215+
216+
def fit(self, dataset, params={}):
217+
java_model = self._fit_java(dataset, params)
218+
return self._create_model(java_model)
219+
220+
221+
@inherit_doc
222+
class JavaTransformer(Transformer, JavaWrapper):
223+
224+
__metaclass__ = ABCMeta
225+
226+
def __init__(self):
227+
super(JavaTransformer, self).__init__()
228+
229+
def transform(self, dataset, params={}):
230+
java_obj = self._create_java_obj()
231+
self._transfer_params_to_java(params, java_obj)
232+
return SchemaRDD(java_obj.transform(dataset._jschema_rdd,
233+
_jvm().org.apache.spark.ml.param.ParamMap()),
234+
dataset.sql_ctx)

python/pyspark/ml/classification.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,42 +16,40 @@
1616
#
1717

1818
from pyspark.sql import SchemaRDD, inherit_doc
19-
from pyspark.ml import Estimator, Transformer, _jvm
19+
from pyspark.ml import JavaEstimator, Transformer, _jvm
2020
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
2121
HasRegParam
2222

2323

2424
@inherit_doc
25-
class LogisticRegression(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
25+
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
2626
HasRegParam):
2727
"""
2828
Logistic regression.
2929
"""
3030

31-
# _java_class = "org.apache.spark.ml.classification.LogisticRegression"
32-
3331
def __init__(self):
3432
super(LogisticRegression, self).__init__()
35-
self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression()
36-
37-
def fit(self, dataset, params=None):
38-
"""
39-
Fits a dataset with optional parameters.
40-
"""
41-
java_model = self._java_obj.fit(dataset._jschema_rdd,
42-
_jvm().org.apache.spark.ml.param.ParamMap())
33+
34+
@property
35+
def _java_class(self):
36+
return "org.apache.spark.ml.classification.LogisticRegression"
37+
38+
def _create_model(self, java_model):
4339
return LogisticRegressionModel(java_model)
4440

4541

42+
@inherit_doc
4643
class LogisticRegressionModel(Transformer):
4744
"""
4845
Model fitted by LogisticRegression.
4946
"""
5047

51-
def __init__(self, _java_model):
52-
self._java_model = _java_model
48+
def __init__(self, java_model):
49+
self._java_model = java_model
5350

54-
def transform(self, dataset):
51+
def transform(self, dataset, params={}):
52+
# TODO: handle params here.
5553
return SchemaRDD(self._java_model.transform(
5654
dataset._jschema_rdd,
5755
_jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)

python/pyspark/ml/feature.py

Lines changed: 18 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,91 +15,41 @@
1515
# limitations under the License.
1616
#
1717

18-
from pyspark.sql import SchemaRDD, ArrayType, StringType, inherit_doc
19-
from pyspark.ml import Transformer, _jvm
18+
from pyspark.sql import inherit_doc
19+
from pyspark.ml import JavaTransformer
2020
from pyspark.ml.param import Param
21+
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
22+
2123

2224
@inherit_doc
23-
class Tokenizer(Transformer):
25+
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
2426

2527
def __init__(self):
2628
super(Tokenizer, self).__init__()
27-
self.inputCol = Param(self, "inputCol", "input column name", None)
28-
self.outputCol = Param(self, "outputCol", "output column name", None)
29-
self.paramMap = {}
30-
31-
def setInputCol(self, value):
32-
self.paramMap[self.inputCol] = value
33-
return self
34-
35-
def getInputCol(self):
36-
if self.inputCol in self.paramMap:
37-
return self.paramMap[self.inputCol]
3829

39-
def setOutputCol(self, value):
40-
self.paramMap[self.outputCol] = value
41-
return self
42-
43-
def getOutputCol(self):
44-
if self.outputCol in self.paramMap:
45-
return self.paramMap[self.outputCol]
46-
47-
def transform(self, dataset, params={}):
48-
sqlCtx = dataset.sql_ctx
49-
if isinstance(params, dict):
50-
paramMap = self.paramMap.copy()
51-
paramMap.update(params)
52-
inputCol = paramMap[self.inputCol]
53-
outputCol = paramMap[self.outputCol]
54-
# TODO: make names unique
55-
sqlCtx.registerFunction("tokenize", lambda text: text.split(),
56-
ArrayType(StringType(), False))
57-
dataset.registerTempTable("dataset")
58-
return sqlCtx.sql("SELECT *, tokenize(%s) AS %s FROM dataset" % (inputCol, outputCol))
59-
elif isinstance(params, list):
60-
return [self.transform(dataset, paramMap) for paramMap in params]
61-
else:
62-
raise ValueError("The input params must be either a dict or a list.")
30+
@property
31+
def _java_class(self):
32+
return "org.apache.spark.ml.feature.Tokenizer"
6333

6434

6535
@inherit_doc
66-
class HashingTF(Transformer):
36+
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol):
6737

6838
def __init__(self):
6939
super(HashingTF, self).__init__()
70-
self._java_obj = _jvm().org.apache.spark.ml.feature.HashingTF()
40+
#: param for number of features
7141
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
72-
self.inputCol = Param(self, "inputCol", "input column name")
73-
self.outputCol = Param(self, "outputCol", "output column name")
42+
43+
@property
44+
def _java_class(self):
45+
return "org.apache.spark.ml.feature.HashingTF"
7446

7547
def setNumFeatures(self, value):
76-
self._java_obj.setNumFeatures(value)
48+
self.paramMap[self.numFeatures] = value
7749
return self
7850

7951
def getNumFeatures(self):
80-
return self._java_obj.getNumFeatures()
81-
82-
def setInputCol(self, value):
83-
self._java_obj.setInputCol(value)
84-
return self
85-
86-
def getInputCol(self):
87-
return self._java_obj.getInputCol()
88-
89-
def setOutputCol(self, value):
90-
self._java_obj.setOutputCol(value)
91-
return self
92-
93-
def getOutputCol(self):
94-
return self._java_obj.getOutputCol()
95-
96-
def transform(self, dataset, paramMap={}):
97-
if isinstance(paramMap, dict):
98-
javaParamMap = _jvm().org.apache.spark.ml.param.ParamMap()
99-
for k, v in paramMap.items():
100-
param = self._java_obj.getParam(k.name)
101-
javaParamMap.put(param, v)
102-
return SchemaRDD(self._java_obj.transform(dataset._jschema_rdd, javaParamMap),
103-
dataset.sql_ctx)
52+
if self.numFeatures in self.paramMap:
53+
return self.paramMap[self.numFeatures]
10454
else:
105-
raise ValueError("paramMap must be a dict.")
55+
return self.numFeatures.defaultValue

python/pyspark/ml/param/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,14 @@ def params(self):
6363
:py:class:`Param`.
6464
"""
6565
return filter(lambda x: isinstance(x, Param), map(lambda x: getattr(self, x), dir(self)))
66+
67+
def _merge_params(self, params):
68+
map = self.paramMap.copy()
69+
map.update(params)
70+
return map
71+
72+
def _transfer_params_to_java(self, params, java_obj):
73+
map = self._merge_params(params)
74+
for param in self.params():
75+
if param in map:
76+
java_obj.set(param.name, map[param])

python/pyspark/ml/param/_gen_shared_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ def set%s(self, value):
5454
self.paramMap[self.%s] = value
5555
return self
5656
57-
def get%s(self, value):
57+
def get%s(self):
5858
if self.%s in self.paramMap:
5959
return self.paramMap[self.%s]
6060
else:
61-
return self.defaultValue""" % (
61+
return self.%s.defaultValue""" % (
6262
upperCamelName, upperCamelName, doc, name, name, doc, defaultValue, upperCamelName, name,
63-
upperCamelName, name, name)
63+
upperCamelName, name, name, name)
6464

6565
if __name__ == "__main__":
6666
print header

0 commit comments

Comments
 (0)