Skip to content

Commit 1dd33ad

Browse files
committed
Modified the code according to the review comments.
1 parent 89bc368 commit 1dd33ad

File tree

2 files changed

+201
-108
lines changed

2 files changed

+201
-108
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 127 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
package org.apache.spark.mllib.fpm
1919

20+
import org.apache.spark.Logging
2021
import org.apache.spark.annotation.Experimental
2122
import org.apache.spark.rdd.RDD
23+
import org.apache.spark.storage.StorageLevel
2224

2325
/**
2426
*
@@ -37,165 +39,206 @@ import org.apache.spark.rdd.RDD
3739
* (Wikipedia)]]
3840
*/
3941
@Experimental
40-
class PrefixSpan(
42+
class PrefixSpan private (
4143
private var minSupport: Double,
42-
private var maxPatternLength: Int) extends java.io.Serializable {
43-
44-
private var absMinSupport: Int = 0
44+
private var maxPatternLength: Int) extends Logging with Serializable {
4545

4646
/**
4747
* Constructs a default instance with default parameters
48-
* {minSupport: `0.1`, maxPatternLength: 10}.
48+
* {minSupport: `0.1`, maxPatternLength: `10`}.
4949
*/
5050
def this() = this(0.1, 10)
5151

5252
/**
5353
* Sets the minimal support level (default: `0.1`).
5454
*/
5555
def setMinSupport(minSupport: Double): this.type = {
56+
require(minSupport >= 0 && minSupport <= 1)
5657
this.minSupport = minSupport
5758
this
5859
}
5960

6061
/**
61-
* Sets maximal pattern length.
62+
* Sets maximal pattern length (default: `10`).
6263
*/
6364
def setMaxPatternLength(maxPatternLength: Int): this.type = {
65+
require(maxPatternLength >= 1)
6466
this.maxPatternLength = maxPatternLength
6567
this
6668
}
6769

6870
/**
69-
* Calculate sequential patterns:
70-
* a) find and collect length-one patterns
71-
* b) for each length-one patterns and each sequence,
72-
* emit (pattern (prefix), suffix sequence) as key-value pairs
73-
* c) group by key and then map value iterator to array
74-
* d) local PrefixSpan on each prefix
75-
* @return sequential patterns
71+
* Find the complete set of sequential patterns in the input sequences.
72+
* @param sequences input data set, contains a set of sequences,
73+
* a sequence is an ordered list of elements.
74+
* @return a set of sequential pattern pairs,
75+
* the key of pair is pattern (a list of elements),
76+
* the value of pair is the pattern's support value.
7677
*/
77-
def run(sequences: RDD[Array[Int]]): RDD[(Seq[Int], Int)] = {
78-
absMinSupport = getAbsoluteMinSupport(sequences)
78+
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
79+
if (sequences.getStorageLevel == StorageLevel.NONE) {
80+
logWarning("Input data is not cached.")
81+
}
82+
val minCount = getAbsoluteMinSupport(sequences)
7983
val (lengthOnePatternsAndCounts, prefixAndCandidates) =
80-
findLengthOnePatterns(sequences)
84+
findLengthOnePatterns(minCount, sequences)
8185
val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates)
82-
val nextPatterns = getPatternsInLocal(repartitionedRdd)
83-
val allPatterns = lengthOnePatternsAndCounts.map(x => (Seq(x._1), x._2)) ++ nextPatterns
86+
val nextPatterns = getPatternsInLocal(minCount, repartitionedRdd)
87+
val allPatterns = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)) ++ nextPatterns
8488
allPatterns
8589
}
8690

87-
private def getAbsoluteMinSupport(sequences: RDD[Array[Int]]): Int = {
88-
val result = if (minSupport <= 0) {
89-
0
90-
} else {
91-
val count = sequences.count()
92-
val support = if (minSupport <= 1) minSupport else 1
93-
(support * count).toInt
94-
}
95-
result
91+
/**
92+
* Get the absolute minimum support value (sequences count * minSupport).
93+
* @param sequences input data set, contains a set of sequences,
94+
* @return absolute minimum support value,
95+
*/
96+
private def getAbsoluteMinSupport(sequences: RDD[Array[Int]]): Long = {
97+
if (minSupport == 0) 0L else (sequences.count() * minSupport).toLong
9698
}
9799

98100
/**
99-
* Find the patterns that it's length is one
101+
* Generates frequent items by filtering the input data using minimal support level.
102+
* @param minCount the absolute minimum support
100103
* @param sequences original sequences data
101-
* @return length-one patterns and projection table
104+
* @return array of frequent pattern ordered by their frequencies
102105
*/
103-
private def findLengthOnePatterns(
104-
sequences: RDD[Array[Int]]): (RDD[(Int, Int)], RDD[(Seq[Int], Array[Int])]) = {
105-
val LengthOnePatternAndCounts = sequences
106-
.flatMap(_.distinct.map((_, 1)))
106+
private def getFreqItemAndCounts(
107+
minCount: Long,
108+
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
109+
sequences.flatMap(_.distinct.map((_, 1L)))
107110
.reduceByKey(_ + _)
108-
val infrequentLengthOnePatterns: Array[Int] = LengthOnePatternAndCounts
109-
.filter(_._2 < absMinSupport)
110-
.map(_._1)
111-
.collect()
112-
val frequentLengthOnePatterns = LengthOnePatternAndCounts
113-
.filter(_._2 >= absMinSupport)
114-
val frequentLengthOnePatternsArray = frequentLengthOnePatterns
115-
.map(_._1)
116-
.collect()
117-
val filteredSequences =
118-
if (infrequentLengthOnePatterns.isEmpty) {
119-
sequences
120-
} else {
121-
sequences.map { p =>
122-
p.filter { x => !infrequentLengthOnePatterns.contains(x) }
123-
}
124-
}
125-
val prefixAndCandidates = filteredSequences.flatMap { x =>
126-
frequentLengthOnePatternsArray.map { y =>
111+
.filter(_._2 >= minCount)
112+
}
113+
114+
/**
115+
* Generates frequent items by filtering the input data using minimal support level.
116+
* @param minCount the absolute minimum support
117+
* @param sequences sequences data
118+
* @return array of frequent pattern ordered by their frequencies
119+
*/
120+
private def getFreqItemAndCounts(
121+
minCount: Long,
122+
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
123+
sequences.flatMap(_.distinct)
124+
.groupBy(x => x)
125+
.mapValues(_.length.toLong)
126+
.filter(_._2 >= minCount)
127+
.toArray
128+
}
129+
130+
/**
131+
* Get the frequent prefixes' projected database.
132+
* @param frequentPrefixes frequent prefixes
133+
* @param sequences sequences data
134+
* @return prefixes and projected database
135+
*/
136+
private def getPatternAndProjectedDatabase(
137+
frequentPrefixes: Array[Int],
138+
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
139+
val filteredSequences = sequences.map { p =>
140+
p.filter (frequentPrefixes.contains(_) )
141+
}
142+
filteredSequences.flatMap { x =>
143+
frequentPrefixes.map { y =>
127144
val sub = getSuffix(y, x)
128-
(Seq(y), sub)
145+
(Array(y), sub)
129146
}
130147
}.filter(x => x._2.nonEmpty)
131-
(frequentLengthOnePatterns, prefixAndCandidates)
132148
}
133149

134150
/**
135-
* Re-partition the RDD data, to get better balance and performance.
151+
* Get the frequent prefixes' projected database.
152+
* @param prePrefix the frequent prefixes' prefix
153+
* @param frequentPrefixes frequent prefixes
154+
* @param sequences sequences data
155+
* @return prefixes and projected database
156+
*/
157+
private def getPatternAndProjectedDatabase(
158+
prePrefix: Array[Int],
159+
frequentPrefixes: Array[Int],
160+
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
161+
val filteredProjectedDatabase = sequences
162+
.map(x => x.filter(frequentPrefixes.contains(_)))
163+
frequentPrefixes.map { x =>
164+
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
165+
(prePrefix ++ Array(x), sub)
166+
}.filter(x => x._2.nonEmpty)
167+
}
168+
169+
/**
170+
* Find the patterns that it's length is one
171+
* @param minCount the absolute minimum support
172+
* @param sequences original sequences data
173+
* @return length-one patterns and projection table
174+
*/
175+
private def findLengthOnePatterns(
176+
minCount: Long,
177+
sequences: RDD[Array[Int]]): (RDD[(Int, Long)], RDD[(Array[Int], Array[Int])]) = {
178+
val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences)
179+
val prefixAndProjectedDatabase = getPatternAndProjectedDatabase(
180+
frequentLengthOnePatternAndCounts.keys.collect(), sequences)
181+
(frequentLengthOnePatternAndCounts, prefixAndProjectedDatabase)
182+
}
183+
184+
/**
185+
* Constructs prefix-projected databases from (prefix, suffix) pairs.
136186
* @param data patterns and projected sequences data before re-partition
137187
* @return patterns and projected sequences data after re-partition
138188
*/
139189
private def makePrefixProjectedDatabases(
140-
data: RDD[(Seq[Int], Array[Int])]): RDD[(Seq[Int], Array[Array[Int]])] = {
141-
val dataMerged = data
190+
data: RDD[(Array[Int], Array[Int])]): RDD[(Array[Int], Array[Array[Int]])] = {
191+
data.map(x => (x._1.toSeq, x._2))
142192
.groupByKey()
143-
.mapValues(_.toArray)
144-
dataMerged
193+
.map(x => (x._1.toArray, x._2.toArray))
145194
}
146195

147196
/**
148197
* calculate the patterns in local.
198+
* @param minCount the absolute minimum support
149199
* @param data patterns and projected sequences data data
150200
* @return patterns
151201
*/
152202
private def getPatternsInLocal(
153-
data: RDD[(Seq[Int], Array[Array[Int]])]): RDD[(Seq[Int], Int)] = {
154-
val result = data.flatMap { x =>
155-
getPatternsWithPrefix(x._1, x._2)
203+
minCount: Long,
204+
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
205+
data.flatMap { x =>
206+
getPatternsWithPrefix(minCount, x._1, x._2)
156207
}
157-
result
158208
}
159209

160210
/**
161211
* calculate the patterns with one prefix in local.
212+
* @param minCount the absolute minimum support
162213
* @param prefix prefix
163214
* @param projectedDatabase patterns and projected sequences data
164215
* @return patterns
165216
*/
166217
private def getPatternsWithPrefix(
167-
prefix: Seq[Int],
168-
projectedDatabase: Array[Array[Int]]): Array[(Seq[Int], Int)] = {
169-
val prefixAndCounts = projectedDatabase
170-
.flatMap(_.distinct)
171-
.groupBy(x => x)
172-
.mapValues(_.length)
173-
val frequentPrefixExtensions = prefixAndCounts.filter(x => x._2 >= absMinSupport)
174-
val frequentPrefixesAndCounts = frequentPrefixExtensions
175-
.map(x => (prefix ++ Seq(x._1), x._2))
176-
.toArray
177-
val cleanedSearchSpace = projectedDatabase
178-
.map(x => x.filter(y => frequentPrefixExtensions.contains(y)))
179-
val prefixProjectedDatabases = frequentPrefixExtensions.map { x =>
180-
val sub = cleanedSearchSpace.map(y => getSuffix(x._1, y)).filter(_.nonEmpty)
181-
(prefix ++ Seq(x._1), sub)
182-
}.filter(x => x._2.nonEmpty)
183-
.toArray
218+
minCount: Long,
219+
prefix: Array[Int],
220+
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
221+
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
222+
val frequentPatternAndCounts = frequentPrefixAndCounts
223+
.map(x => (prefix ++ Array(x._1), x._2))
224+
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
225+
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
226+
184227
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
185228
if (continueProcess) {
186229
val nextPatterns = prefixProjectedDatabases
187-
.map(x => getPatternsWithPrefix(x._1, x._2))
230+
.map(x => getPatternsWithPrefix(minCount, x._1, x._2))
188231
.reduce(_ ++ _)
189-
frequentPrefixesAndCounts ++ nextPatterns
232+
frequentPatternAndCounts ++ nextPatterns
190233
} else {
191-
frequentPrefixesAndCounts
234+
frequentPatternAndCounts
192235
}
193236
}
194237

195238
/**
196239
* calculate suffix sequence following a prefix in a sequence
197240
* @param prefix prefix
198-
* @param sequence original sequence
241+
* @param sequence sequence
199242
* @return suffix sequence
200243
*/
201244
private def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {

0 commit comments

Comments
 (0)