Skip to content

Commit b3be094

Browse files
committed
refactor schema transform in lr
1 parent 8791e8e commit b3be094

File tree

1 file changed

+35
-30
lines changed

1 file changed

+35
-30
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,39 @@ import org.apache.spark.storage.StorageLevel
3232
*/
3333
private[classification] trait LogisticRegressionParams extends Params
3434
with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
35-
with HasScoreCol with HasPredictionCol
35+
with HasScoreCol with HasPredictionCol {
36+
37+
/**
38+
* Validates and transforms the input schema with the provided param map.
39+
* @param schema input schema
40+
* @param paramMap additional parameters
41+
* @param fitting whether this is in fitting
42+
* @return output schema
43+
*/
44+
protected def transformSchema(
45+
schema: StructType,
46+
paramMap: ParamMap,
47+
fitting: Boolean): StructType = {
48+
val map = this.paramMap ++ paramMap
49+
val featuresType = schema(map(featuresCol)).dataType
50+
// TODO: Support casting Array[Double] and Array[Float] to Vector.
51+
require(featuresType.isInstanceOf[VectorUDT],
52+
s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
53+
if (fitting) {
54+
val labelType = schema(map(labelCol)).dataType
55+
require(labelType == DoubleType,
56+
s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
57+
}
58+
val fieldNames = schema.fieldNames
59+
require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
60+
require(!fieldNames.contains(map(predictionCol)),
61+
s"Prediction column ${map(predictionCol)} already exists.")
62+
val outputFields = schema.fields ++ Seq(
63+
StructField(map(scoreCol), DoubleType, false),
64+
StructField(map(predictionCol), DoubleType, false))
65+
StructType(outputFields)
66+
}
67+
}
3668

3769
/**
3870
* Logistic regression.
@@ -71,22 +103,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
71103
}
72104

73105
override def transform(schema: StructType, paramMap: ParamMap): StructType = {
74-
val map = this.paramMap ++ paramMap
75-
val featuresType = schema(map(featuresCol)).dataType
76-
// TODO: Support casting Array[Double] and Array[Float] to Vector.
77-
require(featuresType.isInstanceOf[VectorUDT],
78-
s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
79-
val labelType = schema(map(labelCol)).dataType
80-
require(labelType == DoubleType,
81-
s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
82-
val fieldNames = schema.fieldNames
83-
require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
84-
require(!fieldNames.contains(map(predictionCol)),
85-
s"Prediction column ${map(predictionCol)} already exists.")
86-
val outputFields = schema.fields ++ Seq(
87-
StructField(map(scoreCol), DoubleType, false),
88-
StructField(map(predictionCol), DoubleType, false))
89-
StructType(outputFields)
106+
transformSchema(schema, paramMap, fitting = true)
90107
}
91108
}
92109

@@ -104,19 +121,7 @@ class LogisticRegressionModel private[ml] (
104121
def setPredictionCol(value: String): this.type = { set(predictionCol, value); this }
105122

106123
override def transform(schema: StructType, paramMap: ParamMap): StructType = {
107-
val map = this.paramMap ++ paramMap
108-
val featuresType = schema(map(featuresCol)).dataType
109-
// TODO: Support casting Array[Double] and Array[Float] to Vector.
110-
require(featuresType.isInstanceOf[VectorUDT],
111-
s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
112-
val fieldNames = schema.fieldNames
113-
require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
114-
require(!fieldNames.contains(map(predictionCol)),
115-
s"Prediction column ${map(predictionCol)} already exists.")
116-
val outputFields = schema.fields ++ Seq(
117-
StructField(map(scoreCol), DoubleType, false),
118-
StructField(map(predictionCol), DoubleType, false))
119-
StructType(outputFields)
124+
transformSchema(schema, paramMap, fitting = false)
120125
}
121126

122127
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {

0 commit comments

Comments
 (0)