17
17
18
18
package org .apache .spark .mllib .fpm
19
19
20
+ import scala .collection .mutable .ArrayBuffer
21
+
20
22
import org .apache .spark .Logging
21
23
import org .apache .spark .annotation .Experimental
22
24
import org .apache .spark .rdd .RDD
@@ -43,7 +45,7 @@ class PrefixSpan private (
43
45
private var minSupport : Double ,
44
46
private var maxPatternLength : Int ) extends Logging with Serializable {
45
47
46
- private val minPatternsBeforeShuffle : Int = 20
48
+ private val minPatternsBeforeLocalProcessing : Int = 20
47
49
48
50
/**
49
51
* Constructs a default instance with default parameters
@@ -88,66 +90,65 @@ class PrefixSpan private (
88
90
val prefixSuffixPairs = getPrefixSuffixPairs(
89
91
lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
90
92
var patternsCount : Long = lengthOnePatternsAndCounts.count()
91
- var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2))
93
+ var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer (x._1), x._2))
92
94
var currentPrefixSuffixPairs = prefixSuffixPairs
93
- while (patternsCount <= minPatternsBeforeShuffle && currentPrefixSuffixPairs.count() != 0 ) {
95
+ var patternLength : Int = 1
96
+ while (patternLength < maxPatternLength &&
97
+ patternsCount <= minPatternsBeforeLocalProcessing &&
98
+ currentPrefixSuffixPairs.count() != 0 ) {
94
99
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
95
100
getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs)
96
- patternsCount = nextPatternAndCounts.count().toInt
101
+ patternsCount = nextPatternAndCounts.count()
97
102
currentPrefixSuffixPairs = nextPrefixSuffixPairs
98
103
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
104
+ patternLength = patternLength + 1
99
105
}
100
- if (patternsCount > 0 ) {
106
+ if (patternLength < maxPatternLength && patternsCount > 0 ) {
101
107
val projectedDatabase = currentPrefixSuffixPairs
102
108
.map(x => (x._1.toSeq, x._2))
103
109
.groupByKey()
104
110
.map(x => (x._1.toArray, x._2.toArray))
105
111
val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
106
112
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
107
113
}
108
- allPatternAndCounts
114
+ allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) }
109
115
}
110
116
111
117
/**
112
118
* Get the pattern and counts, and prefix suffix pairs
113
119
* @param minCount minimum count
114
- * @param prefixSuffixPairs prefix and suffix pairs,
115
- * @return pattern and counts, and prefix suffix pairs
116
- * (Array [pattern, count], RDD[prefix, suffix ])
120
+ * @param prefixSuffixPairs prefix (length n) and suffix pairs,
121
+ * @return pattern (length n+1) and counts, and prefix (length n+1) and suffix pairs
122
+ * (RDD [pattern, count], RDD[prefix, suffix ])
117
123
*/
118
124
private def getPatternCountsAndPrefixSuffixPairs (
119
125
minCount : Long ,
120
- prefixSuffixPairs : RDD [(Array [Int ], Array [Int ])]):
121
- (RDD [(Array [Int ], Long )], RDD [(Array [Int ], Array [Int ])]) = {
122
- val prefixAndFreqentItemAndCounts = prefixSuffixPairs
123
- .flatMap { case (prefix, suffix) =>
124
- suffix.distinct.map(y => ((prefix.toSeq, y), 1L ))
125
- }.reduceByKey(_ + _)
126
+ prefixSuffixPairs : RDD [(ArrayBuffer [Int ], Array [Int ])]):
127
+ (RDD [(ArrayBuffer [Int ], Long )], RDD [(ArrayBuffer [Int ], Array [Int ])]) = {
128
+ val prefixAndFrequentItemAndCounts = prefixSuffixPairs
129
+ .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L )) }
130
+ .reduceByKey(_ + _)
126
131
.filter(_._2 >= minCount)
127
- val patternAndCounts = prefixAndFreqentItemAndCounts
128
- .map{ case ((prefix, item), count) => (prefix.toArray :+ item, count) }
129
- val prefixlength = prefixSuffixPairs.first()._1.length
130
- if (prefixlength + 1 >= maxPatternLength) {
131
- (patternAndCounts, prefixSuffixPairs.filter(x => false ))
132
- } else {
133
- val frequentItemsMap = prefixAndFreqentItemAndCounts
134
- .keys
135
- .groupByKey()
136
- .mapValues(_.toSet)
137
- .collect
138
- .toMap
139
- val nextPrefixSuffixPairs = prefixSuffixPairs
140
- .filter(x => frequentItemsMap.contains(x._1))
141
- .flatMap { case (prefix, suffix) =>
142
- val frequentItemSet = frequentItemsMap(prefix)
143
- val filteredSuffix = suffix.filter(frequentItemSet.contains(_))
144
- val nextSuffixes = frequentItemSet.map{ item =>
145
- (item, LocalPrefixSpan .getSuffix(item, filteredSuffix))
146
- }.filter(_._2.nonEmpty)
147
- nextSuffixes.map { case (item, suffix) => (prefix :+ item, suffix) }
132
+ val patternAndCounts = prefixAndFrequentItemAndCounts
133
+ .map { case ((prefix, item), count) => (prefix :+ item, count) }
134
+ val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts
135
+ .keys
136
+ .groupByKey()
137
+ .mapValues(_.toSet)
138
+ .collect()
139
+ .toMap
140
+ val nextPrefixSuffixPairs = prefixSuffixPairs
141
+ .filter(x => prefixToFrequentNextItemsMap.contains(x._1))
142
+ .flatMap { case (prefix, suffix) =>
143
+ val frequentNextItems = prefixToFrequentNextItemsMap(prefix)
144
+ val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
145
+ frequentNextItems.flatMap { item =>
146
+ val suffix = LocalPrefixSpan .getSuffix(item, filteredSuffix)
147
+ if (suffix.isEmpty) None
148
+ else Some (prefix :+ item, suffix)
148
149
}
149
- (patternAndCounts, nextPrefixSuffixPairs)
150
150
}
151
+ (patternAndCounts, nextPrefixSuffixPairs)
151
152
}
152
153
153
154
/**
@@ -181,14 +182,14 @@ class PrefixSpan private (
181
182
*/
182
183
private def getPrefixSuffixPairs (
183
184
frequentPrefixes : Array [Int ],
184
- sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Array [Int ])] = {
185
+ sequences : RDD [Array [Int ]]): RDD [(ArrayBuffer [Int ], Array [Int ])] = {
185
186
val filteredSequences = sequences.map { p =>
186
187
p.filter (frequentPrefixes.contains(_) )
187
188
}
188
189
filteredSequences.flatMap { x =>
189
190
frequentPrefixes.map { y =>
190
191
val sub = LocalPrefixSpan .getSuffix(y, x)
191
- (Array (y), sub)
192
+ (ArrayBuffer (y), sub)
192
193
}.filter(_._2.nonEmpty)
193
194
}
194
195
}
@@ -201,9 +202,9 @@ class PrefixSpan private (
201
202
*/
202
203
private def getPatternsInLocal (
203
204
minCount : Long ,
204
- data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(Array [Int ], Long )] = {
205
- data.flatMap { x =>
206
- LocalPrefixSpan .run(minCount, maxPatternLength, x._1, x._2)
207
- }
205
+ data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(ArrayBuffer [Int ], Long )] = {
206
+ data
207
+ .flatMap { x => LocalPrefixSpan .run(minCount, maxPatternLength, x._1, x._2) }
208
+ .map { case (pattern, count) => (pattern.to[ ArrayBuffer ], count) }
208
209
}
209
210
}
0 commit comments