Skip to content

Commit 8791e8e

Browse files
committed
rename copyValues to inheritValues and make it do the right thing
1 parent 51f1c06 commit 8791e8e

File tree

4 files changed

+21
-15
lines changed

4 files changed

+21
-15
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
6666
val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights)
6767
instances.unpersist()
6868
// copy model params
69-
Params.copyValues(this, lrm)
69+
Params.inheritValues(map, this, lrm)
7070
lrm
7171
}
7272

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
4949
}
5050
val scaler = new feature.StandardScaler().fit(input)
5151
val model = new StandardScalerModel(this, map, scaler)
52-
Params.copyValues(this, model)
52+
Params.inheritValues(map, this, model)
5353
model
5454
}
5555

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,10 @@ trait Params extends Identifiable with Serializable {
143143
m.invoke(this).asInstanceOf[Param[Any]]
144144
}
145145

146-
/**
147-
* Internal param map.
148-
*/
149-
protected val paramMap: ParamMap = ParamMap.empty
150-
151146
/**
152147
* Sets a parameter in the own parameter map.
153148
*/
154-
protected def set[T](param: Param[T], value: T): this.type = {
149+
private[ml] def set[T](param: Param[T], value: T): this.type = {
155150
require(param.parent.eq(this))
156151
paramMap.put(param.asInstanceOf[Param[Any]], value)
157152
this
@@ -160,10 +155,15 @@ trait Params extends Identifiable with Serializable {
160155
/**
161156
* Gets the value of a parameter.
162157
*/
163-
protected def get[T](param: Param[T]): T = {
158+
private[ml] def get[T](param: Param[T]): T = {
164159
require(param.parent.eq(this))
165160
paramMap(param)
166161
}
162+
163+
/**
164+
* Internal param map.
165+
*/
166+
protected val paramMap: ParamMap = ParamMap.empty
167167
}
168168

169169
private[ml] object Params {
@@ -174,12 +174,18 @@ private[ml] object Params {
174174
val empty: Params = new Params {}
175175

176176
/**
177-
* Copy parameter values that are explicitly set from one Params instance to another.
177+
* Copies parameter values from the parent estimator to the child model it produced.
178+
* @param paramMap the param map that holds parameters of the parent
179+
* @param parent the parent estimator
180+
* @param child the child model
178181
*/
179-
private[ml] def copyValues[F <: Params, T <: F](from: F, to: T): Unit = {
180-
from.params.foreach { param =>
181-
if (from.isSet(param)) {
182-
to.set(to.getParam(param.name), from.get(param))
182+
private[ml] def inheritValues[E <: Params, M <: E](
183+
paramMap: ParamMap,
184+
parent: E,
185+
child: M): Unit = {
186+
parent.params.foreach { param =>
187+
if (paramMap.contains(param)) {
188+
child.set(child.getParam(param.name), paramMap(param))
183189
}
184190
}
185191
}

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
9595
logInfo(s"Best cross-validation metric: $bestMetric.")
9696
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
9797
val cvModel = new CrossValidatorModel(this, map, bestModel)
98-
Params.copyValues(this, cvModel)
98+
Params.inheritValues(map, this, cvModel)
9999
cvModel
100100
}
101101

0 commit comments

Comments
 (0)