Skip to content

Commit fcb3e18

Browse files
Peishen-Jiamengxr
authored andcommitted
[SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.Classification or Algo.Regression
JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5317 When setting the BoostingStrategy.defaultParams("Classification"), It's more straightforward to set it with the Enumeration Algo.Classification, just like BoostingStragety.defaultParams(Algo.Classification). I overload the method BoostingStragety.defaultParams(). Author: Basin <[email protected]> Closes #4103 from Peishen-Jia/stragetyAlgo and squashes the following commits: 87bab1c [Basin] Docs and Code documentations updated. 3b72875 [Basin] defaultParams(algoStr: String) call defaultParams(algo: Algo). 7c1e6ee [Basin] Doc of Java updated. algo -> algoStr instead. d5c8a2e [Basin] Merge branch 'stragetyAlgo' of github.com:Peishen-Jia/spark into stragetyAlgo 65f96ce [Basin] mllib-ensembles doc modified. e04a5aa [Basin] boostingstrategy.defaultParam string algo to enumeration. 68cf544 [Basin] mllib-ensembles doc modified. a4aea51 [Basin] boostingstrategy.defaultParam string algo to enumeration.
1 parent ca7910d commit fcb3e18

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,22 +68,31 @@ case class BoostingStrategy(
6868
@Experimental
6969
object BoostingStrategy {
7070

71+
/**
72+
* Returns default configuration for the boosting algorithm
73+
* @param algo Learning goal. Supported: "Classification" or "Regression"
74+
* @return Configuration for boosting algorithm
75+
*/
76+
def defaultParams(algo: String): BoostingStrategy = {
77+
defaultParams(Algo.fromString(algo))
78+
}
79+
7180
/**
7281
* Returns default configuration for the boosting algorithm
7382
* @param algo Learning goal. Supported:
7483
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
7584
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
7685
* @return Configuration for boosting algorithm
7786
*/
78-
def defaultParams(algo: String): BoostingStrategy = {
79-
val treeStrategy = Strategy.defaultStrategy(algo)
80-
treeStrategy.maxDepth = 3
87+
def defaultParams(algo: Algo): BoostingStrategy = {
88+
val treeStragtegy = Strategy.defaultStategy(algo)
89+
treeStragtegy.maxDepth = 3
8190
algo match {
82-
case "Classification" =>
83-
treeStrategy.numClasses = 2
84-
new BoostingStrategy(treeStrategy, LogLoss)
85-
case "Regression" =>
86-
new BoostingStrategy(treeStrategy, SquaredError)
91+
case Algo.Classification =>
92+
treeStragtegy.numClasses = 2
93+
new BoostingStrategy(treeStragtegy, LogLoss)
94+
case Algo.Regression =>
95+
new BoostingStrategy(treeStragtegy, SquaredError)
8796
case _ =>
8897
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
8998
}

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,19 @@ object Strategy {
173173
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
174174
* @param algo "Classification" or "Regression"
175175
*/
176-
def defaultStrategy(algo: String): Strategy = algo match {
177-
case "Classification" =>
176+
def defaultStrategy(algo: String): Strategy = {
177+
defaultStategy(Algo.fromString(algo))
178+
}
179+
180+
/**
181+
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
182+
* @param algo Algo.Classification or Algo.Regression
183+
*/
184+
def defaultStategy(algo: Algo): Strategy = algo match {
185+
case Algo.Classification =>
178186
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
179187
numClasses = 2)
180-
case "Regression" =>
188+
case Algo.Regression =>
181189
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
182190
numClasses = 0)
183191
}

0 commit comments

Comments
 (0)