@@ -29,10 +29,8 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner}
29
29
import org .apache .spark .ml .{Estimator , Model }
30
30
import org .apache .spark .ml .param ._
31
31
import org .apache .spark .rdd .RDD
32
- import org .apache .spark .sql .SchemaRDD
33
- import org .apache .spark .sql .catalyst .dsl ._
34
- import org .apache .spark .sql .catalyst .expressions .Cast
35
- import org .apache .spark .sql .catalyst .plans .LeftOuter
32
+ import org .apache .spark .sql .{Column , DataFrame }
33
+ import org .apache .spark .sql .dsl ._
36
34
import org .apache .spark .sql .types .{DoubleType , FloatType , IntegerType , StructField , StructType }
37
35
import org .apache .spark .util .Utils
38
36
import org .apache .spark .util .collection .{OpenHashMap , OpenHashSet , SortDataFormat , Sorter }
@@ -112,21 +110,21 @@ class ALSModel private[ml] (
112
110
113
111
def setPredictionCol (value : String ): this .type = set(predictionCol, value)
114
112
115
- override def transform (dataset : SchemaRDD , paramMap : ParamMap ): SchemaRDD = {
113
+ override def transform (dataset : DataFrame , paramMap : ParamMap ): DataFrame = {
116
114
import dataset .sqlContext ._
117
115
import org .apache .spark .ml .recommendation .ALSModel .Factor
118
116
val map = this .paramMap ++ paramMap
119
117
// TODO: Add DSL to simplify the code here.
120
118
val instanceTable = s " instance_ $uid"
121
119
val userTable = s " user_ $uid"
122
120
val itemTable = s " item_ $uid"
123
- val instances = dataset.as(Symbol ( instanceTable) )
121
+ val instances = dataset.as(instanceTable)
124
122
val users = userFactors.map { case (id, features) =>
125
123
Factor (id, features)
126
- }.as(Symbol ( userTable) )
124
+ }.as(userTable)
127
125
val items = itemFactors.map { case (id, features) =>
128
126
Factor (id, features)
129
- }.as(Symbol ( itemTable) )
127
+ }.as(itemTable)
130
128
val predict : (Seq [Float ], Seq [Float ]) => Float = (userFeatures, itemFeatures) => {
131
129
if (userFeatures != null && itemFeatures != null ) {
132
130
blas.sdot(k, userFeatures.toArray, 1 , itemFeatures.toArray, 1 )
@@ -135,12 +133,12 @@ class ALSModel private[ml] (
135
133
}
136
134
}
137
135
val inputColumns = dataset.schema.fieldNames
138
- val prediction =
139
- predict.call( s " $userTable .features " .attr, s " $itemTable .features " .attr) as map (predictionCol)
140
- val outputColumns = inputColumns.map(f => s " $instanceTable. $f" .attr as f ) :+ prediction
136
+ val prediction = callUDF(predict, $ " $userTable.features " , $ " $itemTable.features " )
137
+ .as( map(predictionCol) )
138
+ val outputColumns = inputColumns.map(f => $ " $instanceTable.$f" .as(f) ) :+ prediction
141
139
instances
142
- .join(users, LeftOuter , Some (map(userCol).attr === s " $userTable.id " .attr) )
143
- .join(items, LeftOuter , Some (map(itemCol).attr === s " $itemTable.id " .attr) )
140
+ .join(users, " left " , Column (map(userCol)) === $ " $userTable.id" )
141
+ .join(items, " left " , Column (map(itemCol)) === $ " $itemTable.id" )
144
142
.select(outputColumns : _* )
145
143
}
146
144
@@ -209,14 +207,13 @@ class ALS extends Estimator[ALSModel] with ALSParams {
209
207
setMaxIter(20 )
210
208
setRegParam(1.0 )
211
209
212
- override def fit (dataset : SchemaRDD , paramMap : ParamMap ): ALSModel = {
213
- import dataset .sqlContext ._
210
+ override def fit (dataset : DataFrame , paramMap : ParamMap ): ALSModel = {
214
211
val map = this .paramMap ++ paramMap
215
- val ratings =
216
- dataset .select(map(userCol).attr, map(itemCol).attr, Cast (map(ratingCol).attr, FloatType ))
217
- .map { row =>
218
- new Rating (row.getInt(0 ), row.getInt(1 ), row.getFloat(2 ))
219
- }
212
+ val ratings = dataset
213
+ .select(Column ( map(userCol)), Column ( map(itemCol)), Column (map(ratingCol)).cast( FloatType ))
214
+ .map { row =>
215
+ new Rating (row.getInt(0 ), row.getInt(1 ), row.getFloat(2 ))
216
+ }
220
217
val (userFactors, itemFactors) = ALS .train(ratings, rank = map(rank),
221
218
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
222
219
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
0 commit comments