Skip to content

Commit 575995f

Browse files
committed
Modified the code according to the review comments.
1 parent 91fd7e6 commit 575995f

File tree

2 files changed

+278
-0
lines changed

2 files changed

+278
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.fpm
19+
20+
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.rdd.RDD
22+
23+
/**
24+
*
25+
* :: Experimental ::
26+
*
27+
* A parallel PrefixSpan algorithm to mine sequential pattern.
28+
* The PrefixSpan algorithm is described in
29+
* [[http://doi.org/10.1109/ICDE.2001.914830]].
30+
*
31+
* @param minSupport the minimal support level of the sequential pattern, any pattern appears
32+
* more than (minSupport * size-of-the-dataset) times will be output
33+
* @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
34+
* less than maxPatternLength will be output
35+
*
36+
* @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
37+
* (Wikipedia)]]
38+
*/
39+
@Experimental
40+
class PrefixSpan(
41+
private var minSupport: Double,
42+
private var maxPatternLength: Int) extends java.io.Serializable {
43+
44+
private var absMinSupport: Int = 0
45+
46+
/**
47+
* Constructs a default instance with default parameters
48+
* {minSupport: `0.1`, maxPatternLength: 10}.
49+
*/
50+
def this() = this(0.1, 10)
51+
52+
/**
53+
* Sets the minimal support level (default: `0.1`).
54+
*/
55+
def setMinSupport(minSupport: Double): this.type = {
56+
this.minSupport = minSupport
57+
this
58+
}
59+
60+
/**
61+
* Sets maximal pattern length.
62+
*/
63+
def setMaxPatternLength(maxPatternLength: Int): this.type = {
64+
this.maxPatternLength = maxPatternLength
65+
this
66+
}
67+
68+
/**
69+
* Calculate sequential patterns:
70+
* a) find and collect length-one patterns
71+
* b) for each length-one patterns and each sequence,
72+
* emit (pattern (prefix), suffix sequence) as key-value pairs
73+
* c) group by key and then map value iterator to array
74+
* d) local PrefixSpan on each prefix
75+
* @return sequential patterns
76+
*/
77+
def run(sequences: RDD[Array[Int]]): RDD[(Seq[Int], Int)] = {
78+
absMinSupport = getAbsoluteMinSupport(sequences)
79+
val (lengthOnePatternsAndCounts, prefixAndCandidates) =
80+
findLengthOnePatterns(sequences)
81+
val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates)
82+
val nextPatterns = getPatternsInLocal(repartitionedRdd)
83+
val allPatterns = lengthOnePatternsAndCounts.map(x => (Seq(x._1), x._2)) ++ nextPatterns
84+
allPatterns
85+
}
86+
87+
private def getAbsoluteMinSupport(sequences: RDD[Array[Int]]): Int = {
88+
val result = if (minSupport <= 0) {
89+
0
90+
}else {
91+
val count = sequences.count()
92+
val support = if (minSupport <= 1) minSupport else 1
93+
(support * count).toInt
94+
}
95+
result
96+
}
97+
98+
/**
99+
* Find the patterns that it's length is one
100+
* @param sequences original sequences data
101+
* @return length-one patterns and projection table
102+
*/
103+
private def findLengthOnePatterns(
104+
sequences: RDD[Array[Int]]): (RDD[(Int, Int)], RDD[(Seq[Int], Array[Int])]) = {
105+
val LengthOnePatternAndCounts = sequences
106+
.flatMap(_.distinct.map((_, 1)))
107+
.reduceByKey(_ + _)
108+
val infrequentLengthOnePatterns: Array[Int] = LengthOnePatternAndCounts
109+
.filter(_._2 < absMinSupport)
110+
.map(_._1)
111+
.collect()
112+
val frequentLengthOnePatterns = LengthOnePatternAndCounts
113+
.filter(_._2 >= absMinSupport)
114+
val frequentLengthOnePatternsArray = frequentLengthOnePatterns
115+
.map(_._1)
116+
.collect()
117+
val filteredSequences =
118+
if (infrequentLengthOnePatterns.isEmpty) {
119+
sequences
120+
} else {
121+
sequences.map { p =>
122+
p.filter { x => !infrequentLengthOnePatterns.contains(x) }
123+
}
124+
}
125+
val prefixAndCandidates = filteredSequences.flatMap { x =>
126+
frequentLengthOnePatternsArray.map { y =>
127+
val sub = getSuffix(y, x)
128+
(Seq(y), sub)
129+
}
130+
}.filter(x => x._2.nonEmpty)
131+
(frequentLengthOnePatterns, prefixAndCandidates)
132+
}
133+
134+
/**
135+
* Re-partition the RDD data, to get better balance and performance.
136+
* @param data patterns and projected sequences data before re-partition
137+
* @return patterns and projected sequences data after re-partition
138+
*/
139+
private def makePrefixProjectedDatabases(
140+
data: RDD[(Seq[Int], Array[Int])]): RDD[(Seq[Int], Array[Array[Int]])] = {
141+
val dataMerged = data
142+
.groupByKey()
143+
.mapValues(_.toArray)
144+
dataMerged
145+
}
146+
147+
/**
148+
* calculate the patterns in local.
149+
* @param data patterns and projected sequences data data
150+
* @return patterns
151+
*/
152+
private def getPatternsInLocal(
153+
data: RDD[(Seq[Int], Array[Array[Int]])]): RDD[(Seq[Int], Int)] = {
154+
val result = data.flatMap { x =>
155+
getPatternsWithPrefix(x._1, x._2)
156+
}
157+
result
158+
}
159+
160+
/**
161+
* calculate the patterns with one prefix in local.
162+
* @param prefix prefix
163+
* @param projectedDatabase patterns and projected sequences data
164+
* @return patterns
165+
*/
166+
private def getPatternsWithPrefix(
167+
prefix: Seq[Int],
168+
projectedDatabase: Array[Array[Int]]): Array[(Seq[Int], Int)] = {
169+
val prefixAndCounts = projectedDatabase
170+
.flatMap(_.distinct)
171+
.groupBy(x => x)
172+
.mapValues(_.length)
173+
val frequentPrefixExtensions = prefixAndCounts.filter(x => x._2 >= absMinSupport)
174+
val frequentPrefixesAndCounts = frequentPrefixExtensions
175+
.map(x => (prefix ++ Seq(x._1), x._2))
176+
.toArray
177+
val cleanedSearchSpace = projectedDatabase
178+
.map(x => x.filter(y => frequentPrefixExtensions.contains(y)))
179+
val prefixProjectedDatabases = frequentPrefixExtensions.map { x =>
180+
val sub = cleanedSearchSpace.map(y => getSuffix(x._1, y)).filter(_.nonEmpty)
181+
(prefix ++ Seq(x._1), sub)
182+
}.filter(x => x._2.nonEmpty)
183+
.toArray
184+
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
185+
if (continueProcess) {
186+
val nextPatterns = prefixProjectedDatabases
187+
.map(x => getPatternsWithPrefix(x._1, x._2))
188+
.reduce(_ ++ _)
189+
frequentPrefixesAndCounts ++ nextPatterns
190+
} else {
191+
frequentPrefixesAndCounts
192+
}
193+
}
194+
195+
/**
196+
* calculate suffix sequence following a prefix in a sequence
197+
* @param prefix prefix
198+
* @param sequence original sequence
199+
* @return suffix sequence
200+
*/
201+
private def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
202+
val index = sequence.indexOf(prefix)
203+
if (index == -1) {
204+
Array()
205+
} else {
206+
sequence.drop(index + 1)
207+
}
208+
}
209+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.fpm
18+
19+
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.mllib.util.MLlibTestSparkContext
21+
import org.apache.spark.rdd.RDD
22+
23+
class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
24+
25+
test("Prefixspan sequences mining using Integer type") {
26+
val sequences = Array(
27+
Array(3, 1, 3, 4, 5),
28+
Array(2, 3, 1),
29+
Array(3, 4, 4, 3),
30+
Array(1, 3, 4, 5),
31+
Array(2, 4, 1),
32+
Array(6, 5, 3))
33+
34+
val rdd = sc.parallelize(sequences, 2).cache()
35+
36+
def formatResultString(data: RDD[(Seq[Int], Int)]): String = {
37+
data.map(x => x._1.mkString(",") + ": " + x._2)
38+
.collect()
39+
.sortWith(_<_)
40+
.mkString("; ")
41+
}
42+
43+
val prefixspan = new PrefixSpan()
44+
.setMinSupport(0.34)
45+
.setMaxPatternLength(50)
46+
val result1 = prefixspan.run(rdd)
47+
val len1 = result1.count().toInt
48+
val actualValue1 = formatResultString(result1)
49+
val expectedValue1 =
50+
"1,3,4,5: 2; 1,3,4: 2; 1,3,5: 2; 1,3: 2; 1,4,5: 2;" +
51+
" 1,4: 2; 1,5: 2; 1: 4; 2,1: 2; 2: 2; 3,1: 2; 3,3: 2;" +
52+
" 3,4,5: 2; 3,4: 3; 3,5: 2; 3: 5; 4,5: 2; 4: 4; 5: 3"
53+
assert(expectedValue1 == actualValue1)
54+
55+
prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
56+
val result2 = prefixspan.run(rdd)
57+
val expectedValue2 = "1: 4; 3,4: 3; 3: 5; 4: 4; 5: 3"
58+
val actualValue2 = formatResultString(result2)
59+
assert(expectedValue2 == actualValue2)
60+
61+
prefixspan.setMinSupport(0.34).setMaxPatternLength(2)
62+
val result3 = prefixspan.run(rdd)
63+
val actualValue3 = formatResultString(result3)
64+
val expectedValue3 =
65+
"1,3: 2; 1,4: 2; 1,5: 2; 1: 4; 2,1: 2; 2: 2; 3,1: 2;" +
66+
" 3,3: 2; 3,4: 3; 3,5: 2; 3: 5; 4,5: 2; 4: 4; 5: 3"
67+
assert(expectedValue3 == actualValue3)
68+
}
69+
}

0 commit comments

Comments
 (0)