17
17
18
18
package org .apache .spark .mllib .fpm
19
19
20
+ import org .apache .spark .Logging
20
21
import org .apache .spark .annotation .Experimental
21
22
import org .apache .spark .rdd .RDD
23
+ import org .apache .spark .storage .StorageLevel
22
24
23
25
/**
24
26
*
@@ -37,165 +39,206 @@ import org.apache.spark.rdd.RDD
37
39
* (Wikipedia)]]
38
40
*/
39
41
@ Experimental
40
- class PrefixSpan (
42
+ class PrefixSpan private (
41
43
private var minSupport : Double ,
42
- private var maxPatternLength : Int ) extends java.io.Serializable {
43
-
44
- private var absMinSupport : Int = 0
44
+ private var maxPatternLength : Int ) extends Logging with Serializable {
45
45
46
46
/**
47
47
* Constructs a default instance with default parameters
48
- * {minSupport: `0.1`, maxPatternLength: 10 }.
48
+ * {minSupport: `0.1`, maxPatternLength: `10` }.
49
49
*/
50
50
def this () = this (0.1 , 10 )
51
51
52
52
/**
53
53
* Sets the minimal support level (default: `0.1`).
54
54
*/
55
55
def setMinSupport (minSupport : Double ): this .type = {
56
+ require(minSupport >= 0 && minSupport <= 1 )
56
57
this .minSupport = minSupport
57
58
this
58
59
}
59
60
60
61
/**
61
- * Sets maximal pattern length.
62
+ * Sets maximal pattern length (default: `10`) .
62
63
*/
63
64
def setMaxPatternLength (maxPatternLength : Int ): this .type = {
65
+ require(maxPatternLength >= 1 )
64
66
this .maxPatternLength = maxPatternLength
65
67
this
66
68
}
67
69
68
70
/**
69
- * Calculate sequential patterns:
70
- * a) find and collect length-one patterns
71
- * b) for each length-one patterns and each sequence,
72
- * emit (pattern (prefix), suffix sequence) as key-value pairs
73
- * c) group by key and then map value iterator to array
74
- * d) local PrefixSpan on each prefix
75
- * @return sequential patterns
71
+ * Find the complete set of sequential patterns in the input sequences.
72
+ * @param sequences input data set, contains a set of sequences,
73
+ * a sequence is an ordered list of elements.
74
+ * @return a set of sequential pattern pairs,
75
+ * the key of pair is pattern (a list of elements),
76
+ * the value of pair is the pattern's support value.
76
77
*/
77
- def run (sequences : RDD [Array [Int ]]): RDD [(Seq [Int ], Int )] = {
78
- absMinSupport = getAbsoluteMinSupport(sequences)
78
+ def run (sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Long )] = {
79
+ if (sequences.getStorageLevel == StorageLevel .NONE ) {
80
+ logWarning(" Input data is not cached." )
81
+ }
82
+ val minCount = getAbsoluteMinSupport(sequences)
79
83
val (lengthOnePatternsAndCounts, prefixAndCandidates) =
80
- findLengthOnePatterns(sequences)
84
+ findLengthOnePatterns(minCount, sequences)
81
85
val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates)
82
- val nextPatterns = getPatternsInLocal(repartitionedRdd)
83
- val allPatterns = lengthOnePatternsAndCounts.map(x => (Seq (x._1), x._2)) ++ nextPatterns
86
+ val nextPatterns = getPatternsInLocal(minCount, repartitionedRdd)
87
+ val allPatterns = lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)) ++ nextPatterns
84
88
allPatterns
85
89
}
86
90
87
- private def getAbsoluteMinSupport (sequences : RDD [Array [Int ]]): Int = {
88
- val result = if (minSupport <= 0 ) {
89
- 0
90
- } else {
91
- val count = sequences.count()
92
- val support = if (minSupport <= 1 ) minSupport else 1
93
- (support * count).toInt
94
- }
95
- result
91
+ /**
92
+ * Get the absolute minimum support value (sequences count * minSupport).
93
+ * @param sequences input data set, contains a set of sequences,
94
+ * @return absolute minimum support value,
95
+ */
96
+ private def getAbsoluteMinSupport (sequences : RDD [Array [Int ]]): Long = {
97
+ if (minSupport == 0 ) 0L else (sequences.count() * minSupport).toLong
96
98
}
97
99
98
100
/**
99
- * Find the patterns that it's length is one
101
+ * Generates frequent items by filtering the input data using minimal support level.
102
+ * @param minCount the absolute minimum support
100
103
* @param sequences original sequences data
101
- * @return length-one patterns and projection table
104
+ * @return array of frequent pattern ordered by their frequencies
102
105
*/
103
- private def findLengthOnePatterns (
104
- sequences : RDD [ Array [ Int ]]) : ( RDD [( Int , Int )], RDD [( Seq [ Int ], Array [ Int ])]) = {
105
- val LengthOnePatternAndCounts = sequences
106
- .flatMap(_.distinct.map((_, 1 )))
106
+ private def getFreqItemAndCounts (
107
+ minCount : Long ,
108
+ sequences : RDD [ Array [ Int ]]) : RDD [( Int , Long )] = {
109
+ sequences .flatMap(_.distinct.map((_, 1L )))
107
110
.reduceByKey(_ + _)
108
- val infrequentLengthOnePatterns : Array [Int ] = LengthOnePatternAndCounts
109
- .filter(_._2 < absMinSupport)
110
- .map(_._1)
111
- .collect()
112
- val frequentLengthOnePatterns = LengthOnePatternAndCounts
113
- .filter(_._2 >= absMinSupport)
114
- val frequentLengthOnePatternsArray = frequentLengthOnePatterns
115
- .map(_._1)
116
- .collect()
117
- val filteredSequences =
118
- if (infrequentLengthOnePatterns.isEmpty) {
119
- sequences
120
- } else {
121
- sequences.map { p =>
122
- p.filter { x => ! infrequentLengthOnePatterns.contains(x) }
123
- }
124
- }
125
- val prefixAndCandidates = filteredSequences.flatMap { x =>
126
- frequentLengthOnePatternsArray.map { y =>
111
+ .filter(_._2 >= minCount)
112
+ }
113
+
114
+ /**
115
+ * Generates frequent items by filtering the input data using minimal support level.
116
+ * @param minCount the absolute minimum support
117
+ * @param sequences sequences data
118
+ * @return array of frequent pattern ordered by their frequencies
119
+ */
120
+ private def getFreqItemAndCounts (
121
+ minCount : Long ,
122
+ sequences : Array [Array [Int ]]): Array [(Int , Long )] = {
123
+ sequences.flatMap(_.distinct)
124
+ .groupBy(x => x)
125
+ .mapValues(_.length.toLong)
126
+ .filter(_._2 >= minCount)
127
+ .toArray
128
+ }
129
+
130
+ /**
131
+ * Get the frequent prefixes' projected database.
132
+ * @param frequentPrefixes frequent prefixes
133
+ * @param sequences sequences data
134
+ * @return prefixes and projected database
135
+ */
136
+ private def getPatternAndProjectedDatabase (
137
+ frequentPrefixes : Array [Int ],
138
+ sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Array [Int ])] = {
139
+ val filteredSequences = sequences.map { p =>
140
+ p.filter (frequentPrefixes.contains(_) )
141
+ }
142
+ filteredSequences.flatMap { x =>
143
+ frequentPrefixes.map { y =>
127
144
val sub = getSuffix(y, x)
128
- (Seq (y), sub)
145
+ (Array (y), sub)
129
146
}
130
147
}.filter(x => x._2.nonEmpty)
131
- (frequentLengthOnePatterns, prefixAndCandidates)
132
148
}
133
149
134
150
/**
135
- * Re-partition the RDD data, to get better balance and performance.
151
+ * Get the frequent prefixes' projected database.
152
+ * @param prePrefix the frequent prefixes' prefix
153
+ * @param frequentPrefixes frequent prefixes
154
+ * @param sequences sequences data
155
+ * @return prefixes and projected database
156
+ */
157
+ private def getPatternAndProjectedDatabase (
158
+ prePrefix : Array [Int ],
159
+ frequentPrefixes : Array [Int ],
160
+ sequences : Array [Array [Int ]]): Array [(Array [Int ], Array [Array [Int ]])] = {
161
+ val filteredProjectedDatabase = sequences
162
+ .map(x => x.filter(frequentPrefixes.contains(_)))
163
+ frequentPrefixes.map { x =>
164
+ val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
165
+ (prePrefix ++ Array (x), sub)
166
+ }.filter(x => x._2.nonEmpty)
167
+ }
168
+
169
+ /**
170
+ * Find the patterns that it's length is one
171
+ * @param minCount the absolute minimum support
172
+ * @param sequences original sequences data
173
+ * @return length-one patterns and projection table
174
+ */
175
+ private def findLengthOnePatterns (
176
+ minCount : Long ,
177
+ sequences : RDD [Array [Int ]]): (RDD [(Int , Long )], RDD [(Array [Int ], Array [Int ])]) = {
178
+ val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences)
179
+ val prefixAndProjectedDatabase = getPatternAndProjectedDatabase(
180
+ frequentLengthOnePatternAndCounts.keys.collect(), sequences)
181
+ (frequentLengthOnePatternAndCounts, prefixAndProjectedDatabase)
182
+ }
183
+
184
+ /**
185
+ * Constructs prefix-projected databases from (prefix, suffix) pairs.
136
186
* @param data patterns and projected sequences data before re-partition
137
187
* @return patterns and projected sequences data after re-partition
138
188
*/
139
189
private def makePrefixProjectedDatabases (
140
- data : RDD [(Seq [Int ], Array [Int ])]): RDD [(Seq [Int ], Array [Array [Int ]])] = {
141
- val dataMerged = data
190
+ data : RDD [(Array [Int ], Array [Int ])]): RDD [(Array [Int ], Array [Array [Int ]])] = {
191
+ data.map(x => (x._1.toSeq, x._2))
142
192
.groupByKey()
143
- .mapValues(_.toArray)
144
- dataMerged
193
+ .map(x => (x._1.toArray, x._2.toArray))
145
194
}
146
195
147
196
/**
148
197
* calculate the patterns in local.
198
+ * @param minCount the absolute minimum support
149
199
* @param data patterns and projected sequences data data
150
200
* @return patterns
151
201
*/
152
202
private def getPatternsInLocal (
153
- data : RDD [(Seq [Int ], Array [Array [Int ]])]): RDD [(Seq [Int ], Int )] = {
154
- val result = data.flatMap { x =>
155
- getPatternsWithPrefix(x._1, x._2)
203
+ minCount : Long ,
204
+ data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(Array [Int ], Long )] = {
205
+ data.flatMap { x =>
206
+ getPatternsWithPrefix(minCount, x._1, x._2)
156
207
}
157
- result
158
208
}
159
209
160
210
/**
161
211
* calculate the patterns with one prefix in local.
212
+ * @param minCount the absolute minimum support
162
213
* @param prefix prefix
163
214
* @param projectedDatabase patterns and projected sequences data
164
215
* @return patterns
165
216
*/
166
217
private def getPatternsWithPrefix (
167
- prefix : Seq [Int ],
168
- projectedDatabase : Array [Array [Int ]]): Array [(Seq [Int ], Int )] = {
169
- val prefixAndCounts = projectedDatabase
170
- .flatMap(_.distinct)
171
- .groupBy(x => x)
172
- .mapValues(_.length)
173
- val frequentPrefixExtensions = prefixAndCounts.filter(x => x._2 >= absMinSupport)
174
- val frequentPrefixesAndCounts = frequentPrefixExtensions
175
- .map(x => (prefix ++ Seq (x._1), x._2))
176
- .toArray
177
- val cleanedSearchSpace = projectedDatabase
178
- .map(x => x.filter(y => frequentPrefixExtensions.contains(y)))
179
- val prefixProjectedDatabases = frequentPrefixExtensions.map { x =>
180
- val sub = cleanedSearchSpace.map(y => getSuffix(x._1, y)).filter(_.nonEmpty)
181
- (prefix ++ Seq (x._1), sub)
182
- }.filter(x => x._2.nonEmpty)
183
- .toArray
218
+ minCount : Long ,
219
+ prefix : Array [Int ],
220
+ projectedDatabase : Array [Array [Int ]]): Array [(Array [Int ], Long )] = {
221
+ val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
222
+ val frequentPatternAndCounts = frequentPrefixAndCounts
223
+ .map(x => (prefix ++ Array (x._1), x._2))
224
+ val prefixProjectedDatabases = getPatternAndProjectedDatabase(
225
+ prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
226
+
184
227
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
185
228
if (continueProcess) {
186
229
val nextPatterns = prefixProjectedDatabases
187
- .map(x => getPatternsWithPrefix(x._1, x._2))
230
+ .map(x => getPatternsWithPrefix(minCount, x._1, x._2))
188
231
.reduce(_ ++ _)
189
- frequentPrefixesAndCounts ++ nextPatterns
232
+ frequentPatternAndCounts ++ nextPatterns
190
233
} else {
191
- frequentPrefixesAndCounts
234
+ frequentPatternAndCounts
192
235
}
193
236
}
194
237
195
238
/**
196
239
* calculate suffix sequence following a prefix in a sequence
197
240
* @param prefix prefix
198
- * @param sequence original sequence
241
+ * @param sequence sequence
199
242
* @return suffix sequence
200
243
*/
201
244
private def getSuffix (prefix : Int , sequence : Array [Int ]): Array [Int ] = {
0 commit comments