diff --git a/doc/rdd.md b/doc/rdd.md index 9ef505c7..7bf66cc2 100644 --- a/doc/rdd.md +++ b/doc/rdd.md @@ -137,10 +137,18 @@ sc.toRedisFixedLIST(listRDD, listName, listSize) The `listRDD` is an RDD that contains all of the list's string elements in order, and `listName` is the list's key name. `listSize` is an integer which specifies the size of the Redis list; it is optional, and will default to an unlimited size. -Use the following to store an RDD of binary values in a Redis List: +Use the following to store an RDD in multiple Redis Lists: ```scala -sc.toRedisByteLIST(byteListRDD) +sc.toRedisLISTs(rdd) +``` + +The `rdd` is an RDD of tuples (`list name`, `list values`) + +Use the following to store an RDD of binary values in multiple Redis Lists: + +```scala +sc.toRedisByteLISTs(byteListRDD) ``` The `byteListRDD` is an RDD of tuples (`list name`, `list values`) represented as byte arrays. diff --git a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala index d59394cb..d1e7a22c 100644 --- a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala +++ b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala @@ -301,18 +301,47 @@ class RedisContext(@transient val sc: SparkContext) extends Serializable { } /** - * Write RDD of binary values to Redis List. + * Write RDD of (list name, list values) to Redis Lists. * * @param rdd RDD of tuples (list name, list values) * @param ttl time to live */ - def toRedisByteLIST(rdd: RDD[(Array[Byte], Seq[Array[Byte]])], ttl: Int = 0) - (implicit + def toRedisLISTs(rdd: RDD[(String, Seq[String])], ttl: Int = 0) + (implicit redisConfig: RedisConfig = RedisConfig.fromSparkConf(sc.getConf), readWriteConfig: ReadWriteConfig = ReadWriteConfig.fromSparkConf(sc.getConf)) { rdd.foreachPartition(partition => setList(partition, ttl, redisConfig, readWriteConfig)) } + /** + * Write RDD of binary values to Redis Lists. + * + * @deprecated use toRedisByteLISTs, the method name has changed to make API consistent + * + * @param rdd RDD of tuples (list name, list values) + * @param ttl time to live + */ + @Deprecated + def toRedisByteLIST(rdd: RDD[(Array[Byte], Seq[Array[Byte]])], ttl: Int = 0) + (implicit + redisConfig: RedisConfig = RedisConfig.fromSparkConf(sc.getConf), + readWriteConfig: ReadWriteConfig = ReadWriteConfig.fromSparkConf(sc.getConf)) { + toRedisByteLISTs(rdd, ttl)(redisConfig, readWriteConfig) + } + + /** + * Write RDD of binary values to Redis Lists. + * + * @param rdd RDD of tuples (list name, list values) + * @param ttl time to live + */ + def toRedisByteLISTs(rdd: RDD[(Array[Byte], Seq[Array[Byte]])], ttl: Int = 0) + (implicit + redisConfig: RedisConfig = RedisConfig.fromSparkConf(sc.getConf), + readWriteConfig: ReadWriteConfig = ReadWriteConfig.fromSparkConf(sc.getConf)) { + rdd.foreachPartition(partition => setByteList(partition, ttl, redisConfig, readWriteConfig)) + } + /** * @param vs RDD of values * @param listName target list's name which hold all the vs @@ -430,7 +459,30 @@ object RedisContext extends Serializable { } - def setList(keyValues: Iterator[(Array[Byte], Seq[Array[Byte]])], + def setByteList(keyValues: Iterator[(Array[Byte], Seq[Array[Byte]])], + ttl: Int, + redisConfig: RedisConfig, + readWriteConfig: ReadWriteConfig) { + implicit val rwConf: ReadWriteConfig = readWriteConfig + + keyValues + .map { case (key, listValues) => + (redisConfig.getHost(key), (key, listValues)) + } + .toArray + .groupBy(_._1) + .foreach { case (node, arr) => + withConnection(node.endpoint.connect()) { conn => + foreachWithPipeline(conn, arr) { (pipeline, a) => + val (key, listVals) = a._2 + pipeline.rpush(key, listVals: _*) + if (ttl > 0) pipeline.expire(key, ttl) + } + } + } + } + + def setList(keyValues: Iterator[(String, Seq[String])], ttl: Int, redisConfig: RedisConfig, readWriteConfig: ReadWriteConfig) { diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala new file mode 100644 index 00000000..dc0fe381 --- /dev/null +++ b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala @@ -0,0 +1,51 @@ +package com.redislabs.provider.redis.rdd + +import com.redislabs.provider.redis.util.ConnectionUtils.withConnection +import org.scalatest.Matchers +import com.redislabs.provider.redis._ + +import scala.collection.JavaConverters._ + +/** + * More RDD tests + */ +trait RedisRddExtraSuite extends SparkRedisSuite with Keys with Matchers { + + implicit val redisConfig: RedisConfig + + test("toRedisByteLISTs") { + val list1 = Seq("a1", "b1", "c1") + val list2 = Seq("a2", "b2", "c2") + val keyValues = Seq( + ("binary-list1", list1), + ("binary-list2", list2) + ) + val keyValueBytes = keyValues.map { case (k, list) => (k.getBytes, list.map(_.getBytes())) } + val rdd = sc.parallelize(keyValueBytes) + sc.toRedisByteLISTs(rdd) + + verifyList("binary-list1", list1) + verifyList("binary-list2", list2) + } + + test("toRedisLISTs") { + val list1 = Seq("a1", "b1", "c1") + val list2 = Seq("a2", "b2", "c2") + val keyValues = Seq( + ("list1", list1), + ("list2", list2) + ) + val rdd = sc.parallelize(keyValues) + sc.toRedisLISTs(rdd) + + verifyList("list1", list1) + verifyList("list2", list2) + } + + def verifyList(list: String, vals: Seq[String]): Unit = { + withConnection(redisConfig.getHost(list).endpoint.connect()) { conn => + conn.lrange(list, 0, vals.size).asScala should be(vals.toList) + } + } + +} diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddSuite.scala index 3ac2c0b0..807ef6a1 100644 --- a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddSuite.scala @@ -111,27 +111,6 @@ trait RedisRddSuite extends SparkRedisSuite with Keys with Matchers { setContents should be(ws) } - test("toRedisLIST, byte array") { - val list1 = Seq("a1", "b1", "c1") - val list2 = Seq("a2", "b2", "c2") - val keyValues = Seq( - ("list1", list1), - ("list2", list2) - ) - val keyValueBytes = keyValues.map {case (k, list) => (k.getBytes, list.map(_.getBytes())) } - val rdd = sc.parallelize(keyValueBytes) - sc.toRedisByteLIST(rdd) - - def verify(list: String, vals: Seq[String]): Unit = { - withConnection(redisConfig.getHost(list).endpoint.connect()) { conn => - conn.lrange(list, 0, vals.size).asScala should be(vals.toList) - } - } - - verify("list1", list1) - verify("list2", list2) - } - test("Expire") { val expireTime = 1 val prefix = s"#expire in $expireTime#:" diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/cluster/RedisRddExtraClusterSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/cluster/RedisRddExtraClusterSuite.scala new file mode 100644 index 00000000..bd89fbe7 --- /dev/null +++ b/src/test/scala/com/redislabs/provider/redis/rdd/cluster/RedisRddExtraClusterSuite.scala @@ -0,0 +1,6 @@ +package com.redislabs.provider.redis.rdd.cluster + +import com.redislabs.provider.redis.env.RedisClusterEnv +import com.redislabs.provider.redis.rdd.RedisRddExtraSuite + +class RedisRddExtraClusterSuite extends RedisRddExtraSuite with RedisClusterEnv diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/standalone/RedisRddExtraStandaloneSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/standalone/RedisRddExtraStandaloneSuite.scala new file mode 100644 index 00000000..62446360 --- /dev/null +++ b/src/test/scala/com/redislabs/provider/redis/rdd/standalone/RedisRddExtraStandaloneSuite.scala @@ -0,0 +1,6 @@ +package com.redislabs.provider.redis.rdd.standalone + +import com.redislabs.provider.redis.env.RedisStandaloneEnv +import com.redislabs.provider.redis.rdd.RedisRddExtraSuite + +class RedisRddExtraStandaloneSuite extends RedisRddExtraSuite with RedisStandaloneEnv