Skip to content

Commit b85edfb

Browse files
committed
ALS.
1 parent 8c37f0a commit b85edfb

File tree

3 files changed

+20
-23
lines changed

3 files changed

+20
-23
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ object MovieLensALS {
143143

144144
// Evaluate the model.
145145
// TODO: Create an evaluator to compute RMSE.
146-
val mse = predictions.select('rating, 'prediction)
146+
val mse = predictions.select("rating", "prediction").rdd
147147
.flatMap { case Row(rating: Float, prediction: Float) =>
148148
val err = rating.toDouble - prediction
149149
val err2 = err * err

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner}
2929
import org.apache.spark.ml.{Estimator, Model}
3030
import org.apache.spark.ml.param._
3131
import org.apache.spark.rdd.RDD
32-
import org.apache.spark.sql.SchemaRDD
33-
import org.apache.spark.sql.catalyst.dsl._
34-
import org.apache.spark.sql.catalyst.expressions.Cast
35-
import org.apache.spark.sql.catalyst.plans.LeftOuter
32+
import org.apache.spark.sql.{Column, DataFrame}
33+
import org.apache.spark.sql.dsl._
3634
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
3735
import org.apache.spark.util.Utils
3836
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
@@ -112,21 +110,21 @@ class ALSModel private[ml] (
112110

113111
def setPredictionCol(value: String): this.type = set(predictionCol, value)
114112

115-
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
113+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
116114
import dataset.sqlContext._
117115
import org.apache.spark.ml.recommendation.ALSModel.Factor
118116
val map = this.paramMap ++ paramMap
119117
// TODO: Add DSL to simplify the code here.
120118
val instanceTable = s"instance_$uid"
121119
val userTable = s"user_$uid"
122120
val itemTable = s"item_$uid"
123-
val instances = dataset.as(Symbol(instanceTable))
121+
val instances = dataset.as(instanceTable)
124122
val users = userFactors.map { case (id, features) =>
125123
Factor(id, features)
126-
}.as(Symbol(userTable))
124+
}.as(userTable)
127125
val items = itemFactors.map { case (id, features) =>
128126
Factor(id, features)
129-
}.as(Symbol(itemTable))
127+
}.as(itemTable)
130128
val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
131129
if (userFeatures != null && itemFeatures != null) {
132130
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
@@ -135,12 +133,12 @@ class ALSModel private[ml] (
135133
}
136134
}
137135
val inputColumns = dataset.schema.fieldNames
138-
val prediction =
139-
predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol)
140-
val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction
136+
val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
137+
.as(map(predictionCol))
138+
val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
141139
instances
142-
.join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr))
143-
.join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr))
140+
.join(users, "left", Column(map(userCol)) === $"$userTable.id")
141+
.join(items, "left", Column(map(itemCol)) === $"$itemTable.id")
144142
.select(outputColumns: _*)
145143
}
146144

@@ -209,14 +207,13 @@ class ALS extends Estimator[ALSModel] with ALSParams {
209207
setMaxIter(20)
210208
setRegParam(1.0)
211209

212-
override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = {
213-
import dataset.sqlContext._
210+
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
214211
val map = this.paramMap ++ paramMap
215-
val ratings =
216-
dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType))
217-
.map { row =>
218-
new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
219-
}
212+
val ratings = dataset
213+
.select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
214+
.map { row =>
215+
new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
216+
}
220217
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
221218
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
222219
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
350350
numItemBlocks: Int = 3,
351351
targetRMSE: Double = 0.05): Unit = {
352352
val sqlContext = this.sqlContext
353-
import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute}
353+
import sqlContext.createSchemaRDD
354354
val als = new ALS()
355355
.setRank(rank)
356356
.setRegParam(regParam)
@@ -360,7 +360,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
360360
val alpha = als.getAlpha
361361
val model = als.fit(training)
362362
val predictions = model.transform(test)
363-
.select('rating, 'prediction)
363+
.select("rating", "prediction")
364364
.map { case Row(rating: Float, prediction: Float) =>
365365
(rating.toDouble, prediction.toDouble)
366366
}

0 commit comments

Comments
 (0)