Skip to content

Commit 01905c4

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-5597][MLLIB] save/load for decision trees and emsembles
This is based on #4444 from jkbradley with the following changes: 1. Node schema updated to ~~~ treeId: int nodeId: Int predict/ |- predict: Double |- prob: Double impurity: Double isLeaf: Boolean split/ |- feature: Int |- threshold: Double |- featureType: Int |- categories: Array[Double] leftNodeId: Integer rightNodeId: Integer infoGain: Double ~~~ 2. Some refactor of the implementation. Closes #4444. Author: Joseph K. Bradley <[email protected]> Author: Xiangrui Meng <[email protected]> Closes #4493 from mengxr/SPARK-5597 and squashes the following commits: 75e3bb6 [Xiangrui Meng] fix style 2b0033d [Xiangrui Meng] update tree export schema and refactor the implementation 45873a2 [Joseph K. Bradley] org imports 1d4c264 [Joseph K. Bradley] Added save/load for tree ensembles dcdbf85 [Joseph K. Bradley] added save/load for decision tree but need to generalize it to ensembles (cherry picked from commit ef2f55b) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 663d34e commit 01905c4

File tree

8 files changed

+561
-38
lines changed

8 files changed

+561
-38
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 196 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@
1717

1818
package org.apache.spark.mllib.tree.model
1919

20+
import scala.collection.mutable
21+
22+
import org.apache.spark.SparkContext
2023
import org.apache.spark.annotation.Experimental
2124
import org.apache.spark.api.java.JavaRDD
2225
import org.apache.spark.mllib.linalg.Vector
26+
import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
2327
import org.apache.spark.mllib.tree.configuration.Algo._
28+
import org.apache.spark.mllib.util.{Loader, Saveable}
2429
import org.apache.spark.rdd.RDD
30+
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
2531

2632
/**
2733
* :: Experimental ::
@@ -31,7 +37,7 @@ import org.apache.spark.rdd.RDD
3137
* @param algo algorithm type -- classification or regression
3238
*/
3339
@Experimental
34-
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
40+
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
3541

3642
/**
3743
* Predict values for a single data point using the model trained.
@@ -98,4 +104,193 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
98104
header + topNode.subtreeToString(2)
99105
}
100106

107+
override def save(sc: SparkContext, path: String): Unit = {
108+
DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
109+
}
110+
111+
override protected def formatVersion: String = "1.0"
112+
}
113+
114+
object DecisionTreeModel extends Loader[DecisionTreeModel] {
115+
116+
private[tree] object SaveLoadV1_0 {
117+
118+
def thisFormatVersion = "1.0"
119+
120+
// Hard-code class name string in case it changes in the future
121+
def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
122+
123+
case class PredictData(predict: Double, prob: Double) {
124+
def toPredict: Predict = new Predict(predict, prob)
125+
}
126+
127+
object PredictData {
128+
def apply(p: Predict): PredictData = PredictData(p.predict, p.prob)
129+
130+
def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1))
131+
}
132+
133+
case class SplitData(
134+
feature: Int,
135+
threshold: Double,
136+
featureType: Int,
137+
categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed
138+
def toSplit: Split = {
139+
new Split(feature, threshold, FeatureType(featureType), categories.toList)
140+
}
141+
}
142+
143+
object SplitData {
144+
def apply(s: Split): SplitData = {
145+
SplitData(s.feature, s.threshold, s.featureType.id, s.categories)
146+
}
147+
148+
def apply(r: Row): SplitData = {
149+
SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3))
150+
}
151+
}
152+
153+
/** Model data for model import/export */
154+
case class NodeData(
155+
treeId: Int,
156+
nodeId: Int,
157+
predict: PredictData,
158+
impurity: Double,
159+
isLeaf: Boolean,
160+
split: Option[SplitData],
161+
leftNodeId: Option[Int],
162+
rightNodeId: Option[Int],
163+
infoGain: Option[Double])
164+
165+
object NodeData {
166+
def apply(treeId: Int, n: Node): NodeData = {
167+
NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf,
168+
n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id),
169+
n.stats.map(_.gain))
170+
}
171+
172+
def apply(r: Row): NodeData = {
173+
val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5)))
174+
val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6))
175+
val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7))
176+
val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8))
177+
NodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3),
178+
r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain)
179+
}
180+
}
181+
182+
def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = {
183+
val sqlContext = new SQLContext(sc)
184+
import sqlContext.implicits._
185+
186+
// Create JSON metadata.
187+
val metadataRDD = sc.parallelize(
188+
Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1)
189+
.toDataFrame("class", "version", "algo", "numNodes")
190+
metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
191+
192+
// Create Parquet data.
193+
val nodes = model.topNode.subtreeIterator.toSeq
194+
val dataRDD: DataFrame = sc.parallelize(nodes)
195+
.map(NodeData.apply(0, _))
196+
.toDataFrame
197+
dataRDD.saveAsParquetFile(Loader.dataPath(path))
198+
}
199+
200+
def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
201+
val datapath = Loader.dataPath(path)
202+
val sqlContext = new SQLContext(sc)
203+
// Load Parquet data.
204+
val dataRDD = sqlContext.parquetFile(datapath)
205+
// Check schema explicitly since erasure makes it hard to use match-case for checking.
206+
Loader.checkSchema[NodeData](dataRDD.schema)
207+
val nodes = dataRDD.map(NodeData.apply)
208+
// Build node data into a tree.
209+
val trees = constructTrees(nodes)
210+
assert(trees.size == 1,
211+
"Decision tree should contain exactly one tree but got ${trees.size} trees.")
212+
val model = new DecisionTreeModel(trees(0), Algo.fromString(algo))
213+
assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." +
214+
s" Expected $numNodes nodes but found ${model.numNodes}")
215+
model
216+
}
217+
218+
def constructTrees(nodes: RDD[NodeData]): Array[Node] = {
219+
val trees = nodes
220+
.groupBy(_.treeId)
221+
.mapValues(_.toArray)
222+
.collect()
223+
.map { case (treeId, data) =>
224+
(treeId, constructTree(data))
225+
}.sortBy(_._1)
226+
val numTrees = trees.size
227+
val treeIndices = trees.map(_._1).toSeq
228+
assert(treeIndices == (0 until numTrees),
229+
s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.")
230+
trees.map(_._2)
231+
}
232+
233+
/**
234+
* Given a list of nodes from a tree, construct the tree.
235+
* @param data array of all node data in a tree.
236+
*/
237+
def constructTree(data: Array[NodeData]): Node = {
238+
val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap
239+
assert(dataMap.contains(1),
240+
s"DecisionTree missing root node (id = 1).")
241+
constructNode(1, dataMap, mutable.Map.empty)
242+
}
243+
244+
/**
245+
* Builds a node from the node data map and adds new nodes to the input nodes map.
246+
*/
247+
private def constructNode(
248+
id: Int,
249+
dataMap: Map[Int, NodeData],
250+
nodes: mutable.Map[Int, Node]): Node = {
251+
if (nodes.contains(id)) {
252+
return nodes(id)
253+
}
254+
val data = dataMap(id)
255+
val node =
256+
if (data.isLeaf) {
257+
Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf)
258+
} else {
259+
val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes)
260+
val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes)
261+
val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity,
262+
rightNode.impurity, leftNode.predict, rightNode.predict)
263+
new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf,
264+
data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats))
265+
}
266+
nodes += node.id -> node
267+
node
268+
}
269+
}
270+
271+
override def load(sc: SparkContext, path: String): DecisionTreeModel = {
272+
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
273+
val (algo: String, numNodes: Int) = try {
274+
val algo_numNodes = metadata.select("algo", "numNodes").collect()
275+
assert(algo_numNodes.length == 1)
276+
algo_numNodes(0) match {
277+
case Row(a: String, n: Int) => (a, n)
278+
}
279+
} catch {
280+
// Catch both Error and Exception since the checks above can throw either.
281+
case e: Throwable =>
282+
throw new Exception(
283+
s"Unable to load DecisionTreeModel metadata from: ${Loader.metadataPath(path)}."
284+
+ s" Error message: ${e.getMessage}")
285+
}
286+
val classNameV1_0 = SaveLoadV1_0.thisClassName
287+
(loadedClassName, version) match {
288+
case (className, "1.0") if className == classNameV1_0 =>
289+
SaveLoadV1_0.load(sc, path, algo, numNodes)
290+
case _ => throw new Exception(
291+
s"DecisionTreeModel.load did not recognize model with (className, format version):" +
292+
s"($loadedClassName, $version). Supported:\n" +
293+
s" ($classNameV1_0, 1.0)")
294+
}
295+
}
101296
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class InformationGainStats(
4949
gain == other.gain &&
5050
impurity == other.impurity &&
5151
leftImpurity == other.leftImpurity &&
52-
rightImpurity == other.rightImpurity
52+
rightImpurity == other.rightImpurity &&
53+
leftPredict == other.leftPredict &&
54+
rightPredict == other.rightPredict
5355
}
5456
case _ => false
5557
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ class Node (
166166
}
167167
}
168168

169+
/** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */
170+
private[tree] def subtreeIterator: Iterator[Node] = {
171+
Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++
172+
rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty)
173+
}
169174
}
170175

171176
private[tree] object Node {

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,11 @@ class Predict(
3232
override def toString = {
3333
"predict = %f, prob = %f".format(predict, prob)
3434
}
35+
36+
override def equals(other: Any): Boolean = {
37+
other match {
38+
case p: Predict => predict == p.predict && prob == p.prob
39+
case _ => false
40+
}
41+
}
3542
}

0 commit comments

Comments
 (0)