Skip to content

Commit 20328d1

Browse files
committed
Merge remote-tracking branch 'upstream/master' into ldaonline
i
2 parents aa365d1 + 3f00bb3 commit 20328d1

File tree

7 files changed

+129
-33
lines changed

7 files changed

+129
-33
lines changed

core/src/main/scala/org/apache/spark/Accumulators.scala

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,24 @@ object AccumulatorParam {
280280

281281
// TODO: The multi-thread support in accumulators is kind of lame; check
282282
// if there's a more intuitive way of doing it right
283-
private[spark] object Accumulators {
284-
// Store a WeakReference instead of a StrongReference because this way accumulators can be
285-
// appropriately garbage collected during long-running jobs and release memory
286-
type WeakAcc = WeakReference[Accumulable[_, _]]
287-
val originals = Map[Long, WeakAcc]()
288-
val localAccums = new ThreadLocal[Map[Long, WeakAcc]]() {
289-
override protected def initialValue() = Map[Long, WeakAcc]()
283+
private[spark] object Accumulators extends Logging {
284+
/**
285+
* This global map holds the original accumulator objects that are created on the driver.
286+
* It keeps weak references to these objects so that accumulators can be garbage-collected
287+
* once the RDDs and user-code that reference them are cleaned up.
288+
*/
289+
val originals = Map[Long, WeakReference[Accumulable[_, _]]]()
290+
291+
/**
292+
* This thread-local map holds per-task copies of accumulators; it is used to collect the set
293+
* of accumulator updates to send back to the driver when tasks complete. After tasks complete,
294+
* this map is cleared by `Accumulators.clear()` (see Executor.scala).
295+
*/
296+
private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
297+
override protected def initialValue() = Map[Long, Accumulable[_, _]]()
290298
}
291-
var lastId: Long = 0
299+
300+
private var lastId: Long = 0
292301

293302
def newId(): Long = synchronized {
294303
lastId += 1
@@ -297,16 +306,16 @@ private[spark] object Accumulators {
297306

298307
def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
299308
if (original) {
300-
originals(a.id) = new WeakAcc(a)
309+
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
301310
} else {
302-
localAccums.get()(a.id) = new WeakAcc(a)
311+
localAccums.get()(a.id) = a
303312
}
304313
}
305314

306315
// Clear the local (non-original) accumulators for the current thread
307316
def clear() {
308317
synchronized {
309-
localAccums.get.clear
318+
localAccums.get.clear()
310319
}
311320
}
312321

@@ -320,12 +329,7 @@ private[spark] object Accumulators {
320329
def values: Map[Long, Any] = synchronized {
321330
val ret = Map[Long, Any]()
322331
for ((id, accum) <- localAccums.get) {
323-
// Since we are now storing weak references, we must check whether the underlying data
324-
// is valid.
325-
ret(id) = accum.get match {
326-
case Some(values) => values.localValue
327-
case None => None
328-
}
332+
ret(id) = accum.localValue
329333
}
330334
return ret
331335
}
@@ -341,6 +345,8 @@ private[spark] object Accumulators {
341345
case None =>
342346
throw new IllegalAccessError("Attempted to access garbage collected Accumulator.")
343347
}
348+
} else {
349+
logWarning(s"Ignoring accumulator update for unknown accumulator id $id")
344350
}
345351
}
346352
}

docs/mllib-collaborative-filtering.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,8 @@ In the following example we load rating data. Each row consists of a user, a pro
200200
We use the default ALS.train() method which assumes ratings are explicit. We evaluate the
201201
recommendation by measuring the Mean Squared Error of rating prediction.
202202

203-
Note that the Python API does not yet support model save/load but will in the future.
204-
205203
{% highlight python %}
206-
from pyspark.mllib.recommendation import ALS, Rating
204+
from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating
207205

208206
# Load and parse the data
209207
data = sc.textFile("data/mllib/als/test.data")
@@ -220,6 +218,10 @@ predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
220218
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)
221219
MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count()
222220
print("Mean Squared Error = " + str(MSE))
221+
222+
# Save and load model
223+
model.save(sc, "myModelPath")
224+
sameModel = MatrixFactorizationModel.load(sc, "myModelPath")
223225
{% endhighlight %}
224226

225227
If the rating matrix is derived from other source of information (i.e., it is inferred from other

docs/mllib-naive-bayes.md

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,28 @@ used for evaluation and prediction.
115115

116116
Note that the Python API does not yet support model save/load but will in the future.
117117

118-
<!-- TODO: Make Python's example consistent with Scala's and Java's. -->
119118
{% highlight python %}
120-
from pyspark.mllib.regression import LabeledPoint
121119
from pyspark.mllib.classification import NaiveBayes
120+
from pyspark.mllib.linalg import Vectors
121+
from pyspark.mllib.regression import LabeledPoint
122+
123+
def parseLine(line):
124+
parts = line.split(',')
125+
label = float(parts[0])
126+
features = Vectors.dense([float(x) for x in parts[1].split(' ')])
127+
return LabeledPoint(label, features)
128+
129+
data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine)
122130

123-
# an RDD of LabeledPoint
124-
data = sc.parallelize([
125-
LabeledPoint(0.0, [0.0, 0.0])
126-
... # more labeled points
127-
])
131+
# Split data aproximately into training (60%) and test (40%)
132+
training, test = data.randomSplit([0.6, 0.4], seed = 0)
128133

129134
# Train a naive Bayes model.
130-
model = NaiveBayes.train(data, 1.0)
135+
model = NaiveBayes.train(training, 1.0)
131136

132-
# Make prediction.
133-
prediction = model.predict([0.0, 0.0])
137+
# Make prediction and test accuracy.
138+
predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label))
139+
accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count()
134140
{% endhighlight %}
135141

136142
</div>

mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ trait Saveable {
4848
*
4949
* @param sc Spark context used to save model data.
5050
* @param path Path specifying the directory in which to save this model.
51-
* This directory and any intermediate directory will be created if needed.
51+
* If the directory already exists, this method throws an exception.
5252
*/
5353
def save(sc: SparkContext, path: String): Unit
5454

python/pyspark/mllib/recommendation.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from pyspark import SparkContext
2121
from pyspark.rdd import RDD
22-
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
22+
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
23+
from pyspark.mllib.util import Saveable, JavaLoader
2324

2425
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
2526

@@ -39,7 +40,8 @@ def __reduce__(self):
3940
return Rating, (int(self.user), int(self.product), float(self.rating))
4041

4142

42-
class MatrixFactorizationModel(JavaModelWrapper):
43+
@inherit_doc
44+
class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
4345

4446
"""A matrix factorisation model trained by regularized alternating
4547
least-squares.
@@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper):
8183
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
8284
>>> model.predict(2,2)
8385
0.43...
86+
87+
>>> import os, tempfile
88+
>>> path = tempfile.mkdtemp()
89+
>>> model.save(sc, path)
90+
>>> sameModel = MatrixFactorizationModel.load(sc, path)
91+
>>> sameModel.predict(2,2)
92+
0.43...
93+
>>> try:
94+
... os.removedirs(path)
95+
... except:
96+
... pass
8497
"""
8598
def predict(self, user, product):
8699
return self._java_model.predict(int(user), int(product))
@@ -98,6 +111,9 @@ def userFeatures(self):
98111
def productFeatures(self):
99112
return self.call("getProductFeatures")
100113

114+
def save(self, sc, path):
115+
self.call("save", sc._jsc.sc(), path)
116+
101117

102118
class ALS(object):
103119

python/pyspark/mllib/util.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,64 @@ def loadLabeledPoints(sc, path, minPartitions=None):
168168
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
169169

170170

171+
class Saveable(object):
172+
"""
173+
Mixin for models and transformers which may be saved as files.
174+
"""
175+
176+
def save(self, sc, path):
177+
"""
178+
Save this model to the given path.
179+
180+
This saves:
181+
* human-readable (JSON) model metadata to path/metadata/
182+
* Parquet formatted data to path/data/
183+
184+
The model may be loaded using py:meth:`Loader.load`.
185+
186+
:param sc: Spark context used to save model data.
187+
:param path: Path specifying the directory in which to save
188+
this model. If the directory already exists,
189+
this method throws an exception.
190+
"""
191+
raise NotImplementedError
192+
193+
194+
class Loader(object):
195+
"""
196+
Mixin for classes which can load saved models from files.
197+
"""
198+
199+
@classmethod
200+
def load(cls, sc, path):
201+
"""
202+
Load a model from the given path. The model should have been
203+
saved using py:meth:`Saveable.save`.
204+
205+
:param sc: Spark context used for loading model files.
206+
:param path: Path specifying the directory to which the model
207+
was saved.
208+
:return: model instance
209+
"""
210+
raise NotImplemented
211+
212+
213+
class JavaLoader(Loader):
214+
"""
215+
Mixin for classes which can load saved models using its Scala
216+
implementation.
217+
"""
218+
219+
@classmethod
220+
def load(cls, sc, path):
221+
java_package = cls.__module__.replace("pyspark", "org.apache.spark")
222+
java_class = ".".join([java_package, cls.__name__])
223+
java_obj = sc._jvm
224+
for name in java_class.split("."):
225+
java_obj = getattr(java_obj, name)
226+
return cls(java_obj.load(sc._jsc.sc(), path))
227+
228+
171229
def _test():
172230
import doctest
173231
from pyspark.context import SparkContext

sql/core/pom.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,13 @@
109109
<build>
110110
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
111111
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
112+
<resources>
113+
<resource>
114+
<directory>../../python</directory>
115+
<includes>
116+
<include>pyspark/sql/*.py</include>
117+
</includes>
118+
</resource>
119+
</resources>
112120
</build>
113121
</project>

0 commit comments

Comments
 (0)