Skip to content

Commit 0a2edf0

Browse files
committed
Add inline comments to ALS.train method
1 parent fb8f16d commit 0a2edf0

File tree

1 file changed

+19
-7
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/recommendation

1 file changed

+19
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -813,32 +813,43 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
813813
checkpointInterval: Int = 10,
814814
seed: Long = 0L)(
815815
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
816+
816817
require(!ratings.isEmpty(), s"No ratings available from $ratings")
817818
require(intermediateRDDStorageLevel != StorageLevel.NONE,
818819
"ALS is not designed to run without persisting intermediate RDDs.")
820+
819821
val sc = ratings.sparkContext
822+
823+
// Precompute the rating dependencies of each partition
820824
val userPart = new ALSPartitioner(numUserBlocks)
821825
val itemPart = new ALSPartitioner(numItemBlocks)
822-
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
823-
val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
824-
val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
825826
val blockRatings = partitionRatings(ratings, userPart, itemPart)
826827
.persist(intermediateRDDStorageLevel)
827828
val (userInBlocks, userOutBlocks) =
828829
makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel)
829-
// materialize blockRatings and user blocks
830-
userOutBlocks.count()
830+
userOutBlocks.count() // materialize blockRatings and user blocks
831831
val swappedBlockRatings = blockRatings.map {
832832
case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
833833
((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
834834
}
835835
val (itemInBlocks, itemOutBlocks) =
836836
makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel)
837-
// materialize item blocks
838-
itemOutBlocks.count()
837+
itemOutBlocks.count() // materialize item blocks
838+
839+
// Encoders for storing each user/item's partition ID and index within its partition using a
840+
// single integer; used as an optimization
841+
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
842+
val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
843+
844+
// These are the user and item factor matrices that, once trained, are multiplied together to
845+
// estimate the rating matrix. The two matrices are stored in RDDs, partitioned by column such
846+
// that each factor column resides on the same Spark worker as its corresponding user or item.
839847
val seedGen = new XORShiftRandom(seed)
840848
var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
841849
var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
850+
851+
val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
852+
842853
var previousCheckpointFile: Option[String] = None
843854
val shouldCheckpoint: Int => Boolean = (iter) =>
844855
sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0)
@@ -852,6 +863,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
852863
logWarning(s"Cannot delete checkpoint file $file:", e)
853864
}
854865
}
866+
855867
if (implicitPrefs) {
856868
for (iter <- 1 to maxIter) {
857869
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)

0 commit comments

Comments
 (0)