Skip to content

Commit 46eea43

Browse files
committed
a pipeline in python
1 parent 33b68e0 commit 46eea43

File tree

7 files changed

+200
-35
lines changed

7 files changed

+200
-35
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pyspark import SparkContext
2+
from pyspark.sql import SQLContext, Row
3+
from pyspark.ml import Pipeline
4+
from pyspark.ml.feature import HashingTF, Tokenizer
5+
from pyspark.ml.classification import LogisticRegression
6+
7+
if __name__ == "__main__":
8+
sc = SparkContext(appName="SimpleTextClassificationPipeline")
9+
sqlCtx = SQLContext(sc)
10+
training = sqlCtx.inferSchema(
11+
sc.parallelize([(0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), (3L, "hadoop mapreduce", 0.0)]) \
12+
.map(lambda x: Row(id=x[0], text=x[1], label=x[2])))
13+
14+
tokenizer = Tokenizer() \
15+
.setInputCol("text") \
16+
.setOutputCol("words")
17+
hashingTF = HashingTF() \
18+
.setInputCol(tokenizer.getOutputCol()) \
19+
.setOutputCol("features")
20+
lr = LogisticRegression() \
21+
.setMaxIter(10) \
22+
.setRegParam(0.01)
23+
pipeline = Pipeline() \
24+
.setStages([tokenizer, hashingTF, lr])
25+
26+
model = pipeline.fit(training)
27+
28+
test = sqlCtx.inferSchema(
29+
sc.parallelize([(4L, "spark i j k"), (5L, "l m n"), (6L, "mapreduce spark"), (7L, "apache hadoop")]) \
30+
.map(lambda x: Row(id=x[0], text=x[1])))
31+
32+
for row in model.transform(test).collect():
33+
print row

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@ trait Params extends Identifiable with Serializable {
164164
this
165165
}
166166

167+
/**
168+
* Sets a parameter (by name) in the embedded param map.
169+
*/
170+
private[ml] def set(param: String, value: Any): this.type = {
171+
set(getParam(param), value)
172+
}
173+
167174
/**
168175
* Gets the value of a parameter in the embedded param map.
169176
*/

python/pyspark/ml/__init__.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import inspect
22

33
from pyspark import SparkContext
4+
from pyspark.ml.param import Param
5+
6+
__all__ = ["Pipeline"]
47

58
# An implementation of PEP3102 for Python 2.
69
_keyword_only_secret = 70861589
@@ -20,3 +23,40 @@ def _assert_keyword_only_args():
2023

2124
def _jvm():
2225
return SparkContext._jvm
26+
27+
class Pipeline(object):
28+
29+
def __init__(self):
30+
self.stages = Param(self, "stages", "pipeline stages")
31+
self.paramMap = {}
32+
33+
def setStages(self, value):
34+
self.paramMap[self.stages] = value
35+
return self
36+
37+
def getStages(self):
38+
if self.stages in self.paramMap:
39+
return self.paramMap[self.stages]
40+
41+
def fit(self, dataset):
42+
transformers = []
43+
for stage in self.getStages():
44+
if hasattr(stage, "transform"):
45+
transformers.append(stage)
46+
dataset = stage.transform(dataset)
47+
elif hasattr(stage, "fit"):
48+
model = stage.fit(dataset)
49+
transformers.append(model)
50+
dataset = model.transform(dataset)
51+
return PipelineModel(transformers)
52+
53+
54+
class PipelineModel(object):
55+
56+
def __init__(self, transformers):
57+
self.transformers = transformers
58+
59+
def transform(self, dataset):
60+
for t in self.transformers:
61+
dataset = t.transform(dataset)
62+
return dataset

python/pyspark/ml/classification.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pyspark.sql import SchemaRDD
2-
from pyspark.ml import _keyword_only_secret, _assert_keyword_only_args, _jvm
2+
from pyspark.ml import _jvm
33
from pyspark.ml.param import Param
44

55

@@ -8,45 +8,39 @@ class LogisticRegression(object):
88
Logistic regression.
99
"""
1010

11-
_java_class = "org.apache.spark.ml.classification.LogisticRegression"
11+
# _java_class = "org.apache.spark.ml.classification.LogisticRegression"
1212

1313
def __init__(self):
1414
self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression()
15-
self.paramMap = {}
1615
self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
1716
self.regParam = Param(self, "regParam", "regularization constant", 0.1)
17+
self.featuresCol = Param(self, "featuresCol", "features column name", "features")
1818

19-
def set(self, _keyword_only=_keyword_only_secret,
20-
maxIter=None, regParam=None):
21-
_assert_keyword_only_args()
22-
if maxIter is not None:
23-
self.paramMap[self.maxIter] = maxIter
24-
if regParam is not None:
25-
self.paramMap[self.regParam] = regParam
26-
return self
27-
28-
# cannot chained
2919
def setMaxIter(self, value):
30-
self.paramMap[self.maxIter] = value
20+
self._java_obj.setMaxIter(value)
3121
return self
3222

23+
def getMaxIter(self):
24+
return self._java_obj.getMaxIter()
25+
3326
def setRegParam(self, value):
34-
self.paramMap[self.regParam] = value
27+
self._java_obj.setRegParam(value)
3528
return self
3629

37-
def getMaxIter(self):
38-
if self.maxIter in self.paramMap:
39-
return self.paramMap[self.maxIter]
40-
else:
41-
return self.maxIter.defaultValue
42-
4330
def getRegParam(self):
44-
if self.regParam in self.paramMap:
45-
return self.paramMap[self.regParam]
46-
else:
47-
return self.regParam.defaultValue
31+
return self._java_obj.getRegParam()
32+
33+
def setFeaturesCol(self, value):
34+
self._java_obj.setFeaturesCol(value)
35+
return self
4836

49-
def fit(self, dataset):
37+
def getFeaturesCol(self):
38+
return self._java_obj.getFeaturesCol()
39+
40+
def fit(self, dataset, params=None):
41+
"""
42+
Fits a dataset with optional parameters.
43+
"""
5044
java_model = self._java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())
5145
return LogisticRegressionModel(java_model)
5246

@@ -62,6 +56,3 @@ def __init__(self, _java_model):
6256
def transform(self, dataset):
6357
return SchemaRDD(self._java_model.transform(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)
6458

65-
lr = LogisticRegression()
66-
67-
lr.set(maxIter=10, regParam=0.1)

python/pyspark/ml/feature.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from pyspark.sql import SchemaRDD, ArrayType, StringType
2+
from pyspark.ml import _jvm
3+
from pyspark.ml.param import Param
4+
5+
6+
class Tokenizer(object):
7+
8+
def __init__(self):
9+
self.inputCol = Param(self, "inputCol", "input column name", None)
10+
self.outputCol = Param(self, "outputCol", "output column name", None)
11+
self.paramMap = {}
12+
13+
def setInputCol(self, value):
14+
self.paramMap[self.inputCol] = value
15+
return self
16+
17+
def getInputCol(self):
18+
if self.inputCol in self.paramMap:
19+
return self.paramMap[self.inputCol]
20+
21+
def setOutputCol(self, value):
22+
self.paramMap[self.outputCol] = value
23+
return self
24+
25+
def getOutputCol(self):
26+
if self.outputCol in self.paramMap:
27+
return self.paramMap[self.outputCol]
28+
29+
def transform(self, dataset, params={}):
30+
sqlCtx = dataset.sql_ctx
31+
if isinstance(params, dict):
32+
paramMap = self.paramMap.copy()
33+
paramMap.update(params)
34+
inputCol = paramMap[self.inputCol]
35+
outputCol = paramMap[self.outputCol]
36+
# TODO: make names unique
37+
sqlCtx.registerFunction("tokenize", lambda text: text.split(),
38+
ArrayType(StringType(), False))
39+
dataset.registerTempTable("dataset")
40+
return sqlCtx.sql("SELECT *, tokenize(%s) AS %s FROM dataset" % (inputCol, outputCol))
41+
elif isinstance(params, list):
42+
return [self.transform(dataset, paramMap) for paramMap in params]
43+
else:
44+
raise ValueError("The input params must be either a dict or a list.")
45+
46+
47+
class HashingTF(object):
48+
49+
def __init__(self):
50+
self._java_obj = _jvm().org.apache.spark.ml.feature.HashingTF()
51+
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
52+
self.inputCol = Param(self, "inputCol", "input column name")
53+
self.outputCol = Param(self, "outputCol", "output column name")
54+
55+
def setNumFeatures(self, value):
56+
self._java_obj.setNumFeatures(value)
57+
return self
58+
59+
def getNumFeatures(self):
60+
return self._java_obj.getNumFeatures()
61+
62+
def setInputCol(self, value):
63+
self._java_obj.setInputCol(value)
64+
return self
65+
66+
def getInputCol(self):
67+
return self._java_obj.getInputCol()
68+
69+
def setOutputCol(self, value):
70+
self._java_obj.setOutputCol(value)
71+
return self
72+
73+
def getOutputCol(self):
74+
return self._java_obj.getOutputCol()
75+
76+
def transform(self, dataset, paramMap={}):
77+
if isinstance(paramMap, dict):
78+
javaParamMap = _jvm().org.apache.spark.ml.param.ParamMap()
79+
for k, v in paramMap.items():
80+
param = self._java_obj.getParam(k.name)
81+
javaParamMap.put(param, v)
82+
return SchemaRDD(self._java_obj.transform(dataset._jschema_rdd, javaParamMap),
83+
dataset.sql_ctx)
84+
else:
85+
raise ValueError("paramMap must be a dict.")

python/pyspark/ml/param.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,3 @@ def __str__(self):
1414

1515
def __repr_(self):
1616
return self.parent + "_" + self.name
17-
18-
19-
class Params(object):
20-
"""
21-
Components that take parameters.
22-
"""

python/pyspark/ml/test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import subprocess
2+
3+
def funcA(dataset, **kwargs):
4+
"""
5+
funcA
6+
:param dataset:
7+
:param kwargs:
8+
9+
:return:
10+
"""
11+
pass
12+
13+
14+
dataset = []
15+
funcA(dataset, )

0 commit comments

Comments
 (0)