|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from pyspark import keyword_only, since |
| 19 | +from pyspark.ml.util import * |
| 20 | +from pyspark.ml.wrapper import JavaEstimator, JavaModel |
| 21 | +from pyspark.ml.param.shared import * |
| 22 | + |
| 23 | +__all__ = ["FPGrowth", "FPGrowthModel"] |
| 24 | + |
| 25 | + |
| 26 | +class HasSupport(Params): |
| 27 | + """ |
| 28 | + Mixin for param support. |
| 29 | + """ |
| 30 | + |
| 31 | + minSupport = Param( |
| 32 | + Params._dummy(), |
| 33 | + "minSupport", |
| 34 | + """Minimal support level of the frequent pattern. [0.0, 1.0]. |
| 35 | + Any pattern that appears more than (minSupport * size-of-the-dataset) |
| 36 | + times will be output""", |
| 37 | + typeConverter=TypeConverters.toFloat) |
| 38 | + |
| 39 | + def setMinSupport(self, value): |
| 40 | + """ |
| 41 | + Sets the value of :py:attr:`minSupport`. |
| 42 | + """ |
| 43 | + return self._set(minSupport=value) |
| 44 | + |
| 45 | + def getMinSupport(self): |
| 46 | + """ |
| 47 | + Gets the value of minSupport or its default value. |
| 48 | + """ |
| 49 | + return self.getOrDefault(self.minSupport) |
| 50 | + |
| 51 | + |
| 52 | +class HasConfidence(Params): |
| 53 | + """ |
| 54 | + Mixin for param confidence. |
| 55 | + """ |
| 56 | + |
| 57 | + minConfidence = Param( |
| 58 | + Params._dummy(), |
| 59 | + "minConfidence", |
| 60 | + """Minimal confidence for generating Association Rule. [0.0, 1.0] |
| 61 | + Note that minConfidence has no effect during fitting.""", |
| 62 | + typeConverter=TypeConverters.toFloat) |
| 63 | + |
| 64 | + def setMinConfidence(self, value): |
| 65 | + """ |
| 66 | + Sets the value of :py:attr:`minConfidence`. |
| 67 | + """ |
| 68 | + return self._set(minConfidence=value) |
| 69 | + |
| 70 | + def getMinConfidence(self): |
| 71 | + """ |
| 72 | + Gets the value of minConfidence or its default value. |
| 73 | + """ |
| 74 | + return self.getOrDefault(self.minConfidence) |
| 75 | + |
| 76 | + |
| 77 | +class HasItemsCol(Params): |
| 78 | + """ |
| 79 | + Mixin for param itemsCol: items column name. |
| 80 | + """ |
| 81 | + |
| 82 | + itemsCol = Param(Params._dummy(), "itemsCol", |
| 83 | + "items column name", typeConverter=TypeConverters.toString) |
| 84 | + |
| 85 | + def setItemsCol(self, value): |
| 86 | + """ |
| 87 | + Sets the value of :py:attr:`itemsCol`. |
| 88 | + """ |
| 89 | + return self._set(itemsCol=value) |
| 90 | + |
| 91 | + def getItemsCol(self): |
| 92 | + """ |
| 93 | + Gets the value of itemsCol or its default value. |
| 94 | + """ |
| 95 | + return self.getOrDefault(self.itemsCol) |
| 96 | + |
| 97 | + |
| 98 | +class FPGrowthModel(JavaModel, JavaMLWritable, JavaMLReadable): |
| 99 | + """ |
| 100 | + .. note:: Experimental |
| 101 | +
|
| 102 | + Model fitted by FPGrowth. |
| 103 | +
|
| 104 | + .. versionadded:: 2.2.0 |
| 105 | + """ |
| 106 | + @property |
| 107 | + @since("2.2.0") |
| 108 | + def freqItemsets(self): |
| 109 | + """ |
| 110 | + DataFrame with two columns: |
| 111 | + * `items` - Itemset of the same type as the input column. |
| 112 | + * `freq` - Frequency of the itemset (`LongType`). |
| 113 | + """ |
| 114 | + return self._call_java("freqItemsets") |
| 115 | + |
| 116 | + @property |
| 117 | + @since("2.2.0") |
| 118 | + def associationRules(self): |
| 119 | + """ |
| 120 | + Data with three columns: |
| 121 | + * `antecedent` - Array of the same type as the input column. |
| 122 | + * `consequent` - Array of the same type as the input column. |
| 123 | + * `confidence` - Confidence for the rule (`DoubleType`). |
| 124 | + """ |
| 125 | + return self._call_java("associationRules") |
| 126 | + |
| 127 | + |
| 128 | +class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, |
| 129 | + HasSupport, HasConfidence, JavaMLWritable, JavaMLReadable): |
| 130 | + """ |
| 131 | + .. note:: Experimental |
| 132 | +
|
| 133 | + A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in |
| 134 | + Li et al., PFP: Parallel FP-Growth for Query Recommendation [LI2008]_. |
| 135 | + PFP distributes computation in such a way that each worker executes an |
| 136 | + independent group of mining tasks. The FP-Growth algorithm is described in |
| 137 | + Han et al., Mining frequent patterns without candidate generation [HAN2000]_ |
| 138 | +
|
| 139 | + .. [LI2008] http://dx.doi.org/10.1145/1454008.1454027 |
| 140 | + .. [HAN2000] http://dx.doi.org/10.1145/335191.335372 |
| 141 | +
|
| 142 | + .. note:: null values in the feature column are ignored during fit(). |
| 143 | + .. note:: Internally `transform` `collects` and `broadcasts` association rules. |
| 144 | +
|
| 145 | + >>> from pyspark.sql.functions import split |
| 146 | + >>> data = (spark.read |
| 147 | + ... .text("data/mllib/sample_fpgrowth.txt") |
| 148 | + ... .select(split("value", "\s+").alias("items"))) |
| 149 | + >>> data.show(truncate=False) |
| 150 | + +------------------------+ |
| 151 | + |items | |
| 152 | + +------------------------+ |
| 153 | + |[r, z, h, k, p] | |
| 154 | + |[z, y, x, w, v, u, t, s]| |
| 155 | + |[s, x, o, n, r] | |
| 156 | + |[x, z, y, m, t, s, q, e]| |
| 157 | + |[z] | |
| 158 | + |[x, z, y, r, q, t, p] | |
| 159 | + +------------------------+ |
| 160 | + >>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7) |
| 161 | + >>> fpm = fp.fit(data) |
| 162 | + >>> fpm.freqItemsets.show(5) |
| 163 | + +---------+----+ |
| 164 | + | items|freq| |
| 165 | + +---------+----+ |
| 166 | + | [s]| 3| |
| 167 | + | [s, x]| 3| |
| 168 | + |[s, x, z]| 2| |
| 169 | + | [s, z]| 2| |
| 170 | + | [r]| 3| |
| 171 | + +---------+----+ |
| 172 | + only showing top 5 rows |
| 173 | + >>> fpm.associationRules.show(5) |
| 174 | + +----------+----------+----------+ |
| 175 | + |antecedent|consequent|confidence| |
| 176 | + +----------+----------+----------+ |
| 177 | + | [t, s]| [y]| 1.0| |
| 178 | + | [t, s]| [x]| 1.0| |
| 179 | + | [t, s]| [z]| 1.0| |
| 180 | + | [p]| [r]| 1.0| |
| 181 | + | [p]| [z]| 1.0| |
| 182 | + +----------+----------+----------+ |
| 183 | + only showing top 5 rows |
| 184 | + >>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"]) |
| 185 | + >>> sorted(fpm.transform(new_data).first().prediction) |
| 186 | + ['x', 'y', 'z'] |
| 187 | +
|
| 188 | + .. versionadded:: 2.2.0 |
| 189 | + """ |
| 190 | + @keyword_only |
| 191 | + def __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", |
| 192 | + predictionCol="prediction", numPartitions=None): |
| 193 | + """ |
| 194 | + __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ |
| 195 | + predictionCol="prediction", numPartitions=None) |
| 196 | + """ |
| 197 | + super(FPGrowth, self).__init__() |
| 198 | + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.FPGrowth", self.uid) |
| 199 | + self._setDefault(minSupport=0.3, minConfidence=0.8, |
| 200 | + itemsCol="items", predictionCol="prediction") |
| 201 | + kwargs = self._input_kwargs |
| 202 | + self.setParams(**kwargs) |
| 203 | + |
| 204 | + @keyword_only |
| 205 | + @since("2.2.0") |
| 206 | + def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", |
| 207 | + predictionCol="prediction", numPartitions=None): |
| 208 | + """ |
| 209 | + setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ |
| 210 | + predictionCol="prediction", numPartitions=None) |
| 211 | + """ |
| 212 | + kwargs = self._input_kwargs |
| 213 | + return self._set(**kwargs) |
| 214 | + |
| 215 | + def _create_model(self, java_model): |
| 216 | + return FPGrowthModel(java_model) |
0 commit comments