|
17 | 17 |
|
18 | 18 | package org.apache.spark.ml
|
19 | 19 |
|
20 |
| -import org.apache.spark.ml.param.{ParamMap, Param} |
21 |
| -import org.apache.spark.sql.SchemaRDD |
22 |
| - |
23 | 20 | import scala.collection.mutable.ListBuffer
|
24 | 21 |
|
| 22 | +import org.apache.spark.ml.param.{Param, ParamMap} |
| 23 | +import org.apache.spark.sql.SchemaRDD |
| 24 | + |
| 25 | +/** |
| 26 | + * A stage in a pipeline, either an Estimator or an Transformer. |
| 27 | + */ |
25 | 28 | trait PipelineStage extends Identifiable
|
26 | 29 |
|
27 | 30 | /**
|
28 | 31 | * A simple pipeline, which acts as an estimator.
|
29 | 32 | */
|
30 | 33 | class Pipeline extends Estimator[PipelineModel] {
|
31 | 34 |
|
32 |
| - val stages: Param[Array[PipelineStage]] = |
33 |
| - new Param(this, "stages", "stages of the pipeline") |
34 |
| - |
35 |
| - def setStages(stages: Array[PipelineStage]): this.type = { |
36 |
| - set(this.stages, stages) |
37 |
| - this |
38 |
| - } |
39 |
| - |
40 |
| - def getStages: Array[PipelineStage] = { |
41 |
| - get(stages) |
42 |
| - } |
| 35 | + val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") |
| 36 | + def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } |
| 37 | + def getStages: Array[PipelineStage] = get(stages) |
43 | 38 |
|
44 | 39 | override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
|
45 | 40 | val map = this.paramMap ++ paramMap
|
46 | 41 | val theStages = map(stages)
|
47 |
| - // Search for last estimator. |
| 42 | + // Search for the last estimator. |
48 | 43 | var lastIndexOfEstimator = -1
|
49 | 44 | theStages.view.zipWithIndex.foreach { case (stage, index) =>
|
50 | 45 | stage match {
|
@@ -75,10 +70,11 @@ class Pipeline extends Estimator[PipelineModel] {
|
75 | 70 |
|
76 | 71 | new PipelineModel(transformers.toArray)
|
77 | 72 | }
|
78 |
| - |
79 |
| - override def params: Array[Param[_]] = Array.empty |
80 | 73 | }
|
81 | 74 |
|
| 75 | +/** |
| 76 | + * Represents a compiled pipeline. |
| 77 | + */ |
82 | 78 | class PipelineModel(val transformers: Array[Transformer]) extends Model {
|
83 | 79 |
|
84 | 80 | override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
|
|
0 commit comments