@@ -150,47 +150,48 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
150
150
}
151
151
}
152
152
153
- private object SaveLoadV1_0 {
153
+ private [tree] object SaveLoadV1_0 {
154
154
155
155
def thisFormatVersion = " 1.0"
156
156
157
+ // Hard-code class name string in case it changes in the future
157
158
def thisClassName = " org.apache.spark.mllib.tree.DecisionTreeModel"
158
159
159
- private case class PredictData (predict : Double , prob : Double )
160
+ case class PredictData (predict : Double , prob : Double )
160
161
161
- private object PredictData {
162
+ object PredictData {
162
163
def apply (p : Predict ): PredictData = PredictData (p.predict, p.prob)
163
164
}
164
165
165
- private case class InformationGainStatsData (
166
+ case class InformationGainStatsData (
166
167
gain : Double ,
167
168
impurity : Double ,
168
169
leftImpurity : Double ,
169
170
rightImpurity : Double ,
170
171
leftPredict : PredictData ,
171
172
rightPredict : PredictData )
172
173
173
- private object InformationGainStatsData {
174
+ object InformationGainStatsData {
174
175
def apply (i : InformationGainStats ): InformationGainStatsData = {
175
176
InformationGainStatsData (i.gain, i.impurity, i.leftImpurity, i.rightImpurity,
176
177
PredictData (i.leftPredict), PredictData (i.rightPredict))
177
178
}
178
179
}
179
180
180
- private case class SplitData (
181
+ case class SplitData (
181
182
feature : Int ,
182
183
threshold : Double ,
183
184
featureType : Int ,
184
185
categories : Seq [Double ]) // TODO: Change to List once SPARK-3365 is fixed
185
186
186
- private object SplitData {
187
+ object SplitData {
187
188
def apply (s : Split ): SplitData = {
188
189
SplitData (s.feature, s.threshold, s.featureType.id, s.categories)
189
190
}
190
191
}
191
192
192
193
/** Model data for model import/export */
193
- private case class NodeData (
194
+ case class NodeData (
194
195
id : Int ,
195
196
predict : PredictData ,
196
197
impurity : Double ,
@@ -200,7 +201,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
200
201
rightNodeId : Option [Int ],
201
202
stats : Option [InformationGainStatsData ])
202
203
203
- private object NodeData {
204
+ object NodeData {
204
205
def apply (n : Node ): NodeData = {
205
206
NodeData (n.id, PredictData (n.predict), n.impurity, n.isLeaf, n.split.map(SplitData .apply),
206
207
n.leftNode.map(_.id), n.rightNode.map(_.id), n.stats.map(InformationGainStatsData .apply))
@@ -212,27 +213,46 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
212
213
import sqlContext .implicits ._
213
214
214
215
// Create JSON metadata.
215
- val metadataRDD =
216
- sc.parallelize( Seq ((thisClassName, thisFormatVersion, model.algo.toString, model. numNodes)))
217
- .toDataFrame(" class" , " version" , " algo" , " numNodes" )
218
- metadataRDD.toJSON.repartition( 1 ). saveAsTextFile(Loader .metadataPath(path))
216
+ val metadataRDD = sc.parallelize( Seq ((thisClassName, thisFormatVersion, model.algo.toString,
217
+ model.numNodes)), 1 )
218
+ .toDataFrame(" class" , " version" , " algo" , " numNodes" )
219
+ metadataRDD.toJSON.saveAsTextFile(Loader .metadataPath(path))
219
220
220
221
// Create Parquet data.
221
222
val nodeIterator = new DecisionTreeModel .NodeIterator (model)
222
- val dataRDD : DataFrame = sc.parallelize(nodeIterator.toSeq).map(NodeData .apply)
223
+ val dataRDD : DataFrame = sc.parallelize(nodeIterator.toSeq).map(NodeData .apply).toDataFrame
223
224
dataRDD.saveAsParquetFile(Loader .dataPath(path))
224
225
}
225
226
227
+ /**
228
+ * Node with its child IDs. This class is used for loading data and constructing a tree.
229
+ * The child IDs are relevant iff Node.isLeaf == false.
230
+ */
231
+ case class NodeWithKids (node : Node , leftChildId : Int , rightChildId : Int )
232
+
226
233
def load (sc : SparkContext , path : String , algo : String , numNodes : Int ): DecisionTreeModel = {
227
234
val datapath = Loader .dataPath(path)
228
235
val sqlContext = new SQLContext (sc)
229
236
// Load Parquet data.
230
237
val dataRDD = sqlContext.parquetFile(datapath)
231
238
// Check schema explicitly since erasure makes it hard to use match-case for checking.
232
239
Loader .checkSchema[NodeData ](dataRDD.schema)
233
- // TODO: Extract save/load for 1 tree so that it can be reused for ensembles?
240
+ val nodesRDD : RDD [NodeWithKids ] = readNodes(dataRDD)
241
+ // Collect tree nodes, and build them into a tree.
242
+ val tree = constructTree(nodesRDD.collect(), algo, datapath)
243
+ assert(tree.numNodes == numNodes,
244
+ s " Unable to load DecisionTreeModel data from: $datapath. " +
245
+ s " Expected $numNodes nodes but found ${tree.numNodes}" )
246
+ tree
247
+ }
248
+
249
+ /**
250
+ * Read nodes from the loaded data, and return each node with its child IDs.
251
+ * NOTE: The caller should check the schema.
252
+ */
253
+ def readNodes (data : DataFrame ): RDD [NodeWithKids ] = {
234
254
val splitsRDD : RDD [Option [Split ]] =
235
- dataRDD .select(" split.feature" , " split.threshold" , " split.featureType" , " split.categories" )
255
+ data .select(" split.feature" , " split.threshold" , " split.featureType" , " split.categories" )
236
256
.map { row : Row =>
237
257
if (row.isNullAt(0 )) {
238
258
None
@@ -246,7 +266,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
246
266
}
247
267
}
248
268
val lrChildNodesRDD : RDD [Option [(Int , Int )]] =
249
- dataRDD .select(" leftNodeId" , " rightNodeId" ).map { row : Row =>
269
+ data .select(" leftNodeId" , " rightNodeId" ).map { row : Row =>
250
270
if (row.isNullAt(0 )) {
251
271
None
252
272
} else {
@@ -256,7 +276,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
256
276
}
257
277
}
258
278
}
259
- val gainStatsRDD : RDD [Option [InformationGainStats ]] = dataRDD .select(
279
+ val gainStatsRDD : RDD [Option [InformationGainStats ]] = data .select(
260
280
" stats.gain" , " stats.impurity" , " stats.leftImpurity" , " stats.rightImpurity" ,
261
281
" stats.leftPredict.predict" , " stats.leftPredict.prob" ,
262
282
" stats.rightPredict.predict" , " stats.rightPredict.prob" ).map { row : Row =>
@@ -265,8 +285,8 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
265
285
} else {
266
286
row match {
267
287
case Row (gain : Double , impurity : Double , leftImpurity : Double , rightImpurity : Double ,
268
- leftPredictPredict : Double , leftPredictProb : Double ,
269
- rightPredictPredict : Double , rightPredictProb : Double ) =>
288
+ leftPredictPredict : Double , leftPredictProb : Double ,
289
+ rightPredictPredict : Double , rightPredictProb : Double ) =>
270
290
Some (new InformationGainStats (gain, impurity, leftImpurity, rightImpurity,
271
291
new Predict (leftPredictPredict, leftPredictProb),
272
292
new Predict (rightPredictPredict, rightPredictProb)))
@@ -275,55 +295,60 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
275
295
}
276
296
// nodesRDD stores (Node, leftChildId, rightChildId) where the child ids are only relevant if
277
297
// Node.isLeaf == false
278
- val nodesRDD : RDD [(Node , Int , Int )] =
279
- dataRDD.select(" id" , " predict.predict" , " predict.prob" , " impurity" , " isLeaf" ).rdd
280
- .zip(splitsRDD).zip(lrChildNodesRDD).zip(gainStatsRDD).map {
281
- case (((Row (id : Int , predictPredict : Double , predictProb : Double ,
282
- impurity : Double , isLeaf : Boolean ),
283
- split : Option [Split ]), lrChildNodes : Option [(Int , Int )]),
284
- gainStats : Option [InformationGainStats ]) =>
285
- val (leftChildId, rightChildId) = lrChildNodes.getOrElse((- 1 , - 1 ))
286
- (new Node (id, new Predict (predictPredict, predictProb), impurity, isLeaf,
287
- split, None , None , gainStats),
288
- leftChildId, rightChildId)
289
- }
290
- // Collect tree nodes, and build them into a tree.
298
+ data.select(" id" , " predict.predict" , " predict.prob" , " impurity" , " isLeaf" ).rdd
299
+ .zip(splitsRDD).zip(lrChildNodesRDD).zip(gainStatsRDD).map {
300
+ case (((Row (id : Int , predictPredict : Double , predictProb : Double ,
301
+ impurity : Double , isLeaf : Boolean ),
302
+ split : Option [Split ]), lrChildNodes : Option [(Int , Int )]),
303
+ gainStats : Option [InformationGainStats ]) =>
304
+ val (leftChildId, rightChildId) = lrChildNodes.getOrElse((- 1 , - 1 ))
305
+ NodeWithKids (new Node (id, new Predict (predictPredict, predictProb), impurity, isLeaf,
306
+ split, None , None , gainStats),
307
+ leftChildId, rightChildId)
308
+ }
309
+ }
310
+
311
+ /**
312
+ * Given a list of nodes from a tree, construct the tree.
313
+ * @param nodes Array of all nodes in a tree.
314
+ * @param algo Algorithm tree is for.
315
+ * @param datapath Used for printing debugging messages if an error occurs.
316
+ */
317
+ def constructTree (
318
+ nodes : Iterable [NodeWithKids ],
319
+ algo : String ,
320
+ datapath : String ): DecisionTreeModel = {
291
321
// nodesMap: node id -> (node, leftChild, rightChild)
292
- val nodesMap : Map [Int , ( Node , Int , Int ) ] = nodesRDD.collect(). map(n => n._1 .id -> n).toMap
322
+ val nodesMap : Map [Int , NodeWithKids ] = nodes. map(n => n.node .id -> n).toMap
293
323
assert(nodesMap.contains(1 ),
294
324
s " DecisionTree missing root node (id = 1) after loading from: $datapath" )
295
325
val topNode = nodesMap(1 )
296
- linkSubtree(topNode._1, topNode._2, topNode._3, nodesMap)
297
- assert(nodesMap.size == numNodes,
298
- s " Unable to load DecisionTreeModel data from: $datapath. " +
299
- s " Expected $numNodes nodes but found ${nodesMap.size}" )
300
- new DecisionTreeModel (topNode._1, Algo .fromString(algo))
326
+ linkSubtree(topNode, nodesMap)
327
+ new DecisionTreeModel (topNode.node, Algo .fromString(algo))
301
328
}
302
- }
303
329
304
- /**
305
- * Link the given node to its children (if any), and recurse down the subtree.
306
- * @param node Node to link. Node.leftNode and Node.rightNode will be set if there are children.
307
- * @param leftChildId Id of left child. Ignored if node is a leaf.
308
- * @param rightChildId Id of right child. Ignored if node is a leaf.
309
- * @param nodesMap Map storing all nodes as a map: node id -> (Node, leftChildId, rightChildId).
310
- */
311
- private def linkSubtree (
312
- node : Node ,
313
- leftChildId : Int ,
314
- rightChildId : Int ,
315
- nodesMap : Map [Int , (Node , Int , Int )]): Unit = {
316
- if (node.isLeaf) return
317
- assert(nodesMap.contains(leftChildId),
318
- s " DecisionTreeModel.load could not find child (id= $leftChildId) of node ${node.id}. " )
319
- assert(nodesMap.contains(rightChildId),
320
- s " DecisionTreeModel.load could not find child (id= $rightChildId) of node ${node.id}. " )
321
- val leftChild = nodesMap(leftChildId)
322
- val rightChild = nodesMap(rightChildId)
323
- node.leftNode = Some (leftChild._1)
324
- node.rightNode = Some (rightChild._1)
325
- linkSubtree(leftChild._1, leftChild._2, leftChild._3, nodesMap)
326
- linkSubtree(rightChild._1, rightChild._2, rightChild._3, nodesMap)
330
+ /**
331
+ * Link the given node to its children (if any), and recurse down the subtree.
332
+ * @param nodeWithKids Node to link
333
+ * @param nodesMap Map storing all nodes as a map: node id -> (Node, leftChildId, rightChildId)
334
+ */
335
+ private def linkSubtree (
336
+ nodeWithKids : NodeWithKids ,
337
+ nodesMap : Map [Int , NodeWithKids ]): Unit = {
338
+ val (node, leftChildId, rightChildId) =
339
+ (nodeWithKids.node, nodeWithKids.leftChildId, nodeWithKids.rightChildId)
340
+ if (node.isLeaf) return
341
+ assert(nodesMap.contains(leftChildId),
342
+ s " DecisionTreeModel.load could not find child (id= $leftChildId) of node ${node.id}. " )
343
+ assert(nodesMap.contains(rightChildId),
344
+ s " DecisionTreeModel.load could not find child (id= $rightChildId) of node ${node.id}. " )
345
+ val leftChild = nodesMap(leftChildId)
346
+ val rightChild = nodesMap(rightChildId)
347
+ node.leftNode = Some (leftChild.node)
348
+ node.rightNode = Some (rightChild.node)
349
+ linkSubtree(leftChild, nodesMap)
350
+ linkSubtree(rightChild, nodesMap)
351
+ }
327
352
}
328
353
329
354
override def load (sc : SparkContext , path : String ): DecisionTreeModel = {
0 commit comments