Skip to content

[SPARK-7654][MLLIB] Migrate MLlib to the DataFrame reader/writer API #6281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(dataPath(path))
val dataRDD = sqlContext.read.parquet(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
Expand Down Expand Up @@ -199,7 +199,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(dataPath(path))
val dataRDD = sqlContext.read.parquet(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private[classification] object GLMClassificationModel {
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
val datapath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(datapath)
val dataRDD = sqlContext.read.parquet(datapath)
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
val data = dataArray(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
val dataPath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataFrame = sqlContext.parquetFile(dataPath)
val dataFrame = sqlContext.read.parquet(dataPath)
val dataArray = dataFrame.select("weight", "mu", "sigma").collect()

// Check schema explicitly since erasure makes it hard to use match-case for checking.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ object KMeansModel extends Loader[KMeansModel] {
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val k = (metadata \ "k").extract[Int]
val centriods = sqlContext.parquetFile(Loader.dataPath(path))
val centriods = sqlContext.read.parquet(Loader.dataPath(path))
Loader.checkSchema[Cluster](centriods.schema)
val localCentriods = centriods.map(Cluster.apply).collect()
assert(k == localCentriods.size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
def load(sc: SparkContext, path: String): Word2VecModel = {
val dataPath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataFrame = sqlContext.parquetFile(dataPath)
val dataFrame = sqlContext.read.parquet(dataPath)

val dataArray = dataFrame.select("word", "vector").collect()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,11 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val rank = (metadata \ "rank").extract[Int]
val userFeatures = sqlContext.parquetFile(userPath(path))
val userFeatures = sqlContext.read.parquet(userPath(path))
.map { case Row(id: Int, features: Seq[_]) =>
(id, features.asInstanceOf[Seq[Double]].toArray)
}
val productFeatures = sqlContext.parquetFile(productPath(path))
val productFeatures = sqlContext.read.parquet(productPath(path))
.map { case Row(id: Int, features: Seq[_]) =>
(id, features.asInstanceOf[Seq[Double]].toArray)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {

def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(dataPath(path))
val dataRDD = sqlContext.read.parquet(dataPath(path))

checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("boundary", "prediction").collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private[regression] object GLMRegressionModel {
def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
val datapath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(datapath)
val dataRDD = sqlContext.read.parquet(datapath)
val dataArray = dataRDD.select("weights", "intercept").take(1)
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
val data = dataArray(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
val datapath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(datapath)
val dataRDD = sqlContext.read.parquet(datapath)
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[NodeData](dataRDD.schema)
val nodes = dataRDD.map(NodeData.apply)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ private[tree] object TreeEnsembleModel extends Logging {
treeAlgo: String): Array[DecisionTreeModel] = {
val datapath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply)
val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply)
val trees = constructTrees(nodes)
trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
}
Expand Down