Skip to content

Commit 91fd7e6

Browse files
committed
Add new algorithm PrefixSpan and test file.
1 parent f9c448d commit 91fd7e6

File tree

2 files changed

+230
-0
lines changed

2 files changed

+230
-0
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.fpm
19+
20+
import org.apache.spark.rdd.RDD
21+
22+
/**
23+
*
24+
* A parallel PrefixSpan algorithm to mine sequential pattern.
25+
* The PrefixSpan algorithm is described in
26+
* [[http://web.engr.illinois.edu/~hanj/pdf/span01.pdf]].
27+
*
28+
* @param sequences original sequences data
29+
* @param minSupport the minimal support level of the sequential pattern, any pattern appears
30+
* more than minSupport times will be output
31+
* @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
32+
* less than maxPatternLength will be output
33+
*
34+
* @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
35+
* (Wikipedia)]]
36+
*/
37+
class Prefixspan(
38+
val sequences: RDD[Array[Int]],
39+
val minSupport: Int = 2,
40+
val maxPatternLength: Int = 50) extends java.io.Serializable {
41+
42+
/**
43+
* Calculate sequential patterns:
44+
* a) find and collect length-one patterns
45+
* b) for each length-one patterns and each sequence,
46+
* emit (pattern (prefix), suffix sequence) as key-value pairs
47+
* c) group by key and then map value iterator to array
48+
* d) local PrefixSpan on each prefix
49+
* @return sequential patterns
50+
*/
51+
def run(): RDD[(Seq[Int], Int)] = {
52+
val (patternsOneLength, prefixAndCandidates) = findPatternsLengthOne()
53+
val repartitionedRdd = repartitionSequences(prefixAndCandidates)
54+
val nextPatterns = getPatternsInLocal(repartitionedRdd)
55+
val allPatterns = patternsOneLength.map(x => (Seq(x._1), x._2)) ++ nextPatterns
56+
allPatterns
57+
}
58+
59+
/**
60+
* Find the patterns that it's length is one
61+
* @return length-one patterns and projection table
62+
*/
63+
private def findPatternsLengthOne(): (RDD[(Int, Int)], RDD[(Seq[Int], Array[Int])]) = {
64+
val patternsOneLength = sequences
65+
.map(_.distinct)
66+
.flatMap(p => p)
67+
.map((_, 1))
68+
.reduceByKey(_ + _)
69+
70+
val removedElements: Array[Int] = patternsOneLength
71+
.filter(_._2 < minSupport)
72+
.map(_._1)
73+
.collect()
74+
75+
val savedElements = patternsOneLength.filter(_._2 >= minSupport)
76+
77+
val savedElementsArray = savedElements
78+
.map(_._1)
79+
.collect()
80+
81+
val filteredSequences =
82+
if (removedElements.isEmpty) {
83+
sequences
84+
} else {
85+
sequences.map { p =>
86+
p.filter { x => !removedElements.contains(x) }
87+
}
88+
}
89+
90+
val prefixAndCandidates = filteredSequences.flatMap { x =>
91+
savedElementsArray.map { y =>
92+
val sub = getSuffix(y, x)
93+
(Seq(y), sub)
94+
}
95+
}
96+
97+
(savedElements, prefixAndCandidates)
98+
}
99+
100+
/**
101+
* Re-partition the RDD data, to get better balance and performance.
102+
* @param data patterns and projected sequences data before re-partition
103+
* @return patterns and projected sequences data after re-partition
104+
*/
105+
private def repartitionSequences(
106+
data: RDD[(Seq[Int], Array[Int])]): RDD[(Seq[Int], Array[Array[Int]])] = {
107+
val dataRemovedEmptyLine = data.filter(x => x._2.nonEmpty)
108+
val dataMerged = dataRemovedEmptyLine
109+
.groupByKey()
110+
.map(x => (x._1, x._2.toArray))
111+
dataMerged
112+
}
113+
114+
/**
115+
* calculate the patterns in local.
116+
* @param data patterns and projected sequences data data
117+
* @return patterns
118+
*/
119+
private def getPatternsInLocal(
120+
data: RDD[(Seq[Int], Array[Array[Int]])]): RDD[(Seq[Int], Int)] = {
121+
val result = data.flatMap { x =>
122+
getPatternsWithPrefix(x._1, x._2)
123+
}
124+
result
125+
}
126+
127+
/**
128+
* calculate the patterns with one prefix in local.
129+
* @param prefix prefix
130+
* @param data patterns and projected sequences data
131+
* @return patterns
132+
*/
133+
private def getPatternsWithPrefix(
134+
prefix: Seq[Int],
135+
data: Array[Array[Int]]): Array[(Seq[Int], Int)] = {
136+
val elements = data
137+
.map(x => x.distinct)
138+
.flatMap(x => x)
139+
.groupBy(x => x)
140+
.map(x => (x._1, x._2.length))
141+
142+
val selectedSingleElements = elements.filter(x => x._2 >= minSupport)
143+
144+
val selectedElements = selectedSingleElements
145+
.map(x => (prefix ++ Seq(x._1), x._2))
146+
.toArray
147+
148+
val cleanedSearchSpace = data
149+
.map(x => x.filter(y => selectedSingleElements.contains(y)))
150+
151+
val newSearchSpace = selectedSingleElements.map { x =>
152+
val sub = cleanedSearchSpace.map(y => getSuffix(x._1, y)).filter(_.nonEmpty)
153+
(prefix ++ Seq(x._1), sub)
154+
}.filter(x => x._2.nonEmpty)
155+
.toArray
156+
157+
val continueProcess = newSearchSpace.nonEmpty && prefix.length + 1 < maxPatternLength
158+
159+
if (continueProcess) {
160+
val nextPatterns = newSearchSpace
161+
.map(x => getPatternsWithPrefix(x._1, x._2))
162+
.reduce(_ ++ _)
163+
selectedElements ++ nextPatterns
164+
} else {
165+
selectedElements
166+
}
167+
}
168+
169+
/**
170+
* calculate suffix sequence following a prefix in a sequence
171+
* @param prefix prefix
172+
* @param sequence original sequence
173+
* @return suffix sequence
174+
*/
175+
private def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
176+
val index = sequence.indexOf(prefix)
177+
if (index == -1) {
178+
Array()
179+
} else {
180+
sequence.takeRight(sequence.length - index - 1)
181+
}
182+
}
183+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.fpm
18+
19+
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.mllib.util.MLlibTestSparkContext
21+
22+
class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
23+
24+
test("Prefixspan sequences mining using Integer type") {
25+
val sequences = Array(
26+
Array(3, 1, 3, 4, 5),
27+
Array(2, 3, 1),
28+
Array(3, 4, 4, 3),
29+
Array(1, 3, 4, 5),
30+
Array(2, 4, 1),
31+
Array(6, 5, 3))
32+
33+
val rdd = sc.parallelize(sequences, 2).cache()
34+
35+
val prefixspan1 = new Prefixspan(rdd, 2, 50)
36+
val result1 = prefixspan1.run()
37+
assert(result1.count() == 19)
38+
39+
val prefixspan2 = new Prefixspan(rdd, 3, 50)
40+
val result2 = prefixspan2.run()
41+
assert(result2.count() == 5)
42+
43+
val prefixspan3 = new Prefixspan(rdd, 2, 2)
44+
val result3 = prefixspan3.run()
45+
assert(result3.count() == 14)
46+
}
47+
}

0 commit comments

Comments
 (0)