Skip to content

Commit 4ddf479

Browse files
author
Feynman Liang
committed
Parallelize freqItemCounts
1 parent ad23aa9 commit 4ddf479

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class PrefixSpan private (
4949
* The maximum number of items allowed in a projected database before local processing. If a
5050
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
5151
*/
52+
// TODO: make configurable with a better default value, 10000 may be too small
5253
private val maxLocalProjDBSize: Long = 10000
5354

5455
/**
@@ -61,7 +62,7 @@ class PrefixSpan private (
6162
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
6263
* frequent).
6364
*/
64-
def getMinSupport(): Double = this.minSupport
65+
def getMinSupport: Double = this.minSupport
6566

6667
/**
6768
* Sets the minimal support level (default: `0.1`).
@@ -75,7 +76,7 @@ class PrefixSpan private (
7576
/**
7677
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
7778
*/
78-
def getMaxPatternLength(): Double = this.maxPatternLength
79+
def getMaxPatternLength: Double = this.maxPatternLength
7980

8081
/**
8182
* Sets maximal pattern length (default: `10`).
@@ -96,6 +97,8 @@ class PrefixSpan private (
9697
* the value of pair is the pattern's count.
9798
*/
9899
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
100+
val sc = sequences.sparkContext
101+
99102
if (sequences.getStorageLevel == StorageLevel.NONE) {
100103
logWarning("Input data is not cached.")
101104
}
@@ -108,10 +111,11 @@ class PrefixSpan private (
108111
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
109112
.reduceByKey(_ + _)
110113
.filter(_._2 >= minCount)
114+
.collect()
111115

112116
// Pairs of (length 1 prefix, suffix consisting of frequent items)
113117
val itemSuffixPairs = {
114-
val freqItems = freqItemCounts.keys.collect().toSet
118+
val freqItems = freqItemCounts.map(_._1).toSet
115119
sequences.flatMap { seq =>
116120
val filteredSeq = seq.filter(freqItems.contains(_))
117121
freqItems.flatMap { item =>
@@ -141,13 +145,14 @@ class PrefixSpan private (
141145
pairsForDistributed = largerPairsPart
142146
pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
143147
pairsForLocal ++= smallerPairsPart
144-
resultsAccumulator ++= nextPatternAndCounts
148+
resultsAccumulator ++= nextPatternAndCounts.collect()
145149
}
146150

147151
// Process the small projected databases locally
148-
resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey())
152+
val remainingResults = getPatternsInLocal(minCount, pairsForLocal.groupByKey())
149153

150-
resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) }
154+
(sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
155+
.map { case (pattern, count) => (pattern.toArray, count) }
151156
}
152157

153158

0 commit comments

Comments
 (0)