Skip to content

Commit 74d1d9e

Browse files
[SC-5504][REDSHIFT] Fix filter pushdown when no columns are selected
Escaping single quotes should be done during creation of the UNLOAD statement. This also fixes a longstanding bug where backslashes were not escaped properly when they appeared in string literals in generated `WHERE` clauses. Author: Josh Rosen <[email protected]> Author: Juliusz Sompolski <[email protected]> Author: Adrian Ionescu <[email protected]> Closes apache#174 from juliuszsompolski/sc5504.
1 parent 26249f1 commit 74d1d9e

File tree

8 files changed

+92
-14
lines changed

8 files changed

+92
-14
lines changed

external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ trait IntegrationSuiteBase
160160
|(0, null, '2015-07-03', 0.0, -1.0, 4141214, 1239012341823719, null, 'f', '2015-07-03 00:00:00.000'),
161161
|(0, false, null, -1234152.12312498, 100000.0, null, 1239012341823719, 24, '___|_123', null),
162162
|(1, false, '2015-07-02', 0.0, 0.0, 42, 1239012341823719, -13, 'asdf', '2015-07-02 00:00:00.000'),
163-
|(1, true, '2015-07-01', 1234152.12312498, 1.0, 42, 1239012341823719, 23, 'Unicode''s樂趣', '2015-07-01 00:00:00.001')
163+
|(1, true, '2015-07-01', 1234152.12312498, 1.0, 42, 1239012341823719, 23, 'Unicode''s樂趣', '2015-07-01 00:00:00.001'),
164+
|(null, null, null, null, null, null, null, null, 'Ba\\\\ckslash\\\\', null)
164165
""".stripMargin
165166
)
166167
// scalastyle:on

external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftReadSuite.scala

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class RedshiftReadSuite extends IntegrationSuiteBase {
110110
test("Can load output of Redshift aggregation queries") {
111111
checkAnswer(
112112
read.option("query", s"select testbool, count(*) from $test_table group by testbool").load(),
113-
Seq(Row(true, 1), Row(false, 2), Row(null, 2)))
113+
Seq(Row(true, 1), Row(false, 2), Row(null, 3)))
114114
}
115115

116116
test("multiple scans on same table") {
@@ -128,6 +128,7 @@ class RedshiftReadSuite extends IntegrationSuiteBase {
128128
checkAnswer(
129129
sqlContext.sql("select testbyte, testbool from test_table"),
130130
Seq(
131+
Row(null, null),
131132
Row(null, null),
132133
Row(0.toByte, null),
133134
Row(0.toByte, false),
@@ -240,4 +241,72 @@ class RedshiftReadSuite extends IntegrationSuiteBase {
240241
.load()
241242
assert(df.schema.fields(0).dataType === LongType)
242243
}
244+
245+
test("properly escape literals in filter pushdown (SC-5504)") {
246+
checkAnswer(
247+
sqlContext.sql("select count(1) from test_table where testint = 4141214"),
248+
Seq(Row(1))
249+
)
250+
checkAnswer(
251+
sqlContext.sql("select count(1) from test_table where testint = 7"),
252+
Seq(Row(0))
253+
)
254+
checkAnswer(
255+
sqlContext.sql("select testint from test_table where testint = 42"),
256+
Seq(Row(42), Row(42))
257+
)
258+
259+
checkAnswer(
260+
sqlContext.sql("select count(1) from test_table where teststring = 'asdf'"),
261+
Seq(Row(1))
262+
)
263+
checkAnswer(
264+
sqlContext.sql("select count(1) from test_table where teststring = 'alamakota'"),
265+
Seq(Row(0))
266+
)
267+
checkAnswer(
268+
sqlContext.sql("select teststring from test_table where teststring = 'asdf'"),
269+
Seq(Row("asdf"))
270+
)
271+
272+
checkAnswer(
273+
sqlContext.sql("select count(1) from test_table where teststring = 'a\\'b'"),
274+
Seq(Row(0))
275+
)
276+
checkAnswer(
277+
sqlContext.sql("select teststring from test_table where teststring = 'a\\'b'"),
278+
Seq()
279+
)
280+
281+
// scalastyle:off
282+
checkAnswer(
283+
sqlContext.sql("select count(1) from test_table where teststring = 'Unicode\\'s樂趣'"),
284+
Seq(Row(1))
285+
)
286+
checkAnswer(
287+
sqlContext.sql("select teststring from test_table where teststring = \"Unicode's樂趣\""),
288+
Seq(Row("Unicode's樂趣"))
289+
)
290+
// scalastyle:on
291+
292+
checkAnswer(
293+
sqlContext.sql("select count(1) from test_table where teststring = 'a\\\\b'"),
294+
Seq(Row(0))
295+
)
296+
checkAnswer(
297+
sqlContext.sql("select teststring from test_table where teststring = 'a\\\\b'"),
298+
Seq()
299+
)
300+
301+
checkAnswer(
302+
sqlContext.sql(
303+
"select count(1) from test_table where teststring = 'Ba\\\\ckslash\\\\'"),
304+
Seq(Row(1))
305+
)
306+
checkAnswer(
307+
sqlContext.sql(
308+
"select teststring from test_table where teststring = \"Ba\\\\ckslash\\\\\""),
309+
Seq(Row("Ba\\ckslash\\"))
310+
)
311+
}
243312
}

external/redshift/src/main/scala/com/databricks/spark/redshift/FilterPushdown.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ private[redshift] object FilterPushdown {
4040
def buildComparison(attr: String, value: Any, comparisonOp: String): Option[String] = {
4141
getTypeForAttribute(schema, attr).map { dataType =>
4242
val sqlEscapedValue: String = dataType match {
43-
case StringType => s"\\'${value.toString.replace("'", "\\'\\'")}\\'"
44-
case DateType => s"\\'${value.asInstanceOf[Date]}\\'"
45-
case TimestampType => s"\\'${value.asInstanceOf[Timestamp]}\\'"
43+
case StringType => s"'${value.toString.replace("'", "''").replace("\\", "\\\\")}'"
44+
case DateType => s"'${value.asInstanceOf[Date]}'"
45+
case TimestampType => s"'${value.asInstanceOf[Timestamp]}'"
4646
case _ => value.toString
4747
}
4848
s""""$attr" $comparisonOp $sqlEscapedValue"""

external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import scala.collection.JavaConverters._
1717
import com.amazonaws.auth.AWSCredentialsProvider
1818
import com.amazonaws.services.s3.AmazonS3Client
1919
import com.databricks.spark.redshift.Parameters.MergedParameters
20+
import com.databricks.spark.redshift.Utils.escapeJdbcString
2021
import com.eclipsesource.json.Json
2122
import org.slf4j.LoggerFactory
2223

@@ -178,8 +179,8 @@ private[redshift] case class RedshiftRelation(
178179
val query = {
179180
// Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape
180181
// any backslashes and single quotes that appear in the query itself
181-
val escapedTableNameOrSubqury = tableNameOrSubquery.replace("\\", "\\\\").replace("'", "\\'")
182-
s"SELECT $columnList FROM $escapedTableNameOrSubqury $whereClause"
182+
s"SELECT $columnList FROM ${escapeJdbcString(tableNameOrSubquery)} " +
183+
s"${escapeJdbcString(whereClause)}"
183184
}
184185
// We need to remove S3 credentials from the unload path URI because they will conflict with
185186
// the credentials passed via `credsString`.

external/redshift/src/main/scala/com/databricks/spark/redshift/Utils.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,11 @@ private[redshift] object Utils {
197197
case _ => None
198198
}
199199
}
200+
201+
/**
202+
* Escapes a string, so that it can be passed as a JDBC string literal.
203+
*/
204+
def escapeJdbcString(s: String): String = {
205+
s.replace("\\", "\\\\").replace("'", "\\'")
206+
}
200207
}

external/redshift/src/test/scala/com/databricks/spark/redshift/FilterPushdownSuite.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,8 @@ class FilterPushdownSuite extends SparkFunSuite {
3232
test("buildWhereClause with string literals that contain Unicode characters") {
3333
// scalastyle:off
3434
val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_string", "Unicode's樂趣")))
35-
// Here, the apostrophe in the string needs to be replaced with two single quotes, '', but we
36-
// also need to escape those quotes with backslashes because this WHERE clause is going to
37-
// eventually be embedded inside of a single-quoted string that's embedded inside of a larger
38-
// Redshift query.
39-
assert(whereClause === """WHERE "test_string" = \'Unicode\'\'s樂趣\'""")
35+
// Here, the apostrophe in the string needs to be replaced with two single quotes, ''.
36+
assert(whereClause === """WHERE "test_string" = 'Unicode''s樂趣'""")
4037
// scalastyle:on
4138
}
4239

@@ -57,7 +54,7 @@ class FilterPushdownSuite extends SparkFunSuite {
5754
val expectedWhereClause =
5855
"""
5956
|WHERE "test_bool" = true
60-
|AND "test_string" = \'Unicode是樂趣\'
57+
|AND "test_string" = 'Unicode是樂趣'
6158
|AND "test_double" > 1000.0
6259
|AND "test_double" < 1.7976931348623157E308
6360
|AND "test_float" >= 1.0

external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class RedshiftSourceSuite
144144
|1|f|2015-07-02|0|0.0|42|1239012341823719|-13|asdf|2015-07-02 00:00:00.0
145145
|0||2015-07-03|0.0|-1.0|4141214|1239012341823719||f|2015-07-03 00:00:00
146146
|0|f||-1234152.12312498|100000.0||1239012341823719|24|___\|_123|
147+
|||||||||Ba\\ckslash\\|
147148
||||||||||
148149
""".stripMargin.trim
149150
// scalastyle:on

external/redshift/src/test/scala/com/databricks/spark/redshift/TestUtils.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ object TestUtils {
5252
1239012341823719L, null, "f", TestUtils.toTimestamp(2015, 6, 3, 0, 0, 0)),
5353
Row(0.toByte, false, null, -1234152.12312498, 100000.0f, null, 1239012341823719L, 24.toShort,
5454
"___|_123", null),
55-
Row(List.fill(10)(null): _*))
55+
Row(List.fill(10)(null): _*),
56+
Row(null, null, null, null, null, null, null, null, "Ba\\ckslash\\", null)
57+
)
5658
// scalastyle:on
5759

5860
/**

0 commit comments

Comments
 (0)