Skip to content

Commit c221db9

Browse files
committed
overload StringArrayParam.w
1 parent c81072d commit c221db9

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.NoSuchElementException
2222

2323
import scala.annotation.varargs
2424
import scala.collection.mutable
25-
import scala.reflect.ClassTag
25+
import scala.collection.JavaConverters._
2626

2727
import org.apache.spark.annotation.AlphaComponent
2828
import org.apache.spark.ml.util.Identifiable
@@ -228,7 +228,8 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
228228

229229
override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
230230

231-
private[param] def wCast(value: Seq[String]): ParamPair[Array[String]] = w(value.toArray)
231+
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
232+
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
232233
}
233234

234235
/**
@@ -323,13 +324,7 @@ trait Params extends Identifiable with Serializable {
323324
* Sets a parameter in the embedded param map.
324325
*/
325326
protected final def set[T](param: Param[T], value: T): this.type = {
326-
shouldOwn(param)
327-
if (param.isInstanceOf[StringArrayParam] && value.isInstanceOf[Seq[_]]) {
328-
paramMap.put(param.asInstanceOf[StringArrayParam].wCast(value.asInstanceOf[Seq[String]]))
329-
} else {
330-
paramMap.put(param.w(value))
331-
}
332-
this
327+
set(param -> value)
333328
}
334329

335330
/**
@@ -339,6 +334,15 @@ trait Params extends Identifiable with Serializable {
339334
set(getParam(param), value)
340335
}
341336

337+
/**
338+
* Sets a parameter in the embedded param map.
339+
*/
340+
protected final def set(paramPair: ParamPair[_]): this.type = {
341+
shouldOwn(paramPair.param)
342+
paramMap.put(paramPair)
343+
this
344+
}
345+
342346
/**
343347
* Optionally returns the user-supplied value of a param.
344348
*/

python/pyspark/ml/wrapper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ def _transfer_params_to_java(self, params, java_obj):
6868
for param in self.params:
6969
if param in paramMap:
7070
value = paramMap[param]
71-
if isinstance(value, list):
72-
value = _jvm().PythonUtils.toSeq(value)
73-
java_obj.set(param.name, value)
71+
java_param = java_obj.getParam(param.name)
72+
java_obj.set(java_param.w(value))
7473

7574
def _empty_java_param_map(self):
7675
"""
@@ -82,7 +81,8 @@ def _create_java_param_map(self, params, java_obj):
8281
paramMap = self._empty_java_param_map()
8382
for param, value in params.items():
8483
if param.parent is self:
85-
paramMap.put(java_obj.getParam(param.name), value)
84+
java_param = java_obj.getParam(param.name)
85+
paramMap.put(java_param.w(value))
8686
return paramMap
8787

8888

0 commit comments

Comments
 (0)