Skip to content

Commit 33b68e0

Browse files
committed
a working LR
1 parent c233ab3 commit 33b68e0

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

python/pyspark/ml/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import inspect
2+
3+
from pyspark import SparkContext
4+
5+
# An implementation of PEP3102 for Python 2.
6+
_keyword_only_secret = 70861589
7+
8+
9+
def _assert_keyword_only_args():
10+
"""
11+
Checks whether the _keyword_only trick is applied and validates input arguments.
12+
"""
13+
# Get the frame of the function that calls this function.
14+
frame = inspect.currentframe().f_back
15+
info = inspect.getargvalues(frame)
16+
if "_keyword_only" not in info.args:
17+
raise ValueError("Function does not have argument _keyword_only.")
18+
if info.locals["_keyword_only"] != _keyword_only_secret:
19+
raise ValueError("Must use keyword arguments instead of positional ones.")
20+
21+
def _jvm():
22+
return SparkContext._jvm

python/pyspark/ml/classification.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from pyspark.sql import SchemaRDD
2+
from pyspark.ml import _keyword_only_secret, _assert_keyword_only_args, _jvm
3+
from pyspark.ml.param import Param
4+
5+
6+
class LogisticRegression(object):
7+
"""
8+
Logistic regression.
9+
"""
10+
11+
_java_class = "org.apache.spark.ml.classification.LogisticRegression"
12+
13+
def __init__(self):
14+
self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression()
15+
self.paramMap = {}
16+
self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
17+
self.regParam = Param(self, "regParam", "regularization constant", 0.1)
18+
19+
def set(self, _keyword_only=_keyword_only_secret,
20+
maxIter=None, regParam=None):
21+
_assert_keyword_only_args()
22+
if maxIter is not None:
23+
self.paramMap[self.maxIter] = maxIter
24+
if regParam is not None:
25+
self.paramMap[self.regParam] = regParam
26+
return self
27+
28+
# cannot chained
29+
def setMaxIter(self, value):
30+
self.paramMap[self.maxIter] = value
31+
return self
32+
33+
def setRegParam(self, value):
34+
self.paramMap[self.regParam] = value
35+
return self
36+
37+
def getMaxIter(self):
38+
if self.maxIter in self.paramMap:
39+
return self.paramMap[self.maxIter]
40+
else:
41+
return self.maxIter.defaultValue
42+
43+
def getRegParam(self):
44+
if self.regParam in self.paramMap:
45+
return self.paramMap[self.regParam]
46+
else:
47+
return self.regParam.defaultValue
48+
49+
def fit(self, dataset):
50+
java_model = self._java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())
51+
return LogisticRegressionModel(java_model)
52+
53+
54+
class LogisticRegressionModel(object):
55+
"""
56+
Model fitted by LogisticRegression.
57+
"""
58+
59+
def __init__(self, _java_model):
60+
self._java_model = _java_model
61+
62+
def transform(self, dataset):
63+
return SchemaRDD(self._java_model.transform(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)
64+
65+
lr = LogisticRegression()
66+
67+
lr.set(maxIter=10, regParam=0.1)

python/pyspark/ml/param.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
class Param(object):
2+
"""
3+
A param with self-contained documentation and optionally default value.
4+
"""
5+
6+
def __init__(self, parent, name, doc, defaultValue=None):
7+
self.parent = parent
8+
self.name = name
9+
self.doc = doc
10+
self.defaultValue = defaultValue
11+
12+
def __str__(self):
13+
return self.parent + "_" + self.name
14+
15+
def __repr_(self):
16+
return self.parent + "_" + self.name
17+
18+
19+
class Params(object):
20+
"""
21+
Components that take parameters.
22+
"""

0 commit comments

Comments
 (0)