Skip to content

Commit 7550029

Browse files
committed
add ALS.setIntermediateDataStorageLevel
1 parent c235b83 commit 7550029

File tree

1 file changed

+30
-15
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/recommendation

1 file changed

+30
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,17 @@ class ALS private (
111111
*/
112112
def this() = this(-1, -1, 10, 10, 0.01, false, 1.0)
113113

114+
/** If true, do alternating nonnegative least squares. */
115+
private var nonnegative = false
116+
117+
/** storage level for user/product in/out links */
118+
private var intermediateDataStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
119+
114120
/**
115121
* Set the number of blocks for both user blocks and product blocks to parallelize the computation
116122
* into; pass -1 for an auto-configured number of blocks. Default: -1.
117123
*/
118-
def setBlocks(numBlocks: Int): ALS = {
124+
def setBlocks(numBlocks: Int): this.type = {
119125
this.numUserBlocks = numBlocks
120126
this.numProductBlocks = numBlocks
121127
this
@@ -124,39 +130,39 @@ class ALS private (
124130
/**
125131
* Set the number of user blocks to parallelize the computation.
126132
*/
127-
def setUserBlocks(numUserBlocks: Int): ALS = {
133+
def setUserBlocks(numUserBlocks: Int): this.type = {
128134
this.numUserBlocks = numUserBlocks
129135
this
130136
}
131137

132138
/**
133139
* Set the number of product blocks to parallelize the computation.
134140
*/
135-
def setProductBlocks(numProductBlocks: Int): ALS = {
141+
def setProductBlocks(numProductBlocks: Int): this.type = {
136142
this.numProductBlocks = numProductBlocks
137143
this
138144
}
139145

140146
/** Set the rank of the feature matrices computed (number of features). Default: 10. */
141-
def setRank(rank: Int): ALS = {
147+
def setRank(rank: Int): this.type = {
142148
this.rank = rank
143149
this
144150
}
145151

146152
/** Set the number of iterations to run. Default: 10. */
147-
def setIterations(iterations: Int): ALS = {
153+
def setIterations(iterations: Int): this.type = {
148154
this.iterations = iterations
149155
this
150156
}
151157

152158
/** Set the regularization parameter, lambda. Default: 0.01. */
153-
def setLambda(lambda: Double): ALS = {
159+
def setLambda(lambda: Double): this.type = {
154160
this.lambda = lambda
155161
this
156162
}
157163

158164
/** Sets whether to use implicit preference. Default: false. */
159-
def setImplicitPrefs(implicitPrefs: Boolean): ALS = {
165+
def setImplicitPrefs(implicitPrefs: Boolean): this.type = {
160166
this.implicitPrefs = implicitPrefs
161167
this
162168
}
@@ -166,29 +172,38 @@ class ALS private (
166172
* Sets the constant used in computing confidence in implicit ALS. Default: 1.0.
167173
*/
168174
@Experimental
169-
def setAlpha(alpha: Double): ALS = {
175+
def setAlpha(alpha: Double): this.type = {
170176
this.alpha = alpha
171177
this
172178
}
173179

174180
/** Sets a random seed to have deterministic results. */
175-
def setSeed(seed: Long): ALS = {
181+
def setSeed(seed: Long): this.type = {
176182
this.seed = seed
177183
this
178184
}
179185

180-
/** If true, do alternating nonnegative least squares. */
181-
private var nonnegative = false
182-
183186
/**
184187
* Set whether the least-squares problems solved at each iteration should have
185188
* nonnegativity constraints.
186189
*/
187-
def setNonnegative(b: Boolean): ALS = {
190+
def setNonnegative(b: Boolean): this.type = {
188191
this.nonnegative = b
189192
this
190193
}
191194

195+
/**
196+
* :: DeveloperApi ::
197+
* Sets storage level for intermediate RDDs (user/product in/out links). The default value is
198+
* `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g., `MEMORY_AND_DISK_SER` and
199+
* set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed.
200+
*/
201+
@DeveloperApi
202+
def setIntermediateDataStorageLevel(storageLevel: StorageLevel): this.type = {
203+
this.intermediateDataStorageLevel = storageLevel
204+
this
205+
}
206+
192207
/**
193208
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
194209
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -441,8 +456,8 @@ class ALS private (
441456
}, preservesPartitioning = true)
442457
val inLinks = links.mapValues(_._1)
443458
val outLinks = links.mapValues(_._2)
444-
inLinks.persist(StorageLevel.MEMORY_AND_DISK)
445-
outLinks.persist(StorageLevel.MEMORY_AND_DISK)
459+
inLinks.persist(intermediateDataStorageLevel)
460+
outLinks.persist(intermediateDataStorageLevel)
446461
(inLinks, outLinks)
447462
}
448463

0 commit comments

Comments
 (0)