Skip to content

Commit 49faf60

Browse files
rsotn-maprekrivokonmapr
authored andcommitted
MapR [SPARK-469] Fix NPE in generated classes by reverting "[SPARK-23466][SQL] Remove redundant null checks in generated Java code by GenerateUnsafeProjection" (apache#455)
This reverts commit c5583fd.
1 parent a3ce6ef commit 49faf60

File tree

3 files changed

+33
-117
lines changed

3 files changed

+33
-117
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 30 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ import org.apache.spark.sql.types._
3232
*/
3333
object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {
3434

35-
case class Schema(dataType: DataType, nullable: Boolean)
36-
3735
/** Returns true iff we support this data type. */
3836
def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match {
3937
case NullType => true
@@ -45,21 +43,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
4543
case _ => false
4644
}
4745

46+
// TODO: if the nullability of field is correct, we can use it to save null check.
4847
private def writeStructToBuffer(
4948
ctx: CodegenContext,
5049
input: String,
5150
index: String,
52-
schemas: Seq[Schema],
51+
fieldTypes: Seq[DataType],
5352
rowWriter: String): String = {
5453
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
5554
val tmpInput = ctx.freshName("tmpInput")
56-
val fieldEvals = schemas.zipWithIndex.map { case (Schema(dt, nullable), i) =>
57-
val isNull = if (nullable) {
58-
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)")
59-
} else {
60-
FalseLiteral
61-
}
62-
ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
55+
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
56+
ExprCode(
57+
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"),
58+
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
6359
}
6460

6561
val rowWriterClass = classOf[UnsafeRowWriter].getName
@@ -74,7 +70,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
7470
| // Remember the current cursor so that we can calculate how many bytes are
7571
| // written later.
7672
| final int $previousCursor = $rowWriter.cursor();
77-
| ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, structRowWriter)}
73+
| ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)}
7874
| $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
7975
|}
8076
""".stripMargin
@@ -84,7 +80,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
8480
ctx: CodegenContext,
8581
row: String,
8682
inputs: Seq[ExprCode],
87-
schemas: Seq[Schema],
83+
inputTypes: Seq[DataType],
8884
rowWriter: String,
8985
isTopLevel: Boolean = false): String = {
9086
val resetWriter = if (isTopLevel) {
@@ -102,8 +98,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
10298
s"$rowWriter.resetRowWriter();"
10399
}
104100

105-
val writeFields = inputs.zip(schemas).zipWithIndex.map {
106-
case ((input, Schema(dataType, nullable)), index) =>
101+
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
102+
case ((input, dataType), index) =>
107103
val dt = UserDefinedType.sqlType(dataType)
108104

109105
val setNull = dt match {
@@ -114,7 +110,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
114110
}
115111

116112
val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
117-
if (!nullable) {
113+
if (input.isNull == FalseLiteral) {
118114
s"""
119115
|${input.code}
120116
|${writeField.trim}
@@ -147,11 +143,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
147143
""".stripMargin
148144
}
149145

146+
// TODO: if the nullability of array element is correct, we can use it to save null check.
150147
private def writeArrayToBuffer(
151148
ctx: CodegenContext,
152149
input: String,
153150
elementType: DataType,
154-
containsNull: Boolean,
155151
rowWriter: String): String = {
156152
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
157153
val tmpInput = ctx.freshName("tmpInput")
@@ -174,18 +170,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
174170

175171
val element = CodeGenerator.getValue(tmpInput, et, index)
176172

177-
val elementAssignment = if (containsNull) {
178-
s"""
179-
|if ($tmpInput.isNullAt($index)) {
180-
| $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
181-
|} else {
182-
| ${writeElement(ctx, element, index, et, arrayWriter)}
183-
|}
184-
""".stripMargin
185-
} else {
186-
writeElement(ctx, element, index, et, arrayWriter)
187-
}
188-
189173
s"""
190174
|final ArrayData $tmpInput = $input;
191175
|if ($tmpInput instanceof UnsafeArrayData) {
@@ -195,31 +179,30 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
195179
| $arrayWriter.initialize($numElements);
196180
|
197181
| for (int $index = 0; $index < $numElements; $index++) {
198-
| $elementAssignment
182+
| if ($tmpInput.isNullAt($index)) {
183+
| $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
184+
| } else {
185+
| ${writeElement(ctx, element, index, et, arrayWriter)}
186+
| }
199187
| }
200188
|}
201189
""".stripMargin
202190
}
203191

192+
// TODO: if the nullability of value element is correct, we can use it to save null check.
204193
private def writeMapToBuffer(
205194
ctx: CodegenContext,
206195
input: String,
207196
index: String,
208197
keyType: DataType,
209198
valueType: DataType,
210-
valueContainsNull: Boolean,
211199
rowWriter: String): String = {
212200
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
213201
val tmpInput = ctx.freshName("tmpInput")
214202
val tmpCursor = ctx.freshName("tmpCursor")
215203
val previousCursor = ctx.freshName("previousCursor")
216204

217205
// Writes out unsafe map according to the format described in `UnsafeMapData`.
218-
val keyArray = writeArrayToBuffer(
219-
ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter)
220-
val valueArray = writeArrayToBuffer(
221-
ctx, s"$tmpInput.valueArray()", valueType, valueContainsNull, rowWriter)
222-
223206
s"""
224207
|final MapData $tmpInput = $input;
225208
|if ($tmpInput instanceof UnsafeMapData) {
@@ -236,15 +219,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
236219
| // Remember the current cursor so that we can write numBytes of key array later.
237220
| final int $tmpCursor = $rowWriter.cursor();
238221
|
239-
| $keyArray
222+
| ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)}
240223
|
241224
| // Write the numBytes of key array into the first 8 bytes.
242225
| Platform.putLong(
243226
| $rowWriter.getBuffer(),
244227
| $tmpCursor - 8,
245228
| $rowWriter.cursor() - $tmpCursor);
246229
|
247-
| $valueArray
230+
| ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)}
248231
| $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
249232
|}
250233
""".stripMargin
@@ -257,21 +240,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
257240
dt: DataType,
258241
writer: String): String = dt match {
259242
case t: StructType =>
260-
writeStructToBuffer(
261-
ctx, input, index, t.map(e => Schema(e.dataType, e.nullable)), writer)
243+
writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer)
262244

263-
case ArrayType(et, en) =>
245+
case ArrayType(et, _) =>
264246
val previousCursor = ctx.freshName("previousCursor")
265247
s"""
266248
|// Remember the current cursor so that we can calculate how many bytes are
267249
|// written later.
268250
|final int $previousCursor = $writer.cursor();
269-
|${writeArrayToBuffer(ctx, input, et, en, writer)}
251+
|${writeArrayToBuffer(ctx, input, et, writer)}
270252
|$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
271253
""".stripMargin
272254

273-
case MapType(kt, vt, vn) =>
274-
writeMapToBuffer(ctx, input, index, kt, vt, vn, writer)
255+
case MapType(kt, vt, _) =>
256+
writeMapToBuffer(ctx, input, index, kt, vt, writer)
275257

276258
case DecimalType.Fixed(precision, scale) =>
277259
s"$writer.write($index, $input, $precision, $scale);"
@@ -286,11 +268,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
286268
expressions: Seq[Expression],
287269
useSubexprElimination: Boolean = false): ExprCode = {
288270
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
289-
val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))
271+
val exprTypes = expressions.map(_.dataType)
290272

291-
val numVarLenFields = exprSchemas.count {
292-
case Schema(dt, _) => !UnsafeRow.isFixedLength(dt)
273+
val numVarLenFields = exprTypes.count {
274+
case dt if UnsafeRow.isFixedLength(dt) => false
293275
// TODO: consider large decimal and interval type
276+
case _ => true
294277
}
295278

296279
val rowWriterClass = classOf[UnsafeRowWriter].getName
@@ -301,7 +284,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
301284
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
302285

303286
val writeExpressions = writeExpressionsToBuffer(
304-
ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true)
287+
ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
305288

306289
val code =
307290
code"""

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
694694
|""".stripMargin
695695
val jsonSchema = new StructType()
696696
.add("a", LongType, nullable = false)
697-
.add("b", StringType, nullable = !forceJsonNullableSchema)
697+
.add("b", StringType, nullable = false)
698698
.add("c", StringType, nullable = false)
699699
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
700700
val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala

Lines changed: 2 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions.BoundReference
23-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData}
24-
import org.apache.spark.sql.types._
23+
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
24+
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
2525
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2626

2727
class GenerateUnsafeProjectionSuite extends SparkFunSuite {
@@ -33,41 +33,6 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite {
3333
assert(!result.isNullAt(0))
3434
assert(result.getStruct(0, 1).isNullAt(0))
3535
}
36-
37-
test("Test unsafe projection for array/map/struct") {
38-
val dataType1 = ArrayType(StringType, false)
39-
val exprs1 = BoundReference(0, dataType1, nullable = false) :: Nil
40-
val projection1 = GenerateUnsafeProjection.generate(exprs1)
41-
val result1 = projection1.apply(AlwaysNonNull)
42-
assert(!result1.isNullAt(0))
43-
assert(!result1.getArray(0).isNullAt(0))
44-
assert(!result1.getArray(0).isNullAt(1))
45-
assert(!result1.getArray(0).isNullAt(2))
46-
47-
val dataType2 = MapType(StringType, StringType, false)
48-
val exprs2 = BoundReference(0, dataType2, nullable = false) :: Nil
49-
val projection2 = GenerateUnsafeProjection.generate(exprs2)
50-
val result2 = projection2.apply(AlwaysNonNull)
51-
assert(!result2.isNullAt(0))
52-
assert(!result2.getMap(0).keyArray.isNullAt(0))
53-
assert(!result2.getMap(0).keyArray.isNullAt(1))
54-
assert(!result2.getMap(0).keyArray.isNullAt(2))
55-
assert(!result2.getMap(0).valueArray.isNullAt(0))
56-
assert(!result2.getMap(0).valueArray.isNullAt(1))
57-
assert(!result2.getMap(0).valueArray.isNullAt(2))
58-
59-
val dataType3 = (new StructType)
60-
.add("a", StringType, nullable = false)
61-
.add("b", StringType, nullable = false)
62-
.add("c", StringType, nullable = false)
63-
val exprs3 = BoundReference(0, dataType3, nullable = false) :: Nil
64-
val projection3 = GenerateUnsafeProjection.generate(exprs3)
65-
val result3 = projection3.apply(InternalRow(AlwaysNonNull))
66-
assert(!result3.isNullAt(0))
67-
assert(!result3.getStruct(0, 1).isNullAt(0))
68-
assert(!result3.getStruct(0, 2).isNullAt(0))
69-
assert(!result3.getStruct(0, 3).isNullAt(0))
70-
}
7136
}
7237

7338
object AlwaysNull extends InternalRow {
@@ -94,35 +59,3 @@ object AlwaysNull extends InternalRow {
9459
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
9560
private def notSupported: Nothing = throw new UnsupportedOperationException
9661
}
97-
98-
object AlwaysNonNull extends InternalRow {
99-
private def stringToUTF8Array(stringArray: Array[String]): ArrayData = {
100-
val utf8Array = stringArray.map(s => UTF8String.fromString(s)).toArray
101-
ArrayData.toArrayData(utf8Array)
102-
}
103-
override def numFields: Int = 1
104-
override def setNullAt(i: Int): Unit = {}
105-
override def copy(): InternalRow = this
106-
override def anyNull: Boolean = notSupported
107-
override def isNullAt(ordinal: Int): Boolean = notSupported
108-
override def update(i: Int, value: Any): Unit = notSupported
109-
override def getBoolean(ordinal: Int): Boolean = notSupported
110-
override def getByte(ordinal: Int): Byte = notSupported
111-
override def getShort(ordinal: Int): Short = notSupported
112-
override def getInt(ordinal: Int): Int = notSupported
113-
override def getLong(ordinal: Int): Long = notSupported
114-
override def getFloat(ordinal: Int): Float = notSupported
115-
override def getDouble(ordinal: Int): Double = notSupported
116-
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
117-
override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test")
118-
override def getBinary(ordinal: Int): Array[Byte] = notSupported
119-
override def getInterval(ordinal: Int): CalendarInterval = notSupported
120-
override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
121-
override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3"))
122-
val keyArray = stringToUTF8Array(Array("1", "2", "3"))
123-
val valueArray = stringToUTF8Array(Array("a", "b", "c"))
124-
override def getMap(ordinal: Int): MapData = new ArrayBasedMapData(keyArray, valueArray)
125-
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
126-
private def notSupported: Nothing = throw new UnsupportedOperationException
127-
128-
}

0 commit comments

Comments
 (0)