@@ -32,7 +32,39 @@ import org.apache.spark.storage.StorageLevel
32
32
*/
33
33
private [classification] trait LogisticRegressionParams extends Params
34
34
with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
35
- with HasScoreCol with HasPredictionCol
35
+ with HasScoreCol with HasPredictionCol {
36
+
37
+ /**
38
+ * Validates and transforms the input schema with the provided param map.
39
+ * @param schema input schema
40
+ * @param paramMap additional parameters
41
+ * @param fitting whether this is in fitting
42
+ * @return output schema
43
+ */
44
+ protected def transformSchema (
45
+ schema : StructType ,
46
+ paramMap : ParamMap ,
47
+ fitting : Boolean ): StructType = {
48
+ val map = this .paramMap ++ paramMap
49
+ val featuresType = schema(map(featuresCol)).dataType
50
+ // TODO: Support casting Array[Double] and Array[Float] to Vector.
51
+ require(featuresType.isInstanceOf [VectorUDT ],
52
+ s " Features column ${map(featuresCol)} must be a vector column but got $featuresType. " )
53
+ if (fitting) {
54
+ val labelType = schema(map(labelCol)).dataType
55
+ require(labelType == DoubleType ,
56
+ s " Cannot convert label column ${map(labelCol)} of type $labelType to a double column. " )
57
+ }
58
+ val fieldNames = schema.fieldNames
59
+ require(! fieldNames.contains(map(scoreCol)), s " Score column ${map(scoreCol)} already exists. " )
60
+ require(! fieldNames.contains(map(predictionCol)),
61
+ s " Prediction column ${map(predictionCol)} already exists. " )
62
+ val outputFields = schema.fields ++ Seq (
63
+ StructField (map(scoreCol), DoubleType , false ),
64
+ StructField (map(predictionCol), DoubleType , false ))
65
+ StructType (outputFields)
66
+ }
67
+ }
36
68
37
69
/**
38
70
* Logistic regression.
@@ -71,22 +103,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
71
103
}
72
104
73
105
override def transform (schema : StructType , paramMap : ParamMap ): StructType = {
74
- val map = this .paramMap ++ paramMap
75
- val featuresType = schema(map(featuresCol)).dataType
76
- // TODO: Support casting Array[Double] and Array[Float] to Vector.
77
- require(featuresType.isInstanceOf [VectorUDT ],
78
- s " Features column ${map(featuresCol)} must be a vector column but got $featuresType. " )
79
- val labelType = schema(map(labelCol)).dataType
80
- require(labelType == DoubleType ,
81
- s " Cannot convert label column ${map(labelCol)} of type $labelType to a double column. " )
82
- val fieldNames = schema.fieldNames
83
- require(! fieldNames.contains(map(scoreCol)), s " Score column ${map(scoreCol)} already exists. " )
84
- require(! fieldNames.contains(map(predictionCol)),
85
- s " Prediction column ${map(predictionCol)} already exists. " )
86
- val outputFields = schema.fields ++ Seq (
87
- StructField (map(scoreCol), DoubleType , false ),
88
- StructField (map(predictionCol), DoubleType , false ))
89
- StructType (outputFields)
106
+ transformSchema(schema, paramMap, fitting = true )
90
107
}
91
108
}
92
109
@@ -104,19 +121,7 @@ class LogisticRegressionModel private[ml] (
104
121
def setPredictionCol (value : String ): this .type = { set(predictionCol, value); this }
105
122
106
123
override def transform (schema : StructType , paramMap : ParamMap ): StructType = {
107
- val map = this .paramMap ++ paramMap
108
- val featuresType = schema(map(featuresCol)).dataType
109
- // TODO: Support casting Array[Double] and Array[Float] to Vector.
110
- require(featuresType.isInstanceOf [VectorUDT ],
111
- s " Features column ${map(featuresCol)} must be a vector column but got $featuresType. " )
112
- val fieldNames = schema.fieldNames
113
- require(! fieldNames.contains(map(scoreCol)), s " Score column ${map(scoreCol)} already exists. " )
114
- require(! fieldNames.contains(map(predictionCol)),
115
- s " Prediction column ${map(predictionCol)} already exists. " )
116
- val outputFields = schema.fields ++ Seq (
117
- StructField (map(scoreCol), DoubleType , false ),
118
- StructField (map(predictionCol), DoubleType , false ))
119
- StructType (outputFields)
124
+ transformSchema(schema, paramMap, fitting = false )
120
125
}
121
126
122
127
override def transform (dataset : SchemaRDD , paramMap : ParamMap ): SchemaRDD = {
0 commit comments