@@ -73,11 +73,36 @@ def predict(self, x):
73
73
74
74
class LogisticRegressionWithSGD (object ):
75
75
@classmethod
76
- def train (cls , data , iterations = 100 , step = 1.0 , miniBatchFraction = 1.0 , initialWeights = None ):
77
- """Train a logistic regression model on the given data."""
76
+ def train (cls , data , iterations = 100 , step = 1.0 , miniBatchFraction = 1.0 ,
77
+ initialWeights = None , regParam = 1.0 , regType = None , intercept = False ):
78
+ """
79
+ Train a logistic regression model on the given data.
80
+
81
+ @param data: The training data.
82
+ @param iterations: The number of iterations (default: 100).
83
+ @param step: The step parameter used in SGD
84
+ (default: 1.0).
85
+ @param miniBatchFraction: Fraction of data to be used for each SGD
86
+ iteration.
87
+ @param initialWeights: The initial weights (default: None).
88
+ @param regParam: The regularizer parameter (default: 1.0).
89
+ @param regType: The type of regularizer used for training
90
+ our model.
91
+ Allowed values: "l1" for using L1Updater,
92
+ "l2" for using
93
+ SquaredL2Updater,
94
+ "none" for no regularizer.
95
+ (default: "none")
96
+ @param intercept: Boolean parameter which indicates the use
97
+ or not of the augmented representation for
98
+ training data (i.e. whether bias features
99
+ are activated or not).
100
+ """
78
101
sc = data .context
102
+ if regType is None :
103
+ regType = "none"
79
104
train_func = lambda d , i : sc ._jvm .PythonMLLibAPI ().trainLogisticRegressionModelWithSGD (
80
- d ._jrdd , iterations , step , miniBatchFraction , i )
105
+ d ._jrdd , iterations , step , miniBatchFraction , i , regParam , regType , intercept )
81
106
return _regression_train_wrapper (sc , train_func , LogisticRegressionModel , data ,
82
107
initialWeights )
83
108
@@ -115,11 +140,35 @@ def predict(self, x):
115
140
class SVMWithSGD (object ):
116
141
@classmethod
117
142
def train (cls , data , iterations = 100 , step = 1.0 , regParam = 1.0 ,
118
- miniBatchFraction = 1.0 , initialWeights = None ):
119
- """Train a support vector machine on the given data."""
143
+ miniBatchFraction = 1.0 , initialWeights = None , regType = None , intercept = False ):
144
+ """
145
+ Train a support vector machine on the given data.
146
+
147
+ @param data: The training data.
148
+ @param iterations: The number of iterations (default: 100).
149
+ @param step: The step parameter used in SGD
150
+ (default: 1.0).
151
+ @param regParam: The regularizer parameter (default: 1.0).
152
+ @param miniBatchFraction: Fraction of data to be used for each SGD
153
+ iteration.
154
+ @param initialWeights: The initial weights (default: None).
155
+ @param regType: The type of regularizer used for training
156
+ our model.
157
+ Allowed values: "l1" for using L1Updater,
158
+ "l2" for using
159
+ SquaredL2Updater,
160
+ "none" for no regularizer.
161
+ (default: "none")
162
+ @param intercept: Boolean parameter which indicates the use
163
+ or not of the augmented representation for
164
+ training data (i.e. whether bias features
165
+ are activated or not).
166
+ """
120
167
sc = data .context
168
+ if regType is None :
169
+ regType = "none"
121
170
train_func = lambda d , i : sc ._jvm .PythonMLLibAPI ().trainSVMModelWithSGD (
122
- d ._jrdd , iterations , step , regParam , miniBatchFraction , i )
171
+ d ._jrdd , iterations , step , regParam , miniBatchFraction , i , regType , intercept )
123
172
return _regression_train_wrapper (sc , train_func , SVMModel , data , initialWeights )
124
173
125
174
0 commit comments