This repository was archived by the owner on May 9, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed
mllib/src/main/scala/org/apache/spark/ml/param Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -24,7 +24,8 @@ import scala.annotation.varargs
24
24
import scala .collection .mutable
25
25
26
26
import org .apache .spark .annotation .AlphaComponent
27
- import org .apache .spark .ml .util .Identifiable
27
+ import org .apache .spark .ml .util .{SchemaUtils , Identifiable }
28
+ import org .apache .spark .sql .types .{DataType , StructType }
28
29
29
30
/**
30
31
* :: AlphaComponent ::
@@ -380,6 +381,18 @@ trait Params extends Identifiable with Serializable {
380
381
this
381
382
}
382
383
384
+ /**
385
+ * Check whether the given schema contains an input column.
386
+ * @param colName Input column name
387
+ * @param dataType Input column DataType
388
+ */
389
+ protected def checkInputColumn (schema : StructType , colName : String , dataType : DataType ): Unit = {
390
+ val actualDataType = schema(colName).dataType
391
+ SchemaUtils .checkColumnType(schema, colName, dataType)
392
+ require(actualDataType.equals(dataType), s " Input column Name: $colName Description: ${getParam(colName)}" )
393
+ }
394
+
395
+
383
396
/**
384
397
* Gets the default value of a parameter.
385
398
*/
You can’t perform that action at this time.
0 commit comments