Skip to content

Commit d183fdb

Browse files
authored
Merge pull request #133 from RedisLabs/fix-issue-132
fix issue #132: scala.MatchError on dataframe read
2 parents 6aaaba2 + 68c8f2a commit d183fdb

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,13 @@ class RedisSourceRelation(override val sqlContext: SQLContext,
157157
new GenericRow(Array[Any]())
158158
}
159159
} else {
160-
val filteredSchema = {
161-
val requiredColumnsSet = Set(requiredColumns: _*)
162-
val filteredFields = schema.fields
163-
.filter { f =>
164-
requiredColumnsSet.contains(f.name)
165-
}
166-
StructType(filteredFields)
160+
// filter schema columns, it should be in the same order as given 'requiredColumns'
161+
val requiredSchema = {
162+
val fieldsMap = schema.fields.map(f => (f.name, f)).toMap
163+
val requiredFields = requiredColumns.map { c =>
164+
fieldsMap(c)
165+
}
166+
StructType(requiredFields)
167167
}
168168
val keyType =
169169
if (persistenceModel == SqlOptionModelBinary) {
@@ -173,12 +173,12 @@ class RedisSourceRelation(override val sqlContext: SQLContext,
173173
}
174174
keysRdd.mapPartitions { partition =>
175175
// grouped iterator to only allocate memory for a portion of rows
176-
partition.grouped(iteratorGroupingSize).map { batch =>
176+
partition.grouped(iteratorGroupingSize).flatMap { batch =>
177177
groupKeysByNode(redisConfig.hosts, batch.iterator)
178178
.flatMap { case (node, keys) =>
179-
scanRows(node, keys, keyType, filteredSchema, requiredColumns)
179+
scanRows(node, keys, keyType, requiredSchema, requiredColumns)
180180
}
181-
}.flatten
181+
}
182182
}
183183
}
184184
}

src/test/scala/com/redislabs/provider/redis/df/HashDataframeSuite.scala

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
package com.redislabs.provider.redis.df
22

33
import java.sql.{Date, Timestamp}
4+
import java.util.UUID
45

56
import com.redislabs.provider.redis.toRedisContext
67
import com.redislabs.provider.redis.util.Person.{data, _}
78
import com.redislabs.provider.redis.util.TestUtils._
89
import com.redislabs.provider.redis.util.{EntityId, Person}
910
import org.apache.spark.SparkException
10-
import org.apache.spark.sql.DataFrame
11+
import org.apache.spark.sql.{DataFrame, Row}
1112
import org.apache.spark.sql.redis.RedisSourceRelation.tableDataKeyPattern
1213
import org.apache.spark.sql.redis._
13-
import org.apache.spark.sql.types._
14+
import org.apache.spark.sql.types.{StructField, _}
1415
import org.scalatest.Matchers
1516

17+
import scala.util.Random
18+
1619
/**
1720
* @author The Viet Nguyen
1821
*/
@@ -295,6 +298,37 @@ trait HashDataframeSuite extends RedisDataframeSuite with Matchers {
295298
}
296299
}
297300

301+
/**
302+
* A test case for https://github.com/RedisLabs/spark-redis/issues/132
303+
*/
304+
test("RedisSourceRelation.buildScan columns ordering") {
305+
val schema = {
306+
StructType(Array(
307+
StructField("id", StringType),
308+
StructField("int", IntegerType),
309+
StructField("float", FloatType),
310+
StructField("double", DoubleType),
311+
StructField("str", StringType)))
312+
}
313+
314+
val rowsNum = 8
315+
val rdd = spark.sparkContext.parallelize(1 to rowsNum, 2).map { _ =>
316+
def genStr = UUID.randomUUID().toString
317+
def genInt = Random.nextInt()
318+
def genDouble = Random.nextDouble()
319+
def genFloat = Random.nextFloat()
320+
Row.fromSeq(Seq(genStr, genInt, genFloat, genDouble, genStr))
321+
}
322+
323+
val df = spark.createDataFrame(rdd, schema)
324+
val tableName = generateTableName("cols-ordering")
325+
df.write.format(RedisFormat).option(SqlOptionTableName, tableName).save()
326+
val loadedDf = spark.read.format(RedisFormat).option(SqlOptionTableName, tableName).load()
327+
loadedDf.schema shouldBe schema
328+
loadedDf.collect().length shouldBe rowsNum
329+
loadedDf.show()
330+
}
331+
298332
def saveMap(tableName: String): Unit = {
299333
Person.dataMaps.foreach { person =>
300334
saveMap(tableName, person("name"), person)

0 commit comments

Comments
 (0)