Skip to content

Commit 574e56c

Browse files
committed
Add new object LocalPrefixSpan, and do some optimization.
1 parent ba5df34 commit 574e56c

File tree

3 files changed

+158
-102
lines changed

3 files changed

+158
-102
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.fpm
19+
20+
import org.apache.spark.Logging
21+
import org.apache.spark.annotation.Experimental
22+
23+
/**
24+
*
25+
* :: Experimental ::
26+
*
27+
* Calculate all patterns of a projected database in local.
28+
*/
29+
@Experimental
30+
private[fpm] object LocalPrefixSpan extends Logging with Serializable {
31+
32+
/**
33+
* Calculate all patterns of a projected database in local.
34+
* @param minCount minimum count
35+
* @param maxPatternLength maximum pattern length
36+
* @param prefix prefix
37+
* @param projectedDatabase the projected dabase
38+
* @return a set of sequential pattern pairs,
39+
* the key of pair is pattern (a list of elements),
40+
* the value of pair is the pattern's count.
41+
*/
42+
def run(
43+
minCount: Long,
44+
maxPatternLength: Int,
45+
prefix: Array[Int],
46+
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
47+
getPatternsWithPrefix(minCount, maxPatternLength, prefix, projectedDatabase)
48+
}
49+
50+
/**
51+
* calculate suffix sequence following a prefix in a sequence
52+
* @param prefix prefix
53+
* @param sequence sequence
54+
* @return suffix sequence
55+
*/
56+
def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
57+
val index = sequence.indexOf(prefix)
58+
if (index == -1) {
59+
Array()
60+
} else {
61+
sequence.drop(index + 1)
62+
}
63+
}
64+
65+
/**
66+
* Generates frequent items by filtering the input data using minimal count level.
67+
* @param minCount the absolute minimum count
68+
* @param sequences sequences data
69+
* @return array of item and count pair
70+
*/
71+
private def getFreqItemAndCounts(
72+
minCount: Long,
73+
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
74+
sequences.flatMap(_.distinct)
75+
.groupBy(x => x)
76+
.mapValues(_.length.toLong)
77+
.filter(_._2 >= minCount)
78+
.toArray
79+
}
80+
81+
/**
82+
* Get the frequent prefixes' projected database.
83+
* @param prePrefix the frequent prefixes' prefix
84+
* @param frequentPrefixes frequent prefixes
85+
* @param sequences sequences data
86+
* @return prefixes and projected database
87+
*/
88+
private def getPatternAndProjectedDatabase(
89+
prePrefix: Array[Int],
90+
frequentPrefixes: Array[Int],
91+
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
92+
val filteredProjectedDatabase = sequences
93+
.map(x => x.filter(frequentPrefixes.contains(_)))
94+
frequentPrefixes.map { x =>
95+
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
96+
(prePrefix ++ Array(x), sub)
97+
}.filter(x => x._2.nonEmpty)
98+
}
99+
100+
/**
101+
* Calculate all patterns of a projected database in local.
102+
* @param minCount the minimum count
103+
* @param maxPatternLength maximum pattern length
104+
* @param prefix prefix
105+
* @param projectedDatabase projected database
106+
* @return patterns
107+
*/
108+
private def getPatternsWithPrefix(
109+
minCount: Long,
110+
maxPatternLength: Int,
111+
prefix: Array[Int],
112+
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
113+
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
114+
val frequentPatternAndCounts = frequentPrefixAndCounts
115+
.map(x => (prefix ++ Array(x._1), x._2))
116+
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
117+
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
118+
119+
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
120+
if (continueProcess) {
121+
val nextPatterns = prefixProjectedDatabases
122+
.map(x => getPatternsWithPrefix(minCount, maxPatternLength, x._1, x._2))
123+
.reduce(_ ++ _)
124+
frequentPatternAndCounts ++ nextPatterns
125+
} else {
126+
frequentPatternAndCounts
127+
}
128+
}
129+
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 27 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ class PrefixSpan private (
5353
* Sets the minimal support level (default: `0.1`).
5454
*/
5555
def setMinSupport(minSupport: Double): this.type = {
56-
require(minSupport >= 0 && minSupport <= 1)
56+
require(minSupport >= 0 && minSupport <= 1,
57+
"The minimum support value must be between 0 and 1, including 0 and 1.")
5758
this.minSupport = minSupport
5859
this
5960
}
@@ -62,7 +63,8 @@ class PrefixSpan private (
6263
* Sets maximal pattern length (default: `10`).
6364
*/
6465
def setMaxPatternLength(maxPatternLength: Int): this.type = {
65-
require(maxPatternLength >= 1)
66+
require(maxPatternLength >= 1,
67+
"The maximum pattern length value must be greater than 0.")
6668
this.maxPatternLength = maxPatternLength
6769
this
6870
}
@@ -73,35 +75,38 @@ class PrefixSpan private (
7375
* a sequence is an ordered list of elements.
7476
* @return a set of sequential pattern pairs,
7577
* the key of pair is pattern (a list of elements),
76-
* the value of pair is the pattern's support value.
78+
* the value of pair is the pattern's count.
7779
*/
7880
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
7981
if (sequences.getStorageLevel == StorageLevel.NONE) {
8082
logWarning("Input data is not cached.")
8183
}
82-
val minCount = getAbsoluteMinSupport(sequences)
84+
val minCount = getMinCount(sequences)
8385
val (lengthOnePatternsAndCounts, prefixAndCandidates) =
8486
findLengthOnePatterns(minCount, sequences)
85-
val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates)
86-
val nextPatterns = getPatternsInLocal(minCount, repartitionedRdd)
87-
val allPatterns = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)) ++ nextPatterns
87+
val projectedDatabase = makePrefixProjectedDatabases(prefixAndCandidates)
88+
val nextPatterns = getPatternsInLocal(minCount, projectedDatabase)
89+
val lengthOnePatternsAndCountsRdd =
90+
sequences.sparkContext.parallelize(
91+
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
92+
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
8893
allPatterns
8994
}
9095

9196
/**
92-
* Get the absolute minimum support value (sequences count * minSupport).
97+
* Get the minimum count (sequences count * minSupport).
9398
* @param sequences input data set, contains a set of sequences,
94-
* @return absolute minimum support value,
99+
* @return minimum count,
95100
*/
96-
private def getAbsoluteMinSupport(sequences: RDD[Array[Int]]): Long = {
97-
if (minSupport == 0) 0L else (sequences.count() * minSupport).toLong
101+
private def getMinCount(sequences: RDD[Array[Int]]): Long = {
102+
if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
98103
}
99104

100105
/**
101-
* Generates frequent items by filtering the input data using minimal support level.
102-
* @param minCount the absolute minimum support
106+
* Generates frequent items by filtering the input data using minimal count level.
107+
* @param minCount the absolute minimum count
103108
* @param sequences original sequences data
104-
* @return array of frequent pattern ordered by their frequencies
109+
* @return array of item and count pair
105110
*/
106111
private def getFreqItemAndCounts(
107112
minCount: Long,
@@ -111,22 +116,6 @@ class PrefixSpan private (
111116
.filter(_._2 >= minCount)
112117
}
113118

114-
/**
115-
* Generates frequent items by filtering the input data using minimal support level.
116-
* @param minCount the absolute minimum support
117-
* @param sequences sequences data
118-
* @return array of frequent pattern ordered by their frequencies
119-
*/
120-
private def getFreqItemAndCounts(
121-
minCount: Long,
122-
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
123-
sequences.flatMap(_.distinct)
124-
.groupBy(x => x)
125-
.mapValues(_.length.toLong)
126-
.filter(_._2 >= minCount)
127-
.toArray
128-
}
129-
130119
/**
131120
* Get the frequent prefixes' projected database.
132121
* @param frequentPrefixes frequent prefixes
@@ -141,44 +130,25 @@ class PrefixSpan private (
141130
}
142131
filteredSequences.flatMap { x =>
143132
frequentPrefixes.map { y =>
144-
val sub = getSuffix(y, x)
133+
val sub = LocalPrefixSpan.getSuffix(y, x)
145134
(Array(y), sub)
146-
}
147-
}.filter(x => x._2.nonEmpty)
148-
}
149-
150-
/**
151-
* Get the frequent prefixes' projected database.
152-
* @param prePrefix the frequent prefixes' prefix
153-
* @param frequentPrefixes frequent prefixes
154-
* @param sequences sequences data
155-
* @return prefixes and projected database
156-
*/
157-
private def getPatternAndProjectedDatabase(
158-
prePrefix: Array[Int],
159-
frequentPrefixes: Array[Int],
160-
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
161-
val filteredProjectedDatabase = sequences
162-
.map(x => x.filter(frequentPrefixes.contains(_)))
163-
frequentPrefixes.map { x =>
164-
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
165-
(prePrefix ++ Array(x), sub)
166-
}.filter(x => x._2.nonEmpty)
135+
}.filter(_._2.nonEmpty)
136+
}
167137
}
168138

169139
/**
170140
* Find the patterns that it's length is one
171-
* @param minCount the absolute minimum support
141+
* @param minCount the minimum count
172142
* @param sequences original sequences data
173143
* @return length-one patterns and projection table
174144
*/
175145
private def findLengthOnePatterns(
176146
minCount: Long,
177-
sequences: RDD[Array[Int]]): (RDD[(Int, Long)], RDD[(Array[Int], Array[Int])]) = {
147+
sequences: RDD[Array[Int]]): (Array[(Int, Long)], RDD[(Array[Int], Array[Int])]) = {
178148
val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences)
179149
val prefixAndProjectedDatabase = getPatternAndProjectedDatabase(
180150
frequentLengthOnePatternAndCounts.keys.collect(), sequences)
181-
(frequentLengthOnePatternAndCounts, prefixAndProjectedDatabase)
151+
(frequentLengthOnePatternAndCounts.collect(), prefixAndProjectedDatabase)
182152
}
183153

184154
/**
@@ -195,58 +165,15 @@ class PrefixSpan private (
195165

196166
/**
197167
* calculate the patterns in local.
198-
* @param minCount the absolute minimum support
168+
* @param minCount the absolute minimum count
199169
* @param data patterns and projected sequences data data
200170
* @return patterns
201171
*/
202172
private def getPatternsInLocal(
203173
minCount: Long,
204174
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
205175
data.flatMap { x =>
206-
getPatternsWithPrefix(minCount, x._1, x._2)
207-
}
208-
}
209-
210-
/**
211-
* calculate the patterns with one prefix in local.
212-
* @param minCount the absolute minimum support
213-
* @param prefix prefix
214-
* @param projectedDatabase patterns and projected sequences data
215-
* @return patterns
216-
*/
217-
private def getPatternsWithPrefix(
218-
minCount: Long,
219-
prefix: Array[Int],
220-
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
221-
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
222-
val frequentPatternAndCounts = frequentPrefixAndCounts
223-
.map(x => (prefix ++ Array(x._1), x._2))
224-
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
225-
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
226-
227-
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
228-
if (continueProcess) {
229-
val nextPatterns = prefixProjectedDatabases
230-
.map(x => getPatternsWithPrefix(minCount, x._1, x._2))
231-
.reduce(_ ++ _)
232-
frequentPatternAndCounts ++ nextPatterns
233-
} else {
234-
frequentPatternAndCounts
235-
}
236-
}
237-
238-
/**
239-
* calculate suffix sequence following a prefix in a sequence
240-
* @param prefix prefix
241-
* @param sequence sequence
242-
* @return suffix sequence
243-
*/
244-
private def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
245-
val index = sequence.indexOf(prefix)
246-
if (index == -1) {
247-
Array()
248-
} else {
249-
sequence.drop(index + 1)
176+
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
250177
}
251178
}
252179
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
6060
}
6161

6262
val prefixspan = new PrefixSpan()
63-
.setMinSupport(0.34)
63+
.setMinSupport(0.33)
6464
.setMaxPatternLength(50)
6565
val result1 = prefixspan.run(rdd)
6666
val expectedValue1 = Array(
@@ -97,7 +97,7 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
9797
)
9898
assert(compareResult(expectedValue2, result2.collect()))
9999

100-
prefixspan.setMinSupport(0.34).setMaxPatternLength(2)
100+
prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
101101
val result3 = prefixspan.run(rdd)
102102
val expectedValue3 = Array(
103103
(Array(1), 4L),

0 commit comments

Comments
 (0)