17
17
18
18
package org .apache .spark .mllib .fpm
19
19
20
- import scala . collection . mutable . ArrayBuffer
20
+ import java .{ util => ju }
21
21
22
- import org .apache .spark .broadcast .Broadcast
23
- import org .apache .spark .Logging
24
- import org .apache .spark .rdd .RDD
22
+ import scala .collection .mutable
25
23
24
+ import org .apache .spark .{SparkException , HashPartitioner , Logging , Partitioner }
25
+ import org .apache .spark .rdd .RDD
26
+ import org .apache .spark .storage .StorageLevel
26
27
28
+ class FPGrowthModel (val freqItemsets : RDD [(Array [String ], Long )]) extends Serializable
27
29
28
30
/**
29
31
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
@@ -34,125 +36,127 @@ import org.apache.spark.rdd.RDD
34
36
*
35
37
* @param minSupport the minimal support level of the frequent pattern, any pattern appears
36
38
* more than (minSupport * size-of-the-dataset) times will be output
39
+ * @param numPartitions number of partitions used by parallel FP-growth
37
40
*/
38
- class FPGrowth private (private var minSupport : Double ) extends Logging with Serializable {
41
+ class FPGrowth private (
42
+ private var minSupport : Double ,
43
+ private var numPartitions : Int ) extends Logging with Serializable {
39
44
40
45
/**
41
46
* Constructs a FPGrowth instance with default parameters:
42
- * {minSupport: 0.3}
47
+ * {minSupport: 0.3, numPartitions: auto }
43
48
*/
44
- def this () = this (0.3 )
49
+ def this () = this (0.3 , - 1 )
45
50
46
51
/**
47
- * set the minimal support level, default is 0.3
48
- * @param minSupport minimal support level
52
+ * Sets the minimal support level (default: 0.3).
49
53
*/
50
54
def setMinSupport (minSupport : Double ): this .type = {
51
55
this .minSupport = minSupport
52
56
this
53
57
}
54
58
55
59
/**
56
- * Compute a FPGrowth Model that contains frequent pattern result.
60
+ * Sets the number of partitions used by parallel FP-growth (default: same as input data).
61
+ */
62
+ def setNumPartitions (numPartitions : Int ): this .type = {
63
+ this .numPartitions = numPartitions
64
+ this
65
+ }
66
+
67
+ /**
68
+ * Computes an FP-Growth model that contains frequent itemsets.
57
69
* @param data input data set, each element contains a transaction
58
- * @return FPGrowth Model
70
+ * @return an [[ FPGrowthModel ]]
59
71
*/
60
72
def run (data : RDD [Array [String ]]): FPGrowthModel = {
73
+ if (data.getStorageLevel == StorageLevel .NONE ) {
74
+ logWarning(" Input data is not cached." )
75
+ }
61
76
val count = data.count()
62
- val minCount = minSupport * count
63
- val single = generateSingleItem(data, minCount)
64
- val combinations = generateCombinations(data, minCount, single)
65
- val all = single.map(v => (Array [String ](v._1), v._2)).union(combinations)
66
- new FPGrowthModel (all.collect())
77
+ val minCount = math.ceil(minSupport * count).toLong
78
+ val numParts = if (numPartitions > 0 ) numPartitions else data.partitions.length
79
+ val partitioner = new HashPartitioner (numParts)
80
+ val freqItems = genFreqItems(data, minCount, partitioner)
81
+ val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
82
+ new FPGrowthModel (freqItemsets)
67
83
}
68
84
69
85
/**
70
- * Generate single item pattern by filtering the input data using minimal support level
71
- * @return array of frequent pattern with its count
86
+ * Generates frequent items by filtering the input data using minimal support level.
87
+ * @param minCount minimum count for frequent itemsets
88
+ * @param partitioner partitioner used to distribute items
89
+ * @return array of frequent pattern ordered by their frequencies
72
90
*/
73
- private def generateSingleItem (
91
+ private def genFreqItems (
74
92
data : RDD [Array [String ]],
75
- minCount : Double ): RDD [(String , Long )] = {
76
- val single = data.flatMap(v => v.toSet)
77
- .map(v => (v, 1L ))
78
- .reduceByKey(_ + _)
93
+ minCount : Long ,
94
+ partitioner : Partitioner ): Array [String ] = {
95
+ data.flatMap { t =>
96
+ val uniq = t.toSet
97
+ if (t.length != uniq.size) {
98
+ throw new SparkException (s " Items in a transaction must be unique but got ${t.toSeq}. " )
99
+ }
100
+ t
101
+ }.map(v => (v, 1L ))
102
+ .reduceByKey(partitioner, _ + _)
79
103
.filter(_._2 >= minCount)
80
- .sortBy(_._2)
81
- single
104
+ .collect()
105
+ .sortBy(- _._2)
106
+ .map(_._1)
82
107
}
83
108
84
109
/**
85
- * Generate combination of frequent pattern by computing on FPTree,
86
- * the computation is done on each FPTree partitions.
87
- * @return array of frequent pattern with its count
110
+ * Generate frequent itemsets by building FP-Trees, the extraction is done on each partition.
111
+ * @param data transactions
112
+ * @param minCount minimum count for frequent itemsets
113
+ * @param freqItems frequent items
114
+ * @param partitioner partitioner used to distribute transactions
115
+ * @return an RDD of (frequent itemset, count)
88
116
*/
89
- private def generateCombinations (
117
+ private def genFreqItemsets (
90
118
data : RDD [Array [String ]],
91
- minCount : Double ,
92
- singleItem : RDD [(String , Long )]): RDD [(Array [String ], Long )] = {
93
- val single = data.context.broadcast(singleItem.collect())
94
- data.flatMap(transaction => createConditionPatternBase(transaction, single))
95
- .aggregateByKey(new FPTree )(
96
- (aggregator, condPattBase) => aggregator.add(condPattBase),
97
- (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
98
- .flatMap(partition => partition._2.mine(minCount, partition._1))
119
+ minCount : Long ,
120
+ freqItems : Array [String ],
121
+ partitioner : Partitioner ): RDD [(Array [String ], Long )] = {
122
+ val itemToRank = freqItems.zipWithIndex.toMap
123
+ data.flatMap { transaction =>
124
+ genCondTransactions(transaction, itemToRank, partitioner)
125
+ }.aggregateByKey(new FPTree [Int ], partitioner.numPartitions)(
126
+ (tree, transaction) => tree.add(transaction, 1L ),
127
+ (tree1, tree2) => tree1.merge(tree2))
128
+ .flatMap { case (part, tree) =>
129
+ tree.extract(minCount, x => partitioner.getPartition(x) == part)
130
+ }.map { case (ranks, count) =>
131
+ (ranks.map(i => freqItems(i)).toArray, count)
132
+ }
99
133
}
100
134
101
135
/**
102
- * Create FP-Tree partition for the giving basket
103
- * @return an array contains a tuple, whose first element is the single
104
- * item (hash key) and second element is its condition pattern base
136
+ * Generates conditional transactions.
137
+ * @param transaction a transaction
138
+ * @param itemToRank map from item to their rank
139
+ * @param partitioner partitioner used to distribute transactions
140
+ * @return a map of (target partition, conditional transaction)
105
141
*/
106
- private def createConditionPatternBase (
142
+ private def genCondTransactions (
107
143
transaction : Array [String ],
108
- singleBC : Broadcast [Array [(String , Long )]]): Array [(String , Array [String ])] = {
109
- var output = ArrayBuffer [(String , Array [String ])]()
110
- var combination = ArrayBuffer [String ]()
111
- var items = ArrayBuffer [(String , Long )]()
112
- val single = singleBC.value
113
- val singleMap = single.toMap
114
-
115
- // Filter the basket by single item pattern and sort
116
- // by single item and its count
117
- val candidates = transaction
118
- .filter(singleMap.contains)
119
- .map(item => (item, singleMap(item)))
120
- .sortBy(_._1)
121
- .sortBy(_._2)
122
- .toArray
123
-
124
- val itemIterator = candidates.iterator
125
- while (itemIterator.hasNext) {
126
- combination.clear()
127
- val item = itemIterator.next()
128
- val firstNItems = candidates.take(candidates.indexOf(item))
129
- if (firstNItems.length > 0 ) {
130
- val iterator = firstNItems.iterator
131
- while (iterator.hasNext) {
132
- val elem = iterator.next()
133
- combination += elem._1
134
- }
135
- output += ((item._1, combination.toArray))
144
+ itemToRank : Map [String , Int ],
145
+ partitioner : Partitioner ): mutable.Map [Int , Array [Int ]] = {
146
+ val output = mutable.Map .empty[Int , Array [Int ]]
147
+ // Filter the basket by frequent items pattern and sort their ranks.
148
+ val filtered = transaction.flatMap(itemToRank.get)
149
+ ju.Arrays .sort(filtered)
150
+ val n = filtered.length
151
+ var i = n - 1
152
+ while (i >= 0 ) {
153
+ val item = filtered(i)
154
+ val part = partitioner.getPartition(item)
155
+ if (! output.contains(part)) {
156
+ output(part) = filtered.slice(0 , i + 1 )
136
157
}
158
+ i -= 1
137
159
}
138
- output.toArray
160
+ output
139
161
}
140
-
141
162
}
142
-
143
- /**
144
- * Top-level methods for calling FPGrowth.
145
- */
146
- object FPGrowth {
147
-
148
- /**
149
- * Generate a FPGrowth Model using the given minimal support level.
150
- *
151
- * @param data input baskets stored as `RDD[Array[String]]`
152
- * @param minSupport minimal support level, for example 0.5
153
- */
154
- def train (data : RDD [Array [String ]], minSupport : Double ): FPGrowthModel = {
155
- new FPGrowth ().setMinSupport(minSupport).run(data)
156
- }
157
- }
158
-
0 commit comments