17
17
18
18
package org .apache .spark .mllib .tree .model
19
19
20
+ import scala .collection .mutable
21
+
22
+ import org .apache .spark .SparkContext
20
23
import org .apache .spark .annotation .Experimental
21
24
import org .apache .spark .api .java .JavaRDD
22
25
import org .apache .spark .mllib .linalg .Vector
26
+ import org .apache .spark .mllib .tree .configuration .{Algo , FeatureType }
23
27
import org .apache .spark .mllib .tree .configuration .Algo ._
28
+ import org .apache .spark .mllib .util .{Loader , Saveable }
24
29
import org .apache .spark .rdd .RDD
30
+ import org .apache .spark .sql .{DataFrame , Row , SQLContext }
25
31
26
32
/**
27
33
* :: Experimental ::
@@ -31,7 +37,7 @@ import org.apache.spark.rdd.RDD
31
37
* @param algo algorithm type -- classification or regression
32
38
*/
33
39
@ Experimental
34
- class DecisionTreeModel (val topNode : Node , val algo : Algo ) extends Serializable {
40
+ class DecisionTreeModel (val topNode : Node , val algo : Algo ) extends Serializable with Saveable {
35
41
36
42
/**
37
43
* Predict values for a single data point using the model trained.
@@ -98,4 +104,193 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
98
104
header + topNode.subtreeToString(2 )
99
105
}
100
106
107
+ override def save (sc : SparkContext , path : String ): Unit = {
108
+ DecisionTreeModel .SaveLoadV1_0 .save(sc, path, this )
109
+ }
110
+
111
+ override protected def formatVersion : String = " 1.0"
112
+ }
113
+
114
+ object DecisionTreeModel extends Loader [DecisionTreeModel ] {
115
+
116
+ private [tree] object SaveLoadV1_0 {
117
+
118
+ def thisFormatVersion = " 1.0"
119
+
120
+ // Hard-code class name string in case it changes in the future
121
+ def thisClassName = " org.apache.spark.mllib.tree.DecisionTreeModel"
122
+
123
+ case class PredictData (predict : Double , prob : Double ) {
124
+ def toPredict : Predict = new Predict (predict, prob)
125
+ }
126
+
127
+ object PredictData {
128
+ def apply (p : Predict ): PredictData = PredictData (p.predict, p.prob)
129
+
130
+ def apply (r : Row ): PredictData = PredictData (r.getDouble(0 ), r.getDouble(1 ))
131
+ }
132
+
133
+ case class SplitData (
134
+ feature : Int ,
135
+ threshold : Double ,
136
+ featureType : Int ,
137
+ categories : Seq [Double ]) { // TODO: Change to List once SPARK-3365 is fixed
138
+ def toSplit : Split = {
139
+ new Split (feature, threshold, FeatureType (featureType), categories.toList)
140
+ }
141
+ }
142
+
143
+ object SplitData {
144
+ def apply (s : Split ): SplitData = {
145
+ SplitData (s.feature, s.threshold, s.featureType.id, s.categories)
146
+ }
147
+
148
+ def apply (r : Row ): SplitData = {
149
+ SplitData (r.getInt(0 ), r.getDouble(1 ), r.getInt(2 ), r.getAs[Seq [Double ]](3 ))
150
+ }
151
+ }
152
+
153
+ /** Model data for model import/export */
154
+ case class NodeData (
155
+ treeId : Int ,
156
+ nodeId : Int ,
157
+ predict : PredictData ,
158
+ impurity : Double ,
159
+ isLeaf : Boolean ,
160
+ split : Option [SplitData ],
161
+ leftNodeId : Option [Int ],
162
+ rightNodeId : Option [Int ],
163
+ infoGain : Option [Double ])
164
+
165
+ object NodeData {
166
+ def apply (treeId : Int , n : Node ): NodeData = {
167
+ NodeData (treeId, n.id, PredictData (n.predict), n.impurity, n.isLeaf,
168
+ n.split.map(SplitData .apply), n.leftNode.map(_.id), n.rightNode.map(_.id),
169
+ n.stats.map(_.gain))
170
+ }
171
+
172
+ def apply (r : Row ): NodeData = {
173
+ val split = if (r.isNullAt(5 )) None else Some (SplitData (r.getStruct(5 )))
174
+ val leftNodeId = if (r.isNullAt(6 )) None else Some (r.getInt(6 ))
175
+ val rightNodeId = if (r.isNullAt(7 )) None else Some (r.getInt(7 ))
176
+ val infoGain = if (r.isNullAt(8 )) None else Some (r.getDouble(8 ))
177
+ NodeData (r.getInt(0 ), r.getInt(1 ), PredictData (r.getStruct(2 )), r.getDouble(3 ),
178
+ r.getBoolean(4 ), split, leftNodeId, rightNodeId, infoGain)
179
+ }
180
+ }
181
+
182
+ def save (sc : SparkContext , path : String , model : DecisionTreeModel ): Unit = {
183
+ val sqlContext = new SQLContext (sc)
184
+ import sqlContext .implicits ._
185
+
186
+ // Create JSON metadata.
187
+ val metadataRDD = sc.parallelize(
188
+ Seq ((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1 )
189
+ .toDataFrame(" class" , " version" , " algo" , " numNodes" )
190
+ metadataRDD.toJSON.saveAsTextFile(Loader .metadataPath(path))
191
+
192
+ // Create Parquet data.
193
+ val nodes = model.topNode.subtreeIterator.toSeq
194
+ val dataRDD : DataFrame = sc.parallelize(nodes)
195
+ .map(NodeData .apply(0 , _))
196
+ .toDataFrame
197
+ dataRDD.saveAsParquetFile(Loader .dataPath(path))
198
+ }
199
+
200
+ def load (sc : SparkContext , path : String , algo : String , numNodes : Int ): DecisionTreeModel = {
201
+ val datapath = Loader .dataPath(path)
202
+ val sqlContext = new SQLContext (sc)
203
+ // Load Parquet data.
204
+ val dataRDD = sqlContext.parquetFile(datapath)
205
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
206
+ Loader .checkSchema[NodeData ](dataRDD.schema)
207
+ val nodes = dataRDD.map(NodeData .apply)
208
+ // Build node data into a tree.
209
+ val trees = constructTrees(nodes)
210
+ assert(trees.size == 1 ,
211
+ " Decision tree should contain exactly one tree but got ${trees.size} trees." )
212
+ val model = new DecisionTreeModel (trees(0 ), Algo .fromString(algo))
213
+ assert(model.numNodes == numNodes, s " Unable to load DecisionTreeModel data from: $datapath. " +
214
+ s " Expected $numNodes nodes but found ${model.numNodes}" )
215
+ model
216
+ }
217
+
218
+ def constructTrees (nodes : RDD [NodeData ]): Array [Node ] = {
219
+ val trees = nodes
220
+ .groupBy(_.treeId)
221
+ .mapValues(_.toArray)
222
+ .collect()
223
+ .map { case (treeId, data) =>
224
+ (treeId, constructTree(data))
225
+ }.sortBy(_._1)
226
+ val numTrees = trees.size
227
+ val treeIndices = trees.map(_._1).toSeq
228
+ assert(treeIndices == (0 until numTrees),
229
+ s " Tree indices must start from 0 and increment by 1, but we found $treeIndices. " )
230
+ trees.map(_._2)
231
+ }
232
+
233
+ /**
234
+ * Given a list of nodes from a tree, construct the tree.
235
+ * @param data array of all node data in a tree.
236
+ */
237
+ def constructTree (data : Array [NodeData ]): Node = {
238
+ val dataMap : Map [Int , NodeData ] = data.map(n => n.nodeId -> n).toMap
239
+ assert(dataMap.contains(1 ),
240
+ s " DecisionTree missing root node (id = 1). " )
241
+ constructNode(1 , dataMap, mutable.Map .empty)
242
+ }
243
+
244
+ /**
245
+ * Builds a node from the node data map and adds new nodes to the input nodes map.
246
+ */
247
+ private def constructNode (
248
+ id : Int ,
249
+ dataMap : Map [Int , NodeData ],
250
+ nodes : mutable.Map [Int , Node ]): Node = {
251
+ if (nodes.contains(id)) {
252
+ return nodes(id)
253
+ }
254
+ val data = dataMap(id)
255
+ val node =
256
+ if (data.isLeaf) {
257
+ Node (data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf)
258
+ } else {
259
+ val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes)
260
+ val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes)
261
+ val stats = new InformationGainStats (data.infoGain.get, data.impurity, leftNode.impurity,
262
+ rightNode.impurity, leftNode.predict, rightNode.predict)
263
+ new Node (data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf,
264
+ data.split.map(_.toSplit), Some (leftNode), Some (rightNode), Some (stats))
265
+ }
266
+ nodes += node.id -> node
267
+ node
268
+ }
269
+ }
270
+
271
+ override def load (sc : SparkContext , path : String ): DecisionTreeModel = {
272
+ val (loadedClassName, version, metadata) = Loader .loadMetadata(sc, path)
273
+ val (algo : String , numNodes : Int ) = try {
274
+ val algo_numNodes = metadata.select(" algo" , " numNodes" ).collect()
275
+ assert(algo_numNodes.length == 1 )
276
+ algo_numNodes(0 ) match {
277
+ case Row (a : String , n : Int ) => (a, n)
278
+ }
279
+ } catch {
280
+ // Catch both Error and Exception since the checks above can throw either.
281
+ case e : Throwable =>
282
+ throw new Exception (
283
+ s " Unable to load DecisionTreeModel metadata from: ${Loader .metadataPath(path)}. "
284
+ + s " Error message: ${e.getMessage}" )
285
+ }
286
+ val classNameV1_0 = SaveLoadV1_0 .thisClassName
287
+ (loadedClassName, version) match {
288
+ case (className, " 1.0" ) if className == classNameV1_0 =>
289
+ SaveLoadV1_0 .load(sc, path, algo, numNodes)
290
+ case _ => throw new Exception (
291
+ s " DecisionTreeModel.load did not recognize model with (className, format version): " +
292
+ s " ( $loadedClassName, $version). Supported: \n " +
293
+ s " ( $classNameV1_0, 1.0) " )
294
+ }
295
+ }
101
296
}
0 commit comments