Skip to content

Commit 6e86d98

Browse files
committed
some code clean-up
1 parent 2d040b3 commit 6e86d98

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

mllib/src/main/scala/org/apache/spark/ml/Model.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.ml
1919

20+
/**
21+
* A trained model.
22+
*/
2023
abstract class Model extends Transformer {
2124
// def parent: Estimator
2225
// def trainingParameters: ParamMap

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,29 @@
1717

1818
package org.apache.spark.ml
1919

20-
import org.apache.spark.ml.param.{ParamMap, Param}
21-
import org.apache.spark.sql.SchemaRDD
22-
2320
import scala.collection.mutable.ListBuffer
2421

22+
import org.apache.spark.ml.param.{Param, ParamMap}
23+
import org.apache.spark.sql.SchemaRDD
24+
25+
/**
26+
* A stage in a pipeline, either an Estimator or an Transformer.
27+
*/
2528
trait PipelineStage extends Identifiable
2629

2730
/**
2831
* A simple pipeline, which acts as an estimator.
2932
*/
3033
class Pipeline extends Estimator[PipelineModel] {
3134

32-
val stages: Param[Array[PipelineStage]] =
33-
new Param(this, "stages", "stages of the pipeline")
34-
35-
def setStages(stages: Array[PipelineStage]): this.type = {
36-
set(this.stages, stages)
37-
this
38-
}
39-
40-
def getStages: Array[PipelineStage] = {
41-
get(stages)
42-
}
35+
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
36+
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
37+
def getStages: Array[PipelineStage] = get(stages)
4338

4439
override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
4540
val map = this.paramMap ++ paramMap
4641
val theStages = map(stages)
47-
// Search for last estimator.
42+
// Search for the last estimator.
4843
var lastIndexOfEstimator = -1
4944
theStages.view.zipWithIndex.foreach { case (stage, index) =>
5045
stage match {
@@ -75,10 +70,11 @@ class Pipeline extends Estimator[PipelineModel] {
7570

7671
new PipelineModel(transformers.toArray)
7772
}
78-
79-
override def params: Array[Param[_]] = Array.empty
8073
}
8174

75+
/**
76+
* Represents a compiled pipeline.
77+
*/
8278
class PipelineModel(val transformers: Array[Transformer]) extends Model {
8379

8480
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {

0 commit comments

Comments
 (0)