Skip to content

Commit 0bc8847

Browse files
zero323jkbradley
authored andcommitted
[SPARK-19281][PYTHON][ML] spark.ml Python API for FPGrowth
## What changes were proposed in this pull request? - Add `HasSupport` and `HasConfidence` `Params`. - Add new module `pyspark.ml.fpm`. - Add `FPGrowth` / `FPGrowthModel` wrappers. - Provide tests for new features. ## How was this patch tested? Unit tests. Author: zero323 <[email protected]> Closes #17218 from zero323/SPARK-19281.
1 parent 617ab64 commit 0bc8847

File tree

4 files changed

+273
-9
lines changed

4 files changed

+273
-9
lines changed

dev/sparktestsupport/modules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,15 +423,16 @@ def __hash__(self):
423423
"python/pyspark/ml/"
424424
],
425425
python_test_goals=[
426-
"pyspark.ml.feature",
427426
"pyspark.ml.classification",
428427
"pyspark.ml.clustering",
428+
"pyspark.ml.evaluation",
429+
"pyspark.ml.feature",
430+
"pyspark.ml.fpm",
429431
"pyspark.ml.linalg.__init__",
430432
"pyspark.ml.recommendation",
431433
"pyspark.ml.regression",
432434
"pyspark.ml.tuning",
433435
"pyspark.ml.tests",
434-
"pyspark.ml.evaluation",
435436
],
436437
blacklisted_python_implementations=[
437438
"PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there

python/docs/pyspark.ml.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,11 @@ pyspark.ml.evaluation module
8080
:members:
8181
:undoc-members:
8282
:inherited-members:
83+
84+
pyspark.ml.fpm module
85+
----------------------------
86+
87+
.. automodule:: pyspark.ml.fpm
88+
:members:
89+
:undoc-members:
90+
:inherited-members:

python/pyspark/ml/fpm.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark import keyword_only, since
19+
from pyspark.ml.util import *
20+
from pyspark.ml.wrapper import JavaEstimator, JavaModel
21+
from pyspark.ml.param.shared import *
22+
23+
__all__ = ["FPGrowth", "FPGrowthModel"]
24+
25+
26+
class HasSupport(Params):
27+
"""
28+
Mixin for param support.
29+
"""
30+
31+
minSupport = Param(
32+
Params._dummy(),
33+
"minSupport",
34+
"""Minimal support level of the frequent pattern. [0.0, 1.0].
35+
Any pattern that appears more than (minSupport * size-of-the-dataset)
36+
times will be output""",
37+
typeConverter=TypeConverters.toFloat)
38+
39+
def setMinSupport(self, value):
40+
"""
41+
Sets the value of :py:attr:`minSupport`.
42+
"""
43+
return self._set(minSupport=value)
44+
45+
def getMinSupport(self):
46+
"""
47+
Gets the value of minSupport or its default value.
48+
"""
49+
return self.getOrDefault(self.minSupport)
50+
51+
52+
class HasConfidence(Params):
53+
"""
54+
Mixin for param confidence.
55+
"""
56+
57+
minConfidence = Param(
58+
Params._dummy(),
59+
"minConfidence",
60+
"""Minimal confidence for generating Association Rule. [0.0, 1.0]
61+
Note that minConfidence has no effect during fitting.""",
62+
typeConverter=TypeConverters.toFloat)
63+
64+
def setMinConfidence(self, value):
65+
"""
66+
Sets the value of :py:attr:`minConfidence`.
67+
"""
68+
return self._set(minConfidence=value)
69+
70+
def getMinConfidence(self):
71+
"""
72+
Gets the value of minConfidence or its default value.
73+
"""
74+
return self.getOrDefault(self.minConfidence)
75+
76+
77+
class HasItemsCol(Params):
78+
"""
79+
Mixin for param itemsCol: items column name.
80+
"""
81+
82+
itemsCol = Param(Params._dummy(), "itemsCol",
83+
"items column name", typeConverter=TypeConverters.toString)
84+
85+
def setItemsCol(self, value):
86+
"""
87+
Sets the value of :py:attr:`itemsCol`.
88+
"""
89+
return self._set(itemsCol=value)
90+
91+
def getItemsCol(self):
92+
"""
93+
Gets the value of itemsCol or its default value.
94+
"""
95+
return self.getOrDefault(self.itemsCol)
96+
97+
98+
class FPGrowthModel(JavaModel, JavaMLWritable, JavaMLReadable):
99+
"""
100+
.. note:: Experimental
101+
102+
Model fitted by FPGrowth.
103+
104+
.. versionadded:: 2.2.0
105+
"""
106+
@property
107+
@since("2.2.0")
108+
def freqItemsets(self):
109+
"""
110+
DataFrame with two columns:
111+
* `items` - Itemset of the same type as the input column.
112+
* `freq` - Frequency of the itemset (`LongType`).
113+
"""
114+
return self._call_java("freqItemsets")
115+
116+
@property
117+
@since("2.2.0")
118+
def associationRules(self):
119+
"""
120+
Data with three columns:
121+
* `antecedent` - Array of the same type as the input column.
122+
* `consequent` - Array of the same type as the input column.
123+
* `confidence` - Confidence for the rule (`DoubleType`).
124+
"""
125+
return self._call_java("associationRules")
126+
127+
128+
class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol,
129+
HasSupport, HasConfidence, JavaMLWritable, JavaMLReadable):
130+
"""
131+
.. note:: Experimental
132+
133+
A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in
134+
Li et al., PFP: Parallel FP-Growth for Query Recommendation [LI2008]_.
135+
PFP distributes computation in such a way that each worker executes an
136+
independent group of mining tasks. The FP-Growth algorithm is described in
137+
Han et al., Mining frequent patterns without candidate generation [HAN2000]_
138+
139+
.. [LI2008] http://dx.doi.org/10.1145/1454008.1454027
140+
.. [HAN2000] http://dx.doi.org/10.1145/335191.335372
141+
142+
.. note:: null values in the feature column are ignored during fit().
143+
.. note:: Internally `transform` `collects` and `broadcasts` association rules.
144+
145+
>>> from pyspark.sql.functions import split
146+
>>> data = (spark.read
147+
... .text("data/mllib/sample_fpgrowth.txt")
148+
... .select(split("value", "\s+").alias("items")))
149+
>>> data.show(truncate=False)
150+
+------------------------+
151+
|items |
152+
+------------------------+
153+
|[r, z, h, k, p] |
154+
|[z, y, x, w, v, u, t, s]|
155+
|[s, x, o, n, r] |
156+
|[x, z, y, m, t, s, q, e]|
157+
|[z] |
158+
|[x, z, y, r, q, t, p] |
159+
+------------------------+
160+
>>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
161+
>>> fpm = fp.fit(data)
162+
>>> fpm.freqItemsets.show(5)
163+
+---------+----+
164+
| items|freq|
165+
+---------+----+
166+
| [s]| 3|
167+
| [s, x]| 3|
168+
|[s, x, z]| 2|
169+
| [s, z]| 2|
170+
| [r]| 3|
171+
+---------+----+
172+
only showing top 5 rows
173+
>>> fpm.associationRules.show(5)
174+
+----------+----------+----------+
175+
|antecedent|consequent|confidence|
176+
+----------+----------+----------+
177+
| [t, s]| [y]| 1.0|
178+
| [t, s]| [x]| 1.0|
179+
| [t, s]| [z]| 1.0|
180+
| [p]| [r]| 1.0|
181+
| [p]| [z]| 1.0|
182+
+----------+----------+----------+
183+
only showing top 5 rows
184+
>>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"])
185+
>>> sorted(fpm.transform(new_data).first().prediction)
186+
['x', 'y', 'z']
187+
188+
.. versionadded:: 2.2.0
189+
"""
190+
@keyword_only
191+
def __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items",
192+
predictionCol="prediction", numPartitions=None):
193+
"""
194+
__init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \
195+
predictionCol="prediction", numPartitions=None)
196+
"""
197+
super(FPGrowth, self).__init__()
198+
self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.FPGrowth", self.uid)
199+
self._setDefault(minSupport=0.3, minConfidence=0.8,
200+
itemsCol="items", predictionCol="prediction")
201+
kwargs = self._input_kwargs
202+
self.setParams(**kwargs)
203+
204+
@keyword_only
205+
@since("2.2.0")
206+
def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items",
207+
predictionCol="prediction", numPartitions=None):
208+
"""
209+
setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \
210+
predictionCol="prediction", numPartitions=None)
211+
"""
212+
kwargs = self._input_kwargs
213+
return self._set(**kwargs)
214+
215+
def _create_model(self, java_model):
216+
return FPGrowthModel(java_model)

python/pyspark/ml/tests.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,26 +42,28 @@
4242
import array as pyarray
4343
import numpy as np
4444
from numpy import (
45-
array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones)
45+
abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros)
4646
from numpy import sum as array_sum
4747
import inspect
4848

4949
from pyspark import keyword_only, SparkContext
5050
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
5151
from pyspark.ml.classification import *
5252
from pyspark.ml.clustering import *
53+
from pyspark.ml.common import _java2py, _py2java
5354
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
5455
from pyspark.ml.feature import *
55-
from pyspark.ml.linalg import Vector, SparseVector, DenseVector, VectorUDT,\
56-
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT, _convert_to_vector
56+
from pyspark.ml.fpm import FPGrowth, FPGrowthModel
57+
from pyspark.ml.linalg import (
58+
DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT,
59+
SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector)
5760
from pyspark.ml.param import Param, Params, TypeConverters
58-
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
61+
from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed
5962
from pyspark.ml.recommendation import ALS
60-
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \
61-
GeneralizedLinearRegression
63+
from pyspark.ml.regression import (
64+
DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression)
6265
from pyspark.ml.tuning import *
6366
from pyspark.ml.wrapper import JavaParams, JavaWrapper
64-
from pyspark.ml.common import _java2py, _py2java
6567
from pyspark.serializers import PickleSerializer
6668
from pyspark.sql import DataFrame, Row, SparkSession
6769
from pyspark.sql.functions import rand
@@ -1243,6 +1245,43 @@ def test_tweedie_distribution(self):
12431245
self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
12441246

12451247

1248+
class FPGrowthTests(SparkSessionTestCase):
1249+
def setUp(self):
1250+
super(FPGrowthTests, self).setUp()
1251+
self.data = self.spark.createDataFrame(
1252+
[([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )],
1253+
["items"])
1254+
1255+
def test_association_rules(self):
1256+
fp = FPGrowth()
1257+
fpm = fp.fit(self.data)
1258+
1259+
expected_association_rules = self.spark.createDataFrame(
1260+
[([3], [1], 1.0), ([2], [1], 1.0)],
1261+
["antecedent", "consequent", "confidence"]
1262+
)
1263+
actual_association_rules = fpm.associationRules
1264+
1265+
self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0)
1266+
self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0)
1267+
1268+
def test_freq_itemsets(self):
1269+
fp = FPGrowth()
1270+
fpm = fp.fit(self.data)
1271+
1272+
expected_freq_itemsets = self.spark.createDataFrame(
1273+
[([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)],
1274+
["items", "freq"]
1275+
)
1276+
actual_freq_itemsets = fpm.freqItemsets
1277+
1278+
self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0)
1279+
self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0)
1280+
1281+
def tearDown(self):
1282+
del self.data
1283+
1284+
12461285
class ALSTest(SparkSessionTestCase):
12471286

12481287
def test_storage_levels(self):

0 commit comments

Comments
 (0)