Skip to content

Commit 7e69725

Browse files
committed
simplify FPTree and update FPGrowth
1 parent ec21f7d commit 7e69725

File tree

5 files changed

+251
-437
lines changed

5 files changed

+251
-437
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala

Lines changed: 92 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
package org.apache.spark.mllib.fpm
1919

20-
import scala.collection.mutable.ArrayBuffer
20+
import java.{util => ju}
2121

22-
import org.apache.spark.broadcast.Broadcast
23-
import org.apache.spark.Logging
24-
import org.apache.spark.rdd.RDD
22+
import scala.collection.mutable
2523

24+
import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
25+
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.storage.StorageLevel
2627

28+
class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable
2729

2830
/**
2931
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
@@ -34,125 +36,127 @@ import org.apache.spark.rdd.RDD
3436
*
3537
* @param minSupport the minimal support level of the frequent pattern, any pattern appears
3638
* more than (minSupport * size-of-the-dataset) times will be output
39+
* @param numPartitions number of partitions used by parallel FP-growth
3740
*/
38-
class FPGrowth private(private var minSupport: Double) extends Logging with Serializable {
41+
class FPGrowth private (
42+
private var minSupport: Double,
43+
private var numPartitions: Int) extends Logging with Serializable {
3944

4045
/**
4146
* Constructs a FPGrowth instance with default parameters:
42-
* {minSupport: 0.3}
47+
* {minSupport: 0.3, numPartitions: auto}
4348
*/
44-
def this() = this(0.3)
49+
def this() = this(0.3, -1)
4550

4651
/**
47-
* set the minimal support level, default is 0.3
48-
* @param minSupport minimal support level
52+
* Sets the minimal support level (default: 0.3).
4953
*/
5054
def setMinSupport(minSupport: Double): this.type = {
5155
this.minSupport = minSupport
5256
this
5357
}
5458

5559
/**
56-
* Compute a FPGrowth Model that contains frequent pattern result.
60+
* Sets the number of partitions used by parallel FP-growth (default: same as input data).
61+
*/
62+
def setNumPartitions(numPartitions: Int): this.type = {
63+
this.numPartitions = numPartitions
64+
this
65+
}
66+
67+
/**
68+
* Computes an FP-Growth model that contains frequent itemsets.
5769
* @param data input data set, each element contains a transaction
58-
* @return FPGrowth Model
70+
* @return an [[FPGrowthModel]]
5971
*/
6072
def run(data: RDD[Array[String]]): FPGrowthModel = {
73+
if (data.getStorageLevel == StorageLevel.NONE) {
74+
logWarning("Input data is not cached.")
75+
}
6176
val count = data.count()
62-
val minCount = minSupport * count
63-
val single = generateSingleItem(data, minCount)
64-
val combinations = generateCombinations(data, minCount, single)
65-
val all = single.map(v => (Array[String](v._1), v._2)).union(combinations)
66-
new FPGrowthModel(all.collect())
77+
val minCount = math.ceil(minSupport * count).toLong
78+
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
79+
val partitioner = new HashPartitioner(numParts)
80+
val freqItems = genFreqItems(data, minCount, partitioner)
81+
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
82+
new FPGrowthModel(freqItemsets)
6783
}
6884

6985
/**
70-
* Generate single item pattern by filtering the input data using minimal support level
71-
* @return array of frequent pattern with its count
86+
* Generates frequent items by filtering the input data using minimal support level.
87+
* @param minCount minimum count for frequent itemsets
88+
* @param partitioner partitioner used to distribute items
89+
* @return array of frequent pattern ordered by their frequencies
7290
*/
73-
private def generateSingleItem(
91+
private def genFreqItems(
7492
data: RDD[Array[String]],
75-
minCount: Double): RDD[(String, Long)] = {
76-
val single = data.flatMap(v => v.toSet)
77-
.map(v => (v, 1L))
78-
.reduceByKey(_ + _)
93+
minCount: Long,
94+
partitioner: Partitioner): Array[String] = {
95+
data.flatMap { t =>
96+
val uniq = t.toSet
97+
if (t.length != uniq.size) {
98+
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
99+
}
100+
t
101+
}.map(v => (v, 1L))
102+
.reduceByKey(partitioner, _ + _)
79103
.filter(_._2 >= minCount)
80-
.sortBy(_._2)
81-
single
104+
.collect()
105+
.sortBy(-_._2)
106+
.map(_._1)
82107
}
83108

84109
/**
85-
* Generate combination of frequent pattern by computing on FPTree,
86-
* the computation is done on each FPTree partitions.
87-
* @return array of frequent pattern with its count
110+
* Generate frequent itemsets by building FP-Trees, the extraction is done on each partition.
111+
* @param data transactions
112+
* @param minCount minimum count for frequent itemsets
113+
* @param freqItems frequent items
114+
* @param partitioner partitioner used to distribute transactions
115+
* @return an RDD of (frequent itemset, count)
88116
*/
89-
private def generateCombinations(
117+
private def genFreqItemsets(
90118
data: RDD[Array[String]],
91-
minCount: Double,
92-
singleItem: RDD[(String, Long)]): RDD[(Array[String], Long)] = {
93-
val single = data.context.broadcast(singleItem.collect())
94-
data.flatMap(transaction => createConditionPatternBase(transaction, single))
95-
.aggregateByKey(new FPTree)(
96-
(aggregator, condPattBase) => aggregator.add(condPattBase),
97-
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
98-
.flatMap(partition => partition._2.mine(minCount, partition._1))
119+
minCount: Long,
120+
freqItems: Array[String],
121+
partitioner: Partitioner): RDD[(Array[String], Long)] = {
122+
val itemToRank = freqItems.zipWithIndex.toMap
123+
data.flatMap { transaction =>
124+
genCondTransactions(transaction, itemToRank, partitioner)
125+
}.aggregateByKey(new FPTree[Int], partitioner.numPartitions)(
126+
(tree, transaction) => tree.add(transaction, 1L),
127+
(tree1, tree2) => tree1.merge(tree2))
128+
.flatMap { case (part, tree) =>
129+
tree.extract(minCount, x => partitioner.getPartition(x) == part)
130+
}.map { case (ranks, count) =>
131+
(ranks.map(i => freqItems(i)).toArray, count)
132+
}
99133
}
100134

101135
/**
102-
* Create FP-Tree partition for the giving basket
103-
* @return an array contains a tuple, whose first element is the single
104-
* item (hash key) and second element is its condition pattern base
136+
* Generates conditional transactions.
137+
* @param transaction a transaction
138+
* @param itemToRank map from item to their rank
139+
* @param partitioner partitioner used to distribute transactions
140+
* @return a map of (target partition, conditional transaction)
105141
*/
106-
private def createConditionPatternBase(
142+
private def genCondTransactions(
107143
transaction: Array[String],
108-
singleBC: Broadcast[Array[(String, Long)]]): Array[(String, Array[String])] = {
109-
var output = ArrayBuffer[(String, Array[String])]()
110-
var combination = ArrayBuffer[String]()
111-
var items = ArrayBuffer[(String, Long)]()
112-
val single = singleBC.value
113-
val singleMap = single.toMap
114-
115-
// Filter the basket by single item pattern and sort
116-
// by single item and its count
117-
val candidates = transaction
118-
.filter(singleMap.contains)
119-
.map(item => (item, singleMap(item)))
120-
.sortBy(_._1)
121-
.sortBy(_._2)
122-
.toArray
123-
124-
val itemIterator = candidates.iterator
125-
while (itemIterator.hasNext) {
126-
combination.clear()
127-
val item = itemIterator.next()
128-
val firstNItems = candidates.take(candidates.indexOf(item))
129-
if (firstNItems.length > 0) {
130-
val iterator = firstNItems.iterator
131-
while (iterator.hasNext) {
132-
val elem = iterator.next()
133-
combination += elem._1
134-
}
135-
output += ((item._1, combination.toArray))
144+
itemToRank: Map[String, Int],
145+
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
146+
val output = mutable.Map.empty[Int, Array[Int]]
147+
// Filter the basket by frequent items pattern and sort their ranks.
148+
val filtered = transaction.flatMap(itemToRank.get)
149+
ju.Arrays.sort(filtered)
150+
val n = filtered.length
151+
var i = n - 1
152+
while (i >= 0) {
153+
val item = filtered(i)
154+
val part = partitioner.getPartition(item)
155+
if (!output.contains(part)) {
156+
output(part) = filtered.slice(0, i + 1)
136157
}
158+
i -= 1
137159
}
138-
output.toArray
160+
output
139161
}
140-
141162
}
142-
143-
/**
144-
* Top-level methods for calling FPGrowth.
145-
*/
146-
object FPGrowth{
147-
148-
/**
149-
* Generate a FPGrowth Model using the given minimal support level.
150-
*
151-
* @param data input baskets stored as `RDD[Array[String]]`
152-
* @param minSupport minimal support level, for example 0.5
153-
*/
154-
def train(data: RDD[Array[String]], minSupport: Double): FPGrowthModel = {
155-
new FPGrowth().setMinSupport(minSupport).run(data)
156-
}
157-
}
158-

mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)