Skip to content

Commit 78c4671

Browse files
committed
add libSVMFile to MLContext
1 parent f0fe616 commit 78c4671

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib
19+
20+
import org.apache.spark.SparkContext
21+
22+
import org.apache.spark.mllib.linalg.Vectors
23+
import org.apache.spark.mllib.regression.LabeledPoint
24+
import org.apache.spark.rdd.RDD
25+
26+
class MLContext(self: SparkContext) {
27+
/**
28+
* Reads labeled data in the LIBSVM format into an RDD[LabeledPoint].
29+
* The LIBSVM format is a text-based format used by LIBSVM (http://www.csie.ntu.edu.tw/~cjlin/libsvm/).
30+
* Each line represents a labeled sparse feature vector using the following format:
31+
* {{{label index1:value1 index2:value2 ...}}}
32+
* where the indices are one-based and in ascending order.
33+
* This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]] instance,
34+
* where the feature indices are converted to zero-based.
35+
*
36+
* @param path file or directory path in any Hadoop-supported file system URI
37+
* @param numFeatures number of features
38+
* @param labelParser parser for labels, default: _.toDouble
39+
* @return labeled data stored as an RDD[LabeledPoint]
40+
*/
41+
def libSVMFile(
42+
path: String,
43+
numFeatures: Int,
44+
labelParser: String => Double = _.toDouble): RDD[LabeledPoint] = {
45+
self.textFile(path).map(_.trim).filter(!_.isEmpty).map { line =>
46+
val items = line.split(' ')
47+
val label = labelParser(items.head)
48+
val features = Vectors.sparse(numFeatures, items.tail.map { item =>
49+
val indexAndValue = item.split(':')
50+
val index = indexAndValue(0).toInt - 1
51+
val value = indexAndValue(1).toDouble
52+
(index, value)
53+
})
54+
LabeledPoint(label, features)
55+
}
56+
}
57+
}
58+
59+
object MLContext {
60+
implicit def sparkContextToMLContext(sc: SparkContext): MLContext = new MLContext(sc)
61+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib
19+
20+
import org.apache.spark.mllib.MLContext._
21+
import org.apache.spark.mllib.util.LocalSparkContext
22+
import org.scalatest.FunSuite
23+
import com.google.common.io.Files
24+
import java.io.File
25+
import com.google.common.base.Charsets
26+
import org.apache.spark.mllib.linalg.Vectors
27+
28+
class MLContextSuite extends FunSuite with LocalSparkContext {
29+
test("libSVMFile") {
30+
val lines =
31+
"""
32+
|1 1:1.0 3:2.0 5:3.0
33+
|0 2:4.0 4:5.0 6:6.0
34+
""".stripMargin
35+
val tempDir = Files.createTempDir()
36+
val file = new File(tempDir.getPath, "part-00000")
37+
Files.write(lines, file, Charsets.US_ASCII)
38+
val points = sc.libSVMFile(tempDir.toURI.toString, 6).collect()
39+
assert(points.length === 2)
40+
assert(points(0).label === 1.0)
41+
assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
42+
assert(points(1).label === 0.0)
43+
assert(points(1).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
44+
try {
45+
file.delete()
46+
tempDir.delete()
47+
} catch {
48+
case t: Throwable =>
49+
}
50+
}
51+
}

0 commit comments

Comments
 (0)