@@ -49,6 +49,7 @@ class PrefixSpan private (
49
49
* The maximum number of items allowed in a projected database before local processing. If a
50
50
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
51
51
*/
52
+ // TODO: make configurable with a better default value, 10000 may be too small
52
53
private val maxLocalProjDBSize : Long = 10000
53
54
54
55
/**
@@ -61,7 +62,7 @@ class PrefixSpan private (
61
62
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
62
63
* frequent).
63
64
*/
64
- def getMinSupport () : Double = this .minSupport
65
+ def getMinSupport : Double = this .minSupport
65
66
66
67
/**
67
68
* Sets the minimal support level (default: `0.1`).
@@ -75,7 +76,7 @@ class PrefixSpan private (
75
76
/**
76
77
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
77
78
*/
78
- def getMaxPatternLength () : Double = this .maxPatternLength
79
+ def getMaxPatternLength : Double = this .maxPatternLength
79
80
80
81
/**
81
82
* Sets maximal pattern length (default: `10`).
@@ -96,6 +97,8 @@ class PrefixSpan private (
96
97
* the value of pair is the pattern's count.
97
98
*/
98
99
def run (sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Long )] = {
100
+ val sc = sequences.sparkContext
101
+
99
102
if (sequences.getStorageLevel == StorageLevel .NONE ) {
100
103
logWarning(" Input data is not cached." )
101
104
}
@@ -108,10 +111,11 @@ class PrefixSpan private (
108
111
.flatMap(seq => seq.distinct.map(item => (item, 1L )))
109
112
.reduceByKey(_ + _)
110
113
.filter(_._2 >= minCount)
114
+ .collect()
111
115
112
116
// Pairs of (length 1 prefix, suffix consisting of frequent items)
113
117
val itemSuffixPairs = {
114
- val freqItems = freqItemCounts.keys.collect( ).toSet
118
+ val freqItems = freqItemCounts.map(_._1 ).toSet
115
119
sequences.flatMap { seq =>
116
120
val filteredSeq = seq.filter(freqItems.contains(_))
117
121
freqItems.flatMap { item =>
@@ -141,13 +145,14 @@ class PrefixSpan private (
141
145
pairsForDistributed = largerPairsPart
142
146
pairsForDistributed.persist(StorageLevel .MEMORY_AND_DISK )
143
147
pairsForLocal ++= smallerPairsPart
144
- resultsAccumulator ++= nextPatternAndCounts
148
+ resultsAccumulator ++= nextPatternAndCounts.collect()
145
149
}
146
150
147
151
// Process the small projected databases locally
148
- resultsAccumulator ++ = getPatternsInLocal(minCount, pairsForLocal.groupByKey())
152
+ val remainingResults = getPatternsInLocal(minCount, pairsForLocal.groupByKey())
149
153
150
- resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) }
154
+ (sc.parallelize(resultsAccumulator, 1 ) ++ remainingResults)
155
+ .map { case (pattern, count) => (pattern.toArray, count) }
151
156
}
152
157
153
158
0 commit comments