Skip to content

Commit c18dca1

Browse files
committed
make the example working
1 parent dadd84e commit c18dca1

File tree

5 files changed

+45
-13
lines changed

5 files changed

+45
-13
lines changed

examples/src/main/python/ml/simple_text_classification_pipeline.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
118
from pyspark import SparkContext
219
from pyspark.sql import SQLContext, Row
320
from pyspark.ml import Pipeline
@@ -8,7 +25,10 @@
825
sc = SparkContext(appName="SimpleTextClassificationPipeline")
926
sqlCtx = SQLContext(sc)
1027
training = sqlCtx.inferSchema(
11-
sc.parallelize([(0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), (3L, "hadoop mapreduce", 0.0)]) \
28+
sc.parallelize([(0L, "a b c d e spark", 1.0),
29+
(1L, "b d", 0.0),
30+
(2L, "spark f g h", 1.0),
31+
(3L, "hadoop mapreduce", 0.0)]) \
1232
.map(lambda x: Row(id=x[0], text=x[1], label=x[2])))
1333

1434
tokenizer = Tokenizer() \
@@ -26,7 +46,10 @@
2646
model = pipeline.fit(training)
2747

2848
test = sqlCtx.inferSchema(
29-
sc.parallelize([(4L, "spark i j k"), (5L, "l m n"), (6L, "mapreduce spark"), (7L, "apache hadoop")]) \
49+
sc.parallelize([(4L, "spark i j k"),
50+
(5L, "l m n"),
51+
(6L, "mapreduce spark"),
52+
(7L, "apache hadoop")]) \
3053
.map(lambda x: Row(id=x[0], text=x[1])))
3154

3255
for row in model.transform(test).collect():

python/pyspark/ml/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from abc import ABCMeta, abstractmethod
1919

2020
from pyspark import SparkContext
21-
from pyspark.sql import inherit_doc
21+
from pyspark.sql import inherit_doc # TODO: move inherit_doc to Spark Core
2222
from pyspark.ml.param import Param, Params
2323
from pyspark.ml.util import Identifiable
2424

@@ -37,7 +37,7 @@ class PipelineStage(Params):
3737
"""
3838

3939
def __init__(self):
40-
super.__init__(self)
40+
super(PipelineStage, self).__init__()
4141

4242

4343
@inherit_doc
@@ -49,7 +49,7 @@ class Estimator(PipelineStage):
4949
__metaclass__ = ABCMeta
5050

5151
def __init__(self):
52-
super.__init__(self)
52+
super(Estimator, self).__init__()
5353

5454
@abstractmethod
5555
def fit(self, dataset, params={}):
@@ -74,6 +74,9 @@ class Transformer(PipelineStage):
7474

7575
__metaclass__ = ABCMeta
7676

77+
def __init__(self):
78+
super(Transformer, self).__init__()
79+
7780
@abstractmethod
7881
def transform(self, dataset, params={}):
7982
"""
@@ -109,7 +112,7 @@ class Pipeline(Estimator):
109112
"""
110113

111114
def __init__(self):
112-
super.__init__(self)
115+
super(Pipeline, self).__init__()
113116
#: Param for pipeline stages.
114117
self.stages = Param(self, "stages", "pipeline stages")
115118

@@ -139,13 +142,17 @@ def fit(self, dataset):
139142
model = stage.fit(dataset)
140143
transformers.append(model)
141144
dataset = model.transform(dataset)
145+
else:
146+
raise ValueError(
147+
"Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
142148
return PipelineModel(transformers)
143149

144150

145151
@inherit_doc
146152
class PipelineModel(Transformer):
147153

148154
def __init__(self, transformers):
155+
super(PipelineModel, self).__init__()
149156
self.transformers = transformers
150157

151158
def transform(self, dataset):

python/pyspark/ml/feature.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
#
1717

1818
from pyspark.sql import SchemaRDD, ArrayType, StringType
19-
from pyspark.ml import _jvm
19+
from pyspark.ml import Transformer, _jvm
2020
from pyspark.ml.param import Param
2121

22-
23-
class Tokenizer(object):
22+
class Tokenizer(Transformer):
2423

2524
def __init__(self):
25+
super(Tokenizer, self).__init__()
2626
self.inputCol = Param(self, "inputCol", "input column name", None)
2727
self.outputCol = Param(self, "outputCol", "output column name", None)
2828
self.paramMap = {}
@@ -61,9 +61,10 @@ def transform(self, dataset, params={}):
6161
raise ValueError("The input params must be either a dict or a list.")
6262

6363

64-
class HashingTF(object):
64+
class HashingTF(Transformer):
6565

6666
def __init__(self):
67+
super(HashingTF, self).__init__()
6768
self._java_obj = _jvm().org.apache.spark.ml.feature.HashingTF()
6869
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
6970
self.inputCol = Param(self, "inputCol", "input column name")

python/pyspark/ml/param.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from abc import ABCMeta, abstractmethod
18+
from abc import ABCMeta
1919

2020
from pyspark.ml.util import Identifiable
2121

@@ -50,11 +50,10 @@ class Params(Identifiable):
5050
__metaclass__ = ABCMeta
5151

5252
def __init__(self):
53-
super.__init__(self)
53+
super(Params, self).__init__()
5454
#: Internal param map.
5555
self.paramMap = {}
5656

57-
@abstractmethod
5857
def params(self):
5958
"""
6059
Returns all params. The default implementation uses

python/pyspark/ml/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616
#
1717

18+
import uuid
19+
1820

1921
class Identifiable(object):
2022
"""

0 commit comments

Comments
 (0)