Skip to content

Commit e04a5aa

Browse files
committed
boostingstrategy.defaultParam string algo to enumeration.
1 parent 3453d57 commit e04a5aa

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,25 @@ object BoostingStrategy {
8888
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
8989
}
9090
}
91+
92+
/**
93+
* Returns default configuration for the boosting algorithm
94+
* @param algo Learning goal. Supported:
95+
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
96+
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
97+
* @return Configuration for boosting algorithm
98+
*/
99+
def defaultParams(algo: Algo): BoostingStrategy = {
100+
val treeStragtegy = Strategy.defaultStategy(algo)
101+
treeStragtegy.maxDepth = 3
102+
algo match {
103+
case Algo.Classification =>
104+
treeStragtegy.numClasses = 2
105+
new BoostingStrategy(treeStragtegy, LogLoss)
106+
case Algo.Regression =>
107+
new BoostingStrategy(treeStragtegy, SquaredError)
108+
case _ =>
109+
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
110+
}
111+
}
91112
}

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,17 @@ object Strategy {
181181
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
182182
numClasses = 0)
183183
}
184+
185+
/**
186+
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
187+
* @param algo Algo.Classification or Algo.Regression
188+
*/
189+
def defaultStategy(algo: Algo): Strategy = algo match {
190+
case Algo.Classification =>
191+
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
192+
numClasses = 2)
193+
case Algo.Regression =>
194+
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
195+
numClasses = 0)
196+
}
184197
}

0 commit comments

Comments
 (0)