Skip to content

Commit d501774

Browse files
committed
Made Model.parent transient. Added Model.hasParent to test for null parent
1 parent 814b3da commit d501774

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/Model.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ abstract class Model[M <: Model[M]] extends Transformer {
3232
* The parent estimator that produced this model.
3333
* Note: For ensembles' component Models, this value can be null.
3434
*/
35-
var parent: Estimator[M] = _
35+
@transient var parent: Estimator[M] = _
3636

3737
/**
3838
* Sets the parent of this model (Java API).
@@ -42,6 +42,9 @@ abstract class Model[M <: Model[M]] extends Transformer {
4242
this.asInstanceOf[M]
4343
}
4444

45+
/** Indicates whether this [[Model]] has a corresponding parent. */
46+
def hasParent: Boolean = parent != null
47+
4548
override def copy(extra: ParamMap): M = {
4649
// The default implementation of Params.copy doesn't work for models.
4750
throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
8383
assert(model.getRawPredictionCol === "rawPrediction")
8484
assert(model.getProbabilityCol === "probability")
8585
assert(model.intercept !== 0.0)
86+
assert(model.hasParent)
8687
}
8788

8889
test("logistic regression doesn't fit intercept when fitIntercept is off") {

mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,5 +162,7 @@ private object RandomForestClassifierSuite {
162162
val oldModelAsNew = RandomForestClassificationModel.fromOld(
163163
oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
164164
TreeTests.checkEqual(oldModelAsNew, newModel)
165+
assert(newModel.hasParent)
166+
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
165167
}
166168
}

0 commit comments

Comments
 (0)