Skip to content

Commit 5b23452

Browse files
committed
Merge pull request #29 from RedisLabs/sql
Add support for Spark SQL
2 parents 3c15858 + 2688047 commit 5b23452

File tree

3 files changed

+345
-0
lines changed

3 files changed

+345
-0
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
package com.redislabs.provider.redis.sql
2+
3+
import java.util
4+
5+
import scala.collection.JavaConversions._
6+
import com.redislabs.provider.redis._
7+
import com.redislabs.provider.redis.rdd.{Keys, RedisKeysRDD}
8+
import org.apache.spark.rdd.RDD
9+
import org.apache.spark.sql.{DataFrame, SQLContext}
10+
import org.apache.spark.sql.catalyst.expressions.Row
11+
import org.apache.spark.sql.sources._
12+
import org.apache.spark.sql.types._
13+
import redis.clients.jedis.Protocol
14+
import redis.clients.util.JedisClusterCRC16
15+
import java.security.MessageDigest
16+
17+
18+
case class RedisRelation(parameters: Map[String, String], userSchema: StructType)
19+
(@transient val sqlContext: SQLContext)
20+
extends BaseRelation with PrunedFilteredScan with InsertableRelation with Keys {
21+
22+
val tableName: String = parameters.getOrElse("table", "PANIC")
23+
24+
val redisConfig: RedisConfig = {
25+
new RedisConfig({
26+
if ((parameters.keySet & Set("host", "port", "auth", "dbNum", "timeout")).size == 0) {
27+
new RedisEndpoint(sqlContext.sparkContext.getConf)
28+
} else {
29+
val host = parameters.getOrElse("host", Protocol.DEFAULT_HOST)
30+
val port = parameters.getOrElse("port", Protocol.DEFAULT_PORT.toString).toInt
31+
val auth = parameters.getOrElse("auth", null)
32+
val dbNum = parameters.getOrElse("dbNum", Protocol.DEFAULT_DATABASE.toString).toInt
33+
val timeout = parameters.getOrElse("timeout", Protocol.DEFAULT_TIMEOUT.toString).toInt
34+
new RedisEndpoint(host, port, auth, dbNum, timeout)
35+
}
36+
}
37+
)
38+
}
39+
40+
val partitionNum: Int = parameters.getOrElse("partitionNum", 3.toString).toInt
41+
42+
val schema = userSchema
43+
44+
def getNode(key: String): RedisNode = {
45+
val slot = JedisClusterCRC16.getSlot(key)
46+
/* Master only */
47+
redisConfig.hosts.filter(node => { node.startSlot <= slot && node.endSlot >= slot }).filter(_.idx == 0)(0)
48+
}
49+
50+
def insert(data: DataFrame, overwrite: Boolean): Unit = {
51+
data.foreachPartition{
52+
partition => {
53+
val m: Map[String, Row] = partition.map {
54+
row => {
55+
val tn = tableName + ":" + MessageDigest.getInstance("MD5").digest(
56+
row.getValuesMap(schema.fieldNames).map(_._2.toString).reduce(_ + " " + _).getBytes)
57+
(tn, row)
58+
}
59+
}.toMap
60+
groupKeysByNode(redisConfig.hosts, m.keysIterator).foreach{
61+
case(node, keys) => {
62+
val conn = node.connect
63+
val pipeline = conn.pipelined
64+
keys.foreach{
65+
key => {
66+
val row = m.get(key).get
67+
pipeline.hmset(key, row.getValuesMap(row.schema.fieldNames).map(x => (x._1, x._2.toString)))
68+
}
69+
}
70+
pipeline.sync
71+
conn.close
72+
}
73+
}
74+
}
75+
}
76+
}
77+
78+
def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
79+
val colsForFilter = filters.map(getAttr(_)).sorted.distinct
80+
val colsForFilterWithIndex = colsForFilter.zipWithIndex.toMap
81+
val requiredColumnsType = requiredColumns.map(getDataType(_))
82+
new RedisKeysRDD(sqlContext.sparkContext, redisConfig, tableName + ":*", partitionNum, null).
83+
mapPartitions {
84+
partition: Iterator[String] => {
85+
groupKeysByNode(redisConfig.hosts, partition).flatMap {
86+
x => {
87+
val conn = x._1.endpoint.connect()
88+
val pipeline = conn.pipelined
89+
val keys: Array[String] = filterKeysByType(conn, x._2, "hash")
90+
val rowKeys = if (colsForFilter.length == 0) {
91+
keys
92+
} else {
93+
keys.foreach(key => pipeline.hmget(key, colsForFilter:_*))
94+
keys.zip(pipeline.syncAndReturnAll).filter {
95+
x => {
96+
val content = x._2.asInstanceOf[util.ArrayList[String]]
97+
filters.forall {
98+
filter => parseFilter(filter, content(colsForFilterWithIndex.get(getAttr(filter)).get))
99+
}
100+
}
101+
}.map(_._1)
102+
}
103+
104+
rowKeys.foreach(pipeline.hmget(_, requiredColumns:_*))
105+
val res = pipeline.syncAndReturnAll.map{
106+
_.asInstanceOf[util.ArrayList[String]].zip(requiredColumnsType).map {
107+
case(col, targetType) => castToTarget(col, targetType)
108+
}
109+
}
110+
conn.close
111+
res
112+
}
113+
}.toIterator.map(Row.fromSeq(_))
114+
}
115+
}
116+
}
117+
118+
private def getAttr(f: Filter): String = {
119+
f match {
120+
case EqualTo(attribute, value) => attribute
121+
case GreaterThan(attribute, value) => attribute
122+
case GreaterThanOrEqual(attribute, value) => attribute
123+
case LessThan(attribute, value) => attribute
124+
case LessThanOrEqual(attribute, value) => attribute
125+
case In(attribute, values) => attribute
126+
case IsNull(attribute) => attribute
127+
case IsNotNull(attribute) => attribute
128+
case StringStartsWith(attribute, value) => attribute
129+
case StringEndsWith(attribute, value) => attribute
130+
case StringContains(attribute, value) => attribute
131+
}
132+
}
133+
134+
private def castToTarget(value: String, dataType: DataType) = {
135+
dataType match {
136+
case IntegerType => value.toString.toInt
137+
case DoubleType => value.toString.toDouble
138+
case StringType => value.toString
139+
case _ => value.toString
140+
}
141+
}
142+
143+
private def getDataType(attr: String) = {
144+
schema.fields(schema.fieldIndex(attr)).dataType
145+
}
146+
private def parseFilter(f: Filter, target: String) = {
147+
f match {
148+
case EqualTo(attribute, value) => {
149+
value.toString == target
150+
}
151+
case GreaterThan(attribute, value) => {
152+
getDataType(attribute) match {
153+
case IntegerType => value.toString.toInt < target.toInt
154+
case DoubleType => value.toString.toDouble < target.toDouble
155+
case StringType => value.toString < target
156+
case _ => value.toString < target
157+
}
158+
}
159+
case GreaterThanOrEqual(attribute, value) => {
160+
getDataType(attribute) match {
161+
case IntegerType => value.toString.toInt <= target.toInt
162+
case DoubleType => value.toString.toDouble <= target.toDouble
163+
case StringType => value.toString <= target
164+
case _ => value.toString <= target
165+
}
166+
}
167+
case LessThan(attribute, value) => {
168+
getDataType(attribute) match {
169+
case IntegerType => value.toString.toInt > target.toInt
170+
case DoubleType => value.toString.toDouble > target.toDouble
171+
case StringType => value.toString > target
172+
case _ => value.toString > target
173+
}
174+
}
175+
case LessThanOrEqual(attribute, value) => {
176+
getDataType(attribute) match {
177+
case IntegerType => value.toString.toInt >= target.toInt
178+
case DoubleType => value.toString.toDouble >= target.toDouble
179+
case StringType => value.toString >= target
180+
case _ => value.toString >= target
181+
}
182+
}
183+
case In(attribute, values) => {
184+
getDataType(attribute) match {
185+
case IntegerType => values.map(_.toString.toInt).contains(target.toInt)
186+
case DoubleType => values.map(_.toString.toDouble).contains(target.toDouble)
187+
case StringType => values.map(_.toString).contains(target)
188+
case _ => values.map(_.toString).contains(target)
189+
}
190+
}
191+
case IsNull(attribute) => target == null
192+
case IsNotNull(attribute) => target != null
193+
case StringStartsWith(attribute, value) => target.startsWith(value.toString)
194+
case StringEndsWith(attribute, value) => target.endsWith(value.toString)
195+
case StringContains(attribute, value) => target.contains(value.toString)
196+
case _ => false
197+
}
198+
}
199+
}
200+
201+
class DefaultSource extends SchemaRelationProvider {
202+
def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType) = {
203+
RedisRelation(parameters, schema)(sqlContext)
204+
}
205+
}
206+
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package com.redislabs.provider.redis.rdd
2+
3+
import org.apache.spark.{SparkContext, SparkConf}
4+
import org.scalatest.{BeforeAndAfterAll, ShouldMatchers, FunSuite}
5+
import org.apache.spark.sql.SQLContext
6+
import com.redislabs.provider.redis._
7+
8+
class RedisSparkSQLClusterSuite extends FunSuite with ENV with BeforeAndAfterAll with ShouldMatchers {
9+
10+
var sqlContext: SQLContext = null
11+
override def beforeAll() {
12+
super.beforeAll()
13+
14+
sc = new SparkContext(new SparkConf()
15+
.setMaster("local").setAppName(getClass.getName)
16+
.set("redis.host", "127.0.0.1")
17+
.set("redis.port", "7379")
18+
)
19+
redisConfig = new RedisConfig(new RedisEndpoint("127.0.0.1", 7379))
20+
21+
// Flush all the hosts
22+
redisConfig.hosts.foreach( node => {
23+
val conn = node.connect
24+
conn.flushAll
25+
conn.close
26+
})
27+
28+
sqlContext = new SQLContext(sc)
29+
sqlContext.sql( s"""
30+
|CREATE TEMPORARY TABLE rl
31+
|(name STRING, score INT)
32+
|USING com.redislabs.provider.redis.sql
33+
|OPTIONS (table 'rl')
34+
""".stripMargin)
35+
36+
(1 to 64).foreach{
37+
index => {
38+
sqlContext.sql(s"insert overwrite table rl select t.* from (select 'rl${index}', ${index}) t")
39+
}
40+
}
41+
}
42+
43+
test("RedisKVRDD - default(cluster)") {
44+
val df = sqlContext.sql(
45+
s"""
46+
|SELECT *
47+
|FROM rl
48+
""".stripMargin)
49+
df.filter(df("score") > 10).count should be (54)
50+
df.filter(df("score") > 10 and df("score") < 20).count should be (9)
51+
}
52+
53+
test("RedisKVRDD - cluster") {
54+
implicit val c: RedisConfig = redisConfig
55+
val df = sqlContext.sql(
56+
s"""
57+
|SELECT *
58+
|FROM rl
59+
""".stripMargin)
60+
df.filter(df("score") > 10).count should be (54)
61+
df.filter(df("score") > 10 and df("score") < 20).count should be (9)
62+
}
63+
64+
override def afterAll(): Unit = {
65+
sc.stop
66+
System.clearProperty("spark.driver.port")
67+
}
68+
}
69+
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package com.redislabs.provider.redis.rdd
2+
3+
import com.redislabs.provider.redis._
4+
import org.apache.spark.sql.SQLContext
5+
import org.apache.spark.{SparkConf, SparkContext}
6+
import org.scalatest.{BeforeAndAfterAll, FunSuite, ShouldMatchers}
7+
8+
class RedisSparkSQLStandaloneSuite extends FunSuite with ENV with BeforeAndAfterAll with ShouldMatchers {
9+
10+
var sqlContext: SQLContext = null
11+
override def beforeAll() {
12+
super.beforeAll()
13+
14+
sc = new SparkContext(new SparkConf()
15+
.setMaster("local").setAppName(getClass.getName)
16+
.set("redis.host", "127.0.0.1")
17+
.set("redis.port", "6379")
18+
.set("redis.auth", "passwd")
19+
)
20+
redisConfig = new RedisConfig(new RedisEndpoint("127.0.0.1", 6379, "passwd"))
21+
22+
// Flush all the hosts
23+
redisConfig.hosts.foreach( node => {
24+
val conn = node.connect
25+
conn.flushAll
26+
conn.close
27+
})
28+
29+
sqlContext = new SQLContext(sc)
30+
sqlContext.sql( s"""
31+
|CREATE TEMPORARY TABLE rl
32+
|(name STRING, score INT)
33+
|USING com.redislabs.provider.redis.sql
34+
|OPTIONS (table 'rl')
35+
""".stripMargin)
36+
37+
(1 to 64).foreach{
38+
index => {
39+
sqlContext.sql(s"insert overwrite table rl select t.* from (select 'rl${index}', ${index}) t")
40+
}
41+
}
42+
}
43+
44+
test("RedisKVRDD - default(cluster)") {
45+
val df = sqlContext.sql(
46+
s"""
47+
|SELECT *
48+
|FROM rl
49+
""".stripMargin)
50+
df.filter(df("score") > 10).count should be (54)
51+
df.filter(df("score") > 10 and df("score") < 20).count should be (9)
52+
}
53+
54+
test("RedisKVRDD - cluster") {
55+
implicit val c: RedisConfig = redisConfig
56+
val df = sqlContext.sql(
57+
s"""
58+
|SELECT *
59+
|FROM rl
60+
""".stripMargin)
61+
df.filter(df("score") > 10).count should be (54)
62+
df.filter(df("score") > 10 and df("score") < 20).count should be (9)
63+
}
64+
65+
override def afterAll(): Unit = {
66+
sc.stop
67+
System.clearProperty("spark.driver.port")
68+
}
69+
}
70+

0 commit comments

Comments
 (0)