Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 8993c0e

Browse files
committed
SPARK-7137: Add checkInputColumn back to Params and print more info
1 parent e3677c9 commit 8993c0e

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import scala.annotation.varargs
2424
import scala.collection.mutable
2525

2626
import org.apache.spark.annotation.AlphaComponent
27-
import org.apache.spark.ml.util.Identifiable
27+
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
28+
import org.apache.spark.sql.types.{DataType, StructType}
2829

2930
/**
3031
* :: AlphaComponent ::
@@ -380,6 +381,18 @@ trait Params extends Identifiable with Serializable {
380381
this
381382
}
382383

384+
/**
385+
* Check whether the given schema contains an input column.
386+
* @param colName Input column name
387+
* @param dataType Input column DataType
388+
*/
389+
protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = {
390+
val actualDataType = schema(colName).dataType
391+
SchemaUtils.checkColumnType(schema, colName, dataType)
392+
require(actualDataType.equals(dataType), s"Input column Name: $colName Description: ${getParam(colName)}")
393+
}
394+
395+
383396
/**
384397
* Gets the default value of a parameter.
385398
*/

0 commit comments

Comments
 (0)