Skip to content

Commit 1d4c264

Browse files
committed
Added save/load for tree ensembles
1 parent dcdbf85 commit 1d4c264

File tree

5 files changed

+353
-99
lines changed

5 files changed

+353
-99
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 88 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -150,47 +150,48 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
150150
}
151151
}
152152

153-
private object SaveLoadV1_0 {
153+
private[tree] object SaveLoadV1_0 {
154154

155155
def thisFormatVersion = "1.0"
156156

157+
// Hard-code class name string in case it changes in the future
157158
def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
158159

159-
private case class PredictData(predict: Double, prob: Double)
160+
case class PredictData(predict: Double, prob: Double)
160161

161-
private object PredictData {
162+
object PredictData {
162163
def apply(p: Predict): PredictData = PredictData(p.predict, p.prob)
163164
}
164165

165-
private case class InformationGainStatsData(
166+
case class InformationGainStatsData(
166167
gain: Double,
167168
impurity: Double,
168169
leftImpurity: Double,
169170
rightImpurity: Double,
170171
leftPredict: PredictData,
171172
rightPredict: PredictData)
172173

173-
private object InformationGainStatsData {
174+
object InformationGainStatsData {
174175
def apply(i: InformationGainStats): InformationGainStatsData = {
175176
InformationGainStatsData(i.gain, i.impurity, i.leftImpurity, i.rightImpurity,
176177
PredictData(i.leftPredict), PredictData(i.rightPredict))
177178
}
178179
}
179180

180-
private case class SplitData(
181+
case class SplitData(
181182
feature: Int,
182183
threshold: Double,
183184
featureType: Int,
184185
categories: Seq[Double]) // TODO: Change to List once SPARK-3365 is fixed
185186

186-
private object SplitData {
187+
object SplitData {
187188
def apply(s: Split): SplitData = {
188189
SplitData(s.feature, s.threshold, s.featureType.id, s.categories)
189190
}
190191
}
191192

192193
/** Model data for model import/export */
193-
private case class NodeData(
194+
case class NodeData(
194195
id: Int,
195196
predict: PredictData,
196197
impurity: Double,
@@ -200,7 +201,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
200201
rightNodeId: Option[Int],
201202
stats: Option[InformationGainStatsData])
202203

203-
private object NodeData {
204+
object NodeData {
204205
def apply(n: Node): NodeData = {
205206
NodeData(n.id, PredictData(n.predict), n.impurity, n.isLeaf, n.split.map(SplitData.apply),
206207
n.leftNode.map(_.id), n.rightNode.map(_.id), n.stats.map(InformationGainStatsData.apply))
@@ -212,27 +213,46 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
212213
import sqlContext.implicits._
213214

214215
// Create JSON metadata.
215-
val metadataRDD =
216-
sc.parallelize(Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)))
217-
.toDataFrame("class", "version", "algo", "numNodes")
218-
metadataRDD.toJSON.repartition(1).saveAsTextFile(Loader.metadataPath(path))
216+
val metadataRDD = sc.parallelize(Seq((thisClassName, thisFormatVersion, model.algo.toString,
217+
model.numNodes)), 1)
218+
.toDataFrame("class", "version", "algo", "numNodes")
219+
metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
219220

220221
// Create Parquet data.
221222
val nodeIterator = new DecisionTreeModel.NodeIterator(model)
222-
val dataRDD: DataFrame = sc.parallelize(nodeIterator.toSeq).map(NodeData.apply)
223+
val dataRDD: DataFrame = sc.parallelize(nodeIterator.toSeq).map(NodeData.apply).toDataFrame
223224
dataRDD.saveAsParquetFile(Loader.dataPath(path))
224225
}
225226

227+
/**
228+
* Node with its child IDs. This class is used for loading data and constructing a tree.
229+
* The child IDs are relevant iff Node.isLeaf == false.
230+
*/
231+
case class NodeWithKids(node: Node, leftChildId: Int, rightChildId: Int)
232+
226233
def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
227234
val datapath = Loader.dataPath(path)
228235
val sqlContext = new SQLContext(sc)
229236
// Load Parquet data.
230237
val dataRDD = sqlContext.parquetFile(datapath)
231238
// Check schema explicitly since erasure makes it hard to use match-case for checking.
232239
Loader.checkSchema[NodeData](dataRDD.schema)
233-
// TODO: Extract save/load for 1 tree so that it can be reused for ensembles?
240+
val nodesRDD: RDD[NodeWithKids] = readNodes(dataRDD)
241+
// Collect tree nodes, and build them into a tree.
242+
val tree = constructTree(nodesRDD.collect(), algo, datapath)
243+
assert(tree.numNodes == numNodes,
244+
s"Unable to load DecisionTreeModel data from: $datapath." +
245+
s" Expected $numNodes nodes but found ${tree.numNodes}")
246+
tree
247+
}
248+
249+
/**
250+
* Read nodes from the loaded data, and return each node with its child IDs.
251+
* NOTE: The caller should check the schema.
252+
*/
253+
def readNodes(data: DataFrame): RDD[NodeWithKids] = {
234254
val splitsRDD: RDD[Option[Split]] =
235-
dataRDD.select("split.feature", "split.threshold", "split.featureType", "split.categories")
255+
data.select("split.feature", "split.threshold", "split.featureType", "split.categories")
236256
.map { row: Row =>
237257
if (row.isNullAt(0)) {
238258
None
@@ -246,7 +266,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
246266
}
247267
}
248268
val lrChildNodesRDD: RDD[Option[(Int, Int)]] =
249-
dataRDD.select("leftNodeId", "rightNodeId").map { row: Row =>
269+
data.select("leftNodeId", "rightNodeId").map { row: Row =>
250270
if (row.isNullAt(0)) {
251271
None
252272
} else {
@@ -256,7 +276,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
256276
}
257277
}
258278
}
259-
val gainStatsRDD: RDD[Option[InformationGainStats]] = dataRDD.select(
279+
val gainStatsRDD: RDD[Option[InformationGainStats]] = data.select(
260280
"stats.gain", "stats.impurity", "stats.leftImpurity", "stats.rightImpurity",
261281
"stats.leftPredict.predict", "stats.leftPredict.prob",
262282
"stats.rightPredict.predict", "stats.rightPredict.prob").map { row: Row =>
@@ -265,8 +285,8 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
265285
} else {
266286
row match {
267287
case Row(gain: Double, impurity: Double, leftImpurity: Double, rightImpurity: Double,
268-
leftPredictPredict: Double, leftPredictProb: Double,
269-
rightPredictPredict: Double, rightPredictProb: Double) =>
288+
leftPredictPredict: Double, leftPredictProb: Double,
289+
rightPredictPredict: Double, rightPredictProb: Double) =>
270290
Some(new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
271291
new Predict(leftPredictPredict, leftPredictProb),
272292
new Predict(rightPredictPredict, rightPredictProb)))
@@ -275,55 +295,60 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
275295
}
276296
// nodesRDD stores (Node, leftChildId, rightChildId) where the child ids are only relevant if
277297
// Node.isLeaf == false
278-
val nodesRDD: RDD[(Node, Int, Int)] =
279-
dataRDD.select("id", "predict.predict", "predict.prob", "impurity", "isLeaf").rdd
280-
.zip(splitsRDD).zip(lrChildNodesRDD).zip(gainStatsRDD).map {
281-
case (((Row(id: Int, predictPredict: Double, predictProb: Double,
282-
impurity: Double, isLeaf: Boolean),
283-
split: Option[Split]), lrChildNodes: Option[(Int, Int)]),
284-
gainStats: Option[InformationGainStats]) =>
285-
val (leftChildId, rightChildId) = lrChildNodes.getOrElse((-1, -1))
286-
(new Node(id, new Predict(predictPredict, predictProb), impurity, isLeaf,
287-
split, None, None, gainStats),
288-
leftChildId, rightChildId)
289-
}
290-
// Collect tree nodes, and build them into a tree.
298+
data.select("id", "predict.predict", "predict.prob", "impurity", "isLeaf").rdd
299+
.zip(splitsRDD).zip(lrChildNodesRDD).zip(gainStatsRDD).map {
300+
case (((Row(id: Int, predictPredict: Double, predictProb: Double,
301+
impurity: Double, isLeaf: Boolean),
302+
split: Option[Split]), lrChildNodes: Option[(Int, Int)]),
303+
gainStats: Option[InformationGainStats]) =>
304+
val (leftChildId, rightChildId) = lrChildNodes.getOrElse((-1, -1))
305+
NodeWithKids(new Node(id, new Predict(predictPredict, predictProb), impurity, isLeaf,
306+
split, None, None, gainStats),
307+
leftChildId, rightChildId)
308+
}
309+
}
310+
311+
/**
312+
* Given a list of nodes from a tree, construct the tree.
313+
* @param nodes Array of all nodes in a tree.
314+
* @param algo Algorithm tree is for.
315+
* @param datapath Used for printing debugging messages if an error occurs.
316+
*/
317+
def constructTree(
318+
nodes: Iterable[NodeWithKids],
319+
algo: String,
320+
datapath: String): DecisionTreeModel = {
291321
// nodesMap: node id -> (node, leftChild, rightChild)
292-
val nodesMap: Map[Int, (Node, Int, Int)] = nodesRDD.collect().map(n => n._1.id -> n).toMap
322+
val nodesMap: Map[Int, NodeWithKids] = nodes.map(n => n.node.id -> n).toMap
293323
assert(nodesMap.contains(1),
294324
s"DecisionTree missing root node (id = 1) after loading from: $datapath")
295325
val topNode = nodesMap(1)
296-
linkSubtree(topNode._1, topNode._2, topNode._3, nodesMap)
297-
assert(nodesMap.size == numNodes,
298-
s"Unable to load DecisionTreeModel data from: $datapath." +
299-
s" Expected $numNodes nodes but found ${nodesMap.size}")
300-
new DecisionTreeModel(topNode._1, Algo.fromString(algo))
326+
linkSubtree(topNode, nodesMap)
327+
new DecisionTreeModel(topNode.node, Algo.fromString(algo))
301328
}
302-
}
303329

304-
/**
305-
* Link the given node to its children (if any), and recurse down the subtree.
306-
* @param node Node to link. Node.leftNode and Node.rightNode will be set if there are children.
307-
* @param leftChildId Id of left child. Ignored if node is a leaf.
308-
* @param rightChildId Id of right child. Ignored if node is a leaf.
309-
* @param nodesMap Map storing all nodes as a map: node id -> (Node, leftChildId, rightChildId).
310-
*/
311-
private def linkSubtree(
312-
node: Node,
313-
leftChildId: Int,
314-
rightChildId: Int,
315-
nodesMap: Map[Int, (Node, Int, Int)]): Unit = {
316-
if (node.isLeaf) return
317-
assert(nodesMap.contains(leftChildId),
318-
s"DecisionTreeModel.load could not find child (id=$leftChildId) of node ${node.id}.")
319-
assert(nodesMap.contains(rightChildId),
320-
s"DecisionTreeModel.load could not find child (id=$rightChildId) of node ${node.id}.")
321-
val leftChild = nodesMap(leftChildId)
322-
val rightChild = nodesMap(rightChildId)
323-
node.leftNode = Some(leftChild._1)
324-
node.rightNode = Some(rightChild._1)
325-
linkSubtree(leftChild._1, leftChild._2, leftChild._3, nodesMap)
326-
linkSubtree(rightChild._1, rightChild._2, rightChild._3, nodesMap)
330+
/**
331+
* Link the given node to its children (if any), and recurse down the subtree.
332+
* @param nodeWithKids Node to link
333+
* @param nodesMap Map storing all nodes as a map: node id -> (Node, leftChildId, rightChildId)
334+
*/
335+
private def linkSubtree(
336+
nodeWithKids: NodeWithKids,
337+
nodesMap: Map[Int, NodeWithKids]): Unit = {
338+
val (node, leftChildId, rightChildId) =
339+
(nodeWithKids.node, nodeWithKids.leftChildId, nodeWithKids.rightChildId)
340+
if (node.isLeaf) return
341+
assert(nodesMap.contains(leftChildId),
342+
s"DecisionTreeModel.load could not find child (id=$leftChildId) of node ${node.id}.")
343+
assert(nodesMap.contains(rightChildId),
344+
s"DecisionTreeModel.load could not find child (id=$rightChildId) of node ${node.id}.")
345+
val leftChild = nodesMap(leftChildId)
346+
val rightChild = nodesMap(rightChildId)
347+
node.leftNode = Some(leftChild.node)
348+
node.rightNode = Some(rightChild.node)
349+
linkSubtree(leftChild, nodesMap)
350+
linkSubtree(rightChild, nodesMap)
351+
}
327352
}
328353

329354
override def load(sc: SparkContext, path: String): DecisionTreeModel = {

0 commit comments

Comments
 (0)