Skip to content

Commit 8b3045c

Browse files
manishamdemateiz
authored andcommitted
MLI-1 Decision Trees
Joint work with @hirakendu, @etrain, @atalwalkar and @harsha2010. Key features: + Supports binary classification and regression + Supports gini, entropy and variance for information gain calculation + Supports both continuous and categorical features The algorithm has gone through several development iterations over the last few months leading to a highly optimized implementation. Optimizations include: 1. Level-wise training to reduce passes over the entire dataset. 2. Bin-wise split calculation to reduce computation overhead. 3. Aggregation over partitions before combining to reduce communication overhead. Author: Manish Amde <[email protected]> Author: manishamde <[email protected]> Author: Xiangrui Meng <[email protected]> Closes alteryx#79 from manishamde/tree and squashes the following commits: 1e8c704 [Manish Amde] remove numBins field in the Strategy class 7d54b4f [manishamde] Merge pull request alteryx#4 from mengxr/dtree f536ae9 [Xiangrui Meng] another pass on code style e1dd86f [Manish Amde] implementing code style suggestions 62dc723 [Manish Amde] updating javadoc and converting helper methods to package private to allow unit testing 201702f [Manish Amde] making some more methods private f963ef5 [Manish Amde] making methods private c487e6a [manishamde] Merge pull request #1 from mengxr/dtree 24500c5 [Xiangrui Meng] minor style updates 4576b64 [Manish Amde] documentation and for to while loop conversion ff363a7 [Manish Amde] binary search for bins and while loop for categorical feature bins 632818f [Manish Amde] removing threshold for classification predict method 2116360 [Manish Amde] removing dummy bin calculation for categorical variables 6068356 [Manish Amde] ensuring num bins is always greater than max number of categories 62c2562 [Manish Amde] fixing comment indentation ad1fc21 [Manish Amde] incorporated mengxr's code style suggestions d1ef4f6 [Manish Amde] more documentation 794ff4d [Manish Amde] minor improvements to docs and style eb8fcbe [Manish Amde] minor code style updates cd2c2b4 [Manish Amde] fixing code style based on feedback 63e786b [Manish Amde] added multiple train methods for java compatability d3023b3 [Manish Amde] adding more docs for nested methods 84f85d6 [Manish Amde] code documentation 9372779 [Manish Amde] code style: max line lenght <= 100 dd0c0d7 [Manish Amde] minor: some docs 0dd7659 [manishamde] basic doc 5841c28 [Manish Amde] unit tests for categorical features f067d68 [Manish Amde] minor cleanup c0e522b [Manish Amde] updated predict and split threshold logic b09dc98 [Manish Amde] minor refactoring 6b7de78 [Manish Amde] minor refactoring and tests d504eb1 [Manish Amde] more tests for categorical features dbb7ac1 [Manish Amde] categorical feature support 6df35b9 [Manish Amde] regression predict logic 53108ed [Manish Amde] fixing index for highest bin e23c2e5 [Manish Amde] added regression support c8f6d60 [Manish Amde] adding enum for feature type b0e3e76 [Manish Amde] adding enum for feature type 154aa77 [Manish Amde] enums for configurations 733d6dd [Manish Amde] fixed tests 02c595c [Manish Amde] added command line parsing 98ec8d5 [Manish Amde] tree building and prediction logic b0eb866 [Manish Amde] added logic to handle leaf nodes 80e8c66 [Manish Amde] working version of multi-level split calculation 4798aae [Manish Amde] added gain stats class dad0afc [Manish Amde] decison stump functionality working 03f534c [Manish Amde] some more tests 0012a77 [Manish Amde] basic stump working 8bca1e2 [Manish Amde] additional code for creating intermediate RDD 92cedce [Manish Amde] basic building blocks for intermediate RDD calculation. untested. cd53eae [Manish Amde] skeletal framework
1 parent 45df912 commit 8b3045c

File tree

17 files changed

+2188
-0
lines changed

17 files changed

+2188
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 1150 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
This package contains the default implementation of the decision tree algorithm.
2+
3+
The decision tree algorithm supports:
4+
+ Binary classification
5+
+ Regression
6+
+ Information loss calculation with entropy and gini for classification and variance for regression
7+
+ Both continuous and categorical features
8+
9+
# Tree improvements
10+
+ Node model pruning
11+
+ Printing to dot files
12+
13+
# Future Ensemble Extensions
14+
15+
+ Random forests
16+
+ Boosting
17+
+ Extremely randomized trees
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.configuration
19+
20+
/**
21+
* Enum to select the algorithm for the decision tree
22+
*/
23+
object Algo extends Enumeration {
24+
type Algo = Value
25+
val Classification, Regression = Value
26+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.configuration
19+
20+
/**
21+
* Enum to describe whether a feature is "continuous" or "categorical"
22+
*/
23+
object FeatureType extends Enumeration {
24+
type FeatureType = Value
25+
val Continuous, Categorical = Value
26+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.configuration
19+
20+
/**
21+
* Enum for selecting the quantile calculation strategy
22+
*/
23+
object QuantileStrategy extends Enumeration {
24+
type QuantileStrategy = Value
25+
val Sort, MinMax, ApproxHist = Value
26+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.configuration
19+
20+
import org.apache.spark.mllib.tree.impurity.Impurity
21+
import org.apache.spark.mllib.tree.configuration.Algo._
22+
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
23+
24+
/**
25+
* Stores all the configuration options for tree construction
26+
* @param algo classification or regression
27+
* @param impurity criterion used for information gain calculation
28+
* @param maxDepth maximum depth of the tree
29+
* @param maxBins maximum number of bins used for splitting features
30+
* @param quantileCalculationStrategy algorithm for calculating quantiles
31+
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
32+
* number of discrete values they take. For example, an entry (n ->
33+
* k) implies the feature n is categorical with k categories 0,
34+
* 1, 2, ... , k-1. It's important to note that features are
35+
* zero-indexed.
36+
*/
37+
class Strategy (
38+
val algo: Algo,
39+
val impurity: Impurity,
40+
val maxDepth: Int,
41+
val maxBins: Int = 100,
42+
val quantileCalculationStrategy: QuantileStrategy = Sort,
43+
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impurity
19+
20+
/**
21+
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
22+
* binary classification.
23+
*/
24+
object Entropy extends Impurity {
25+
26+
def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
27+
28+
/**
29+
* entropy calculation
30+
* @param c0 count of instances with label 0
31+
* @param c1 count of instances with label 1
32+
* @return entropy value
33+
*/
34+
def calculate(c0: Double, c1: Double): Double = {
35+
if (c0 == 0 || c1 == 0) {
36+
0
37+
} else {
38+
val total = c0 + c1
39+
val f0 = c0 / total
40+
val f1 = c1 / total
41+
-(f0 * log2(f0)) - (f1 * log2(f1))
42+
}
43+
}
44+
45+
def calculate(count: Double, sum: Double, sumSquares: Double): Double =
46+
throw new UnsupportedOperationException("Entropy.calculate")
47+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impurity
19+
20+
/**
21+
* Class for calculating the
22+
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
23+
* during binary classification.
24+
*/
25+
object Gini extends Impurity {
26+
27+
/**
28+
* Gini coefficient calculation
29+
* @param c0 count of instances with label 0
30+
* @param c1 count of instances with label 1
31+
* @return Gini coefficient value
32+
*/
33+
override def calculate(c0: Double, c1: Double): Double = {
34+
if (c0 == 0 || c1 == 0) {
35+
0
36+
} else {
37+
val total = c0 + c1
38+
val f0 = c0 / total
39+
val f1 = c1 / total
40+
1 - f0 * f0 - f1 * f1
41+
}
42+
}
43+
44+
def calculate(count: Double, sum: Double, sumSquares: Double): Double =
45+
throw new UnsupportedOperationException("Gini.calculate")
46+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impurity
19+
20+
/**
21+
* Trait for calculating information gain.
22+
*/
23+
trait Impurity extends Serializable {
24+
25+
/**
26+
* information calculation for binary classification
27+
* @param c0 count of instances with label 0
28+
* @param c1 count of instances with label 1
29+
* @return information value
30+
*/
31+
def calculate(c0 : Double, c1 : Double): Double
32+
33+
/**
34+
* information calculation for regression
35+
* @param count number of instances
36+
* @param sum sum of labels
37+
* @param sumSquares summation of squares of the labels
38+
* @return information value
39+
*/
40+
def calculate(count: Double, sum: Double, sumSquares: Double): Double
41+
42+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impurity
19+
20+
/**
21+
* Class for calculating variance during regression
22+
*/
23+
object Variance extends Impurity {
24+
override def calculate(c0: Double, c1: Double): Double =
25+
throw new UnsupportedOperationException("Variance.calculate")
26+
27+
/**
28+
* variance calculation
29+
* @param count number of instances
30+
* @param sum sum of labels
31+
* @param sumSquares summation of squares of the labels
32+
*/
33+
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
34+
val squaredLoss = sumSquares - (sum * sum) / count
35+
squaredLoss / count
36+
}
37+
}

0 commit comments

Comments
 (0)