Skip to content

Commit acf3e17

Browse files
committed
update checkInputColumn to print more info if needed
1 parent 8993c0e commit acf3e17

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ import scala.annotation.varargs
2424
import scala.collection.mutable
2525

2626
import org.apache.spark.annotation.AlphaComponent
27-
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
28-
import org.apache.spark.sql.types.{DataType, StructType}
27+
import org.apache.spark.ml.util.Identifiable
2928

3029
/**
3130
* :: AlphaComponent ::
@@ -381,18 +380,6 @@ trait Params extends Identifiable with Serializable {
381380
this
382381
}
383382

384-
/**
385-
* Check whether the given schema contains an input column.
386-
* @param colName Input column name
387-
* @param dataType Input column DataType
388-
*/
389-
protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = {
390-
val actualDataType = schema(colName).dataType
391-
SchemaUtils.checkColumnType(schema, colName, dataType)
392-
require(actualDataType.equals(dataType), s"Input column Name: $colName Description: ${getParam(colName)}")
393-
}
394-
395-
396383
/**
397384
* Gets the default value of a parameter.
398385
*/

mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ object SchemaUtils {
3434
* @param colName column name
3535
* @param dataType required column data type
3636
*/
37-
def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = {
37+
def checkColumnType(schema: StructType, colName: String, dataType: DataType,
38+
msg: String = ""): Unit = {
3839
val actualDataType = schema(colName).dataType
3940
require(actualDataType.equals(dataType),
40-
s"Column $colName must be of type $dataType but was actually $actualDataType.")
41+
s"Column $colName must be of type $dataType but was actually $actualDataType.$msg")
4142
}
4243

4344
/**

0 commit comments

Comments
 (0)