From d44179403857fafeddd8110c4ce6961a5231ef0c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 24 Nov 2017 15:45:17 +0000 Subject: [PATCH 01/26] update API of addMutableState --- .../sql/catalyst/expressions/Expression.scala | 3 +- .../MonotonicallyIncreasingID.scala | 6 +- .../expressions/SparkPartitionID.scala | 3 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 22 ++++-- .../codegen/GenerateMutableProjection.scala | 29 ++++---- .../codegen/GenerateUnsafeProjection.scala | 19 +++-- .../expressions/conditionalExpressions.scala | 3 +- .../expressions/datetimeExpressions.scala | 31 ++++----- .../sql/catalyst/expressions/generators.scala | 8 +-- .../expressions/nullExpressions.scala | 3 +- .../expressions/objects/objects.scala | 50 +++++++------- .../expressions/randomExpressions.scala | 6 +- .../expressions/regexpExpressions.scala | 38 +++++----- .../expressions/stringExpressions.scala | 20 ++---- .../sql/execution/ColumnarBatchScan.scala | 26 +++---- .../sql/execution/DataSourceScanExec.scala | 6 +- .../apache/spark/sql/execution/SortExec.scala | 17 ++--- .../sql/execution/WholeStageCodegenExec.scala | 3 +- .../aggregate/HashAggregateExec.scala | 41 +++++------ .../aggregate/HashMapGenerator.scala | 6 +- .../execution/basicPhysicalOperators.scala | 69 +++++++++---------- .../columnar/GenerateColumnAccessor.scala | 3 +- .../joins/BroadcastHashJoinExec.scala | 12 ++-- .../execution/joins/SortMergeJoinExec.scala | 19 +++-- .../apache/spark/sql/execution/limit.scala | 6 +- 26 files changed, 200 insertions(+), 255 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 743782a6453e9..4568714933095 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,8 +119,7 @@ abstract class Expression extends TreeNode[Expression] { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { - val globalIsNull = ctx.freshName("globalIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) + val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = globalIsNull s"$globalIsNull = $localIsNull;" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 821d784a01342..784eaf8195194 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -65,10 +65,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.freshName("count") - val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm) - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm) + val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") + val partitionMaskTerm = ctx.addMutableState(ctx.JAVA_LONG, "partitionMask") ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 4fa18d6b3209b..736ca37c6d54a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -43,8 +43,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override protected def evalInternal(input: InternalRow): Int = partitionId override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm) + val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId") ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1893eec22b65d..d3a8cb5804717 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,8 +602,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.freshName("leastTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull") val evals = evalChildren.map(eval => s""" |${eval.code} @@ -683,8 +682,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.freshName("greatestTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull") val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3a03a65e1af92..37b5ad3665732 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -163,11 +163,24 @@ class CodegenContext { * the list of default imports available. * Also, generic type arguments are accepted but ignored. * @param variableName Name of the field. - * @param initCode The statement(s) to put into the init() method to initialize this field. + * @param codeFunctions Function includes statement(s) to put into the init() method to + * initialize this field. An argument is the name of the mutable state variable * If left blank, the field will be default-initialized. + * @param inline whether the declaration and initialization code may be inlined rather than + * compacted. If true, the name is not changed + * @return the name of the mutable state variable, which is either the original name if the + * variable is inlined to the class, or an array access if the variable is to be stored + * in an array of variables of the same type and initialization. */ - def addMutableState(javaType: String, variableName: String, initCode: String = ""): Unit = { - mutableStates += ((javaType, variableName, initCode)) + def addMutableState( + javaType: String, + variableName: String, + codeFunctions: String => String = _ => "", + inline: Boolean = false): String = { + val newVariableName = if (!inline) freshName(variableName) else variableName + val initCode = codeFunctions(newVariableName) + mutableStates += ((javaType, newVariableName, initCode)) + newVariableName } /** @@ -176,8 +189,7 @@ class CodegenContext { * data types like: UTF8String, ArrayData, MapData & InternalRow. */ def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { - val value = freshName(variableName) - addMutableState(javaType(dataType), value, "") + val value = addMutableState(javaType(dataType), variableName) val code = dataType match { case StringType => s"$value = $initCode.clone();" case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index bd8312eb8b7fe..4aa8a9f455377 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -61,37 +61,34 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case (ev, i) => val e = expressions(i) if (e.nullable) { - val isNull = s"isNull_$i" - val value = s"value_$i" - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull, s"$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"$value = ${ctx.defaultValue(e.dataType)};") - s""" + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, s"isNull_$i", v => s"$v = true;") + val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i", + v => s"$v = ${ctx.defaultValue(e.dataType)};") + (s""" ${ev.code} $isNull = ${ev.isNull}; $value = ${ev.value}; - """ + """, isNull, value, i) } else { - val value = s"value_$i" - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"$value = ${ctx.defaultValue(e.dataType)};") - s""" + val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i", + v => s"$v = ${ctx.defaultValue(e.dataType)};") + (s""" ${ev.code} $value = ${ev.value}; - """ + """, ev.isNull, value, i) } } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(index).map { - case (e, i) => - val ev = ExprCode("", s"isNull_$i", s"value_$i") + val updates = validExpr.zip(projectionCodes).map { + case (e, (_, isNull, value, i)) => + val ev = ExprCode("", isNull, value) ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes) + val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) val codeBody = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index b022457865d50..36ffa8dcdd2b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -73,9 +73,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro bufferHolder: String, isTopLevel: Boolean = false): String = { val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, - s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -186,9 +185,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.freshName("arrayWriter") - ctx.addMutableState(arrayWriterClass, arrayWriter, - s"$arrayWriter = new $arrayWriterClass();") + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -318,13 +316,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") + val result = ctx.addMutableState("UnsafeRow", "result", + v => s"$v = new UnsafeRow(${expressions.length});") - val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName - ctx.addMutableState(holderClass, holder, - s"$holder = new $holderClass($result, ${numVarLenFields * 32});") + val holder = ctx.addMutableState(holderClass, "holder", + v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 53c3b226895ec..1a9b68222a7f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -190,8 +190,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - val tmpResult = ctx.freshName("caseWhenTmpResult") - ctx.addMutableState(ctx.javaType(dataType), tmpResult) + val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult") // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 44d54a20844a3..c48b92beed788 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -442,9 +442,9 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName - val c = ctx.freshName("cal") val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(cal, c, s"""$c = $cal.getInstance($dtu.getTimeZone("UTC"));""") + val c = ctx.addMutableState(cal, "cal", + v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); ${ev.value} = $c.get($cal.DAY_OF_WEEK); @@ -484,13 +484,12 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName - val c = ctx.freshName("cal") val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(cal, c, - s""" - $c = $cal.getInstance($dtu.getTimeZone("UTC")); - $c.setFirstDayOfWeek($cal.MONDAY); - $c.setMinimalDaysInFirstWeek(4); + val c = ctx.addMutableState(cal, "cal", + v => s""" + $v = $cal.getInstance($dtu.getTimeZone("UTC")); + $v.setFirstDayOfWeek($cal.MONDAY); + $v.setMinimalDaysInFirstWeek(4); """) s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); @@ -1014,12 +1013,12 @@ case class FromUTCTimestamp(left: Expression, right: Expression) |long ${ev.value} = 0; """.stripMargin) } else { - val tzTerm = ctx.freshName("tz") - val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTerm = ctx.addMutableState(tzClass, "tz", + v => s"""$v = $dtu.getTimeZone("$tz");""") + val utcTerm = ctx.addMutableState(tzClass, "utc", + v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} @@ -1190,12 +1189,12 @@ case class ToUTCTimestamp(left: Expression, right: Expression) |long ${ev.value} = 0; """.stripMargin) } else { - val tzTerm = ctx.freshName("tz") - val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTerm = ctx.addMutableState(tzClass, "tz", + v => s"""$v = $dtu.getTimeZone("$tz");""") + val utcTerm = ctx.addMutableState(tzClass, "utc", + v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index cd38783a731ad..aa2c72ceb85b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -199,8 +199,8 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. - val rowData = ctx.freshName("rows") - ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") + val rowData = ctx.addMutableState("InternalRow[]", "rows", + v => s"$v = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { row => @@ -212,12 +212,12 @@ case class Stack(children: Seq[Expression]) extends Generator { s"${eval.code}\n$rowData[$row] = ${eval.value};" }) - // Create the collection. + // Create the collection. Inline to outer class. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ctx.addMutableState( s"$wrapperClass", ev.value, - s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);") + _ => s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);", inline = true) ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 294cdcb2e9546..b4f895fffda38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tmpIsNull = ctx.freshName("coalesceTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull") // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4bd395eadcf19..5c20944e65686 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -62,15 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression { def prepareArguments(ctx: CodegenContext): (String, String, String) = { val resultIsNull = if (needNullCheck) { - val resultIsNull = ctx.freshName("resultIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, resultIsNull) + val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") resultIsNull } else { "false" } val argValues = arguments.map { e => - val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue) + val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") argValue } @@ -548,7 +546,7 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue) + ctx.addMutableState(elementJavaType, loopValue, inline = true) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -644,7 +642,7 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, inline = true) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -808,10 +806,10 @@ case class CatalystToExternalMap private( val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] val keyElementJavaType = ctx.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue) + ctx.addMutableState(keyElementJavaType, keyLoopValue, inline = true) val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue) + ctx.addMutableState(valueElementJavaType, valueLoopValue, inline = true) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -844,7 +842,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, inline = true) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { "" @@ -994,8 +992,8 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState(keyElementJavaType, key) - ctx.addMutableState(valueElementJavaType, value) + ctx.addMutableState(keyElementJavaType, key, inline = true) + ctx.addMutableState(valueElementJavaType, value, inline = true) val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => @@ -1031,14 +1029,14 @@ case class ExternalMapToCatalyst private( } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, inline = true) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, inline = true) s"$valueIsNull = $value == null;" } else { "" @@ -1148,7 +1146,6 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. - val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { if (kryo) { (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) @@ -1159,14 +1156,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializerInit = s""" - if ($env == null) { - $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + val serializer = ctx.addMutableState(serializerInstanceClass, "serializer", + v => s""" + if ($env == null) { + $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); } else { - $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); } - """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + """) // Code to serialize. val input = child.genCode(ctx) @@ -1194,7 +1191,6 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. - val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { if (kryo) { (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) @@ -1205,14 +1201,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializerInit = s""" - if ($env == null) { - $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + val serializer = ctx.addMutableState(serializerInstanceClass, "serializer", + v => s""" + if ($env == null) { + $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); } else { - $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); } - """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + """) // Code to deserialize. val input = child.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index b4aefe6cff73e..8bc936fcbfc31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -77,9 +77,8 @@ case class Rand(child: Expression) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm) + val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" @@ -112,9 +111,8 @@ case class Randn(child: Expression) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm) + val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 53d7096dd87d3..f4657797fdcc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -112,15 +112,14 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - val pattern = ctx.freshName("pattern") if (right.foldable) { val rVal = right.eval() if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") + val pattern = ctx.addMutableState(patternClass, "patternLike", + v => s"""$v = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -139,6 +138,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi """) } } else { + val pattern = ctx.freshName("pattern") val rightStr = ctx.freshName("rightStr") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" @@ -187,15 +187,14 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName - val pattern = ctx.freshName("pattern") if (right.foldable) { val rVal = right.eval() if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") + val pattern = ctx.addMutableState(patternClass, "patternRLike", + v => s"""$v = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -215,6 +214,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } } else { val rightStr = ctx.freshName("rightStr") + val pattern = ctx.freshName("pattern") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = ${eval2}.toString(); @@ -316,23 +316,19 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - - val termLastReplacement = ctx.freshName("lastReplacement") - val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") - val termResult = ctx.freshName("termResult") - val classNamePattern = classOf[Pattern].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName val matcher = ctx.freshName("matcher") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState("UTF8String", - termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex", v => s"$v = null;") + val termPattern = ctx.addMutableState(classNamePattern, "pattern", v => s"$v = null;") + val termLastReplacement = ctx.addMutableState("String", "lastReplacement", + v => s"$v = null;") + val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8", + v => s"$v = null;") + val termResult = ctx.addMutableState(classNameStringBuffer, "result", + v => s"$v = new $classNameStringBuffer();") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" @@ -414,14 +410,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex", v => s"$v = null;") + val termPattern = ctx.addMutableState(classNamePattern, "pattern", v => s"$v = null;") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8c4d2fd686be5..60e7c99afb8f9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -532,14 +532,11 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termLastMatching = ctx.freshName("lastMatching") - val termLastReplace = ctx.freshName("lastReplace") - val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") - ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") + val termLastMatching = ctx.addMutableState("UTF8String", "lastMatching", v => s"$v = null;") + val termLastReplace = ctx.addMutableState("UTF8String", "lastReplace", v => s"$v = null;") + val termDict = ctx.addMutableState(classNameDict, "dict", v => s"$v = null;") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { @@ -2065,15 +2062,12 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val lastDValue = ctx.freshName("lastDValue") - val pattern = ctx.freshName("pattern") - val numberFormat = ctx.freshName("numberFormat") val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - ctx.addMutableState(ctx.JAVA_INT, lastDValue, s"$lastDValue = -100;") - ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") - ctx.addMutableState(df, numberFormat, - s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") + val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;") + val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") + val numberFormat = ctx.addMutableState(df, "numberFormat", + v => s"""$v = new $df("", new $dfs($l.$usLocale));""") s""" if ($d >= 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index a9bfb634fbdea..fed07df42b37e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -68,30 +68,26 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { */ // TODO: return ColumnarBatch.Rows instead override protected def doProduce(ctx: CodegenContext): String = { - val input = ctx.freshName("input") // PhysicalRDD always just has one input - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", + v => s"$v = inputs[0];") // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.freshName("scanTime") - ctx.addMutableState(ctx.JAVA_LONG, scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime", v => s"$v = 0;") val columnarBatchClz = classOf[ColumnarBatch].getName - val batch = ctx.freshName("batch") - ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + val batch = ctx.addMutableState(columnarBatchClz, "batch", v => s"$v = null;") - val idx = ctx.freshName("batchIdx") - ctx.addMutableState(ctx.JAVA_INT, idx, s"$idx = 0;") - val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) + val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx", v => s"$v = 0;") val columnVectorClzs = vectorTypes.getOrElse( - Seq.fill(colVars.size)(classOf[ColumnVector].getName)) - val columnAssigns = colVars.zip(columnVectorClzs).zipWithIndex.map { - case ((name, columnVectorClz), i) => - ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = ($columnVectorClz) $batch.column($i);" - } + Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) + val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { + case (columnVectorClz, i) => + val name = ctx.addMutableState(columnVectorClz, s"colInstance$i", v => s"$v = null;") + (name, s"$name = ($columnVectorClz) $batch.column($i);") + }.unzip val nextBatch = ctx.freshName("nextBatch") val nextBatchFuncName = ctx.addNewFunction(nextBatch, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 747749bc72e66..4c3b1c49f703d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -110,8 +110,7 @@ case class RowDataSourceScanExec( override protected def doProduce(ctx: CodegenContext): String = { val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -353,8 +352,7 @@ case class FileSourceScanExec( } val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") val row = ctx.freshName("row") ctx.INPUT_ROW = row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index c0e21343ae623..9b05127942419 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -133,20 +133,17 @@ case class SortExec( override def needStopCheck: Boolean = false override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.freshName("needToSort") - ctx.addMutableState(ctx.JAVA_BOOLEAN, needToSort, s"$needToSort = true;") + val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) - sorterVariable = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, - s"$sorterVariable = $thisPlan.createSorter();") - val metrics = ctx.freshName("metrics") - ctx.addMutableState(classOf[TaskMetrics].getName, metrics, - s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") - val sortedIterator = ctx.freshName("sortedIter") - ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter", + v => s"$v = $thisPlan.createSorter();") + val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", + v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();") + val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter", + _ => "") val addToSorter = ctx.freshName("addToSorter") val addToSorterFuncName = ctx.addNewFunction(addToSorter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 7166b7771e4db..94c7e6a4899e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -282,9 +282,8 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp } override def doProduce(ctx: CodegenContext): String = { - val input = ctx.freshName("input") // Right now, InputAdapter is only used when there is one input RDD. - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") val row = ctx.freshName("row") s""" | while ($input.hasNext() && !stopEarly()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9cadd13999e72..54f2a2e2b392a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,8 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") + val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg", v => s"$v = false;") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -187,10 +186,8 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.freshName("bufIsNull") - val value = ctx.freshName("bufValue") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) - ctx.addMutableState(ctx.javaType(e.dataType), value) + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -568,8 +565,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") + val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg", v => s"$v = false;") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -583,42 +579,38 @@ case class HashAggregateExec( val thisPlan = ctx.addReferenceObj("plan", this) // Create a name for the iterator from the fast hash map. - val iterTermForFastHashMap = ctx.freshName("fastHashMapIter") - if (isFastHashMapEnabled) { + val iterTermForFastHashMap = if (isFastHashMapEnabled) { // Generates the fast hash map class and creates the fash hash map term. - fastHashMapTerm = ctx.freshName("fastHashMap") val fastHashMapClassName = ctx.freshName("FastHashMap") if (isVectorizedHashMapEnabled) { val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) - ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, - s"$fastHashMapTerm = new $fastHashMapClassName();") - ctx.addMutableState(s"java.util.Iterator", iterTermForFastHashMap) + fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap", + v => s"$v = new $fastHashMapClassName();") + ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter") } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) - ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, - s"$fastHashMapTerm = new $fastHashMapClassName(" + + fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap", + v => s"$v = new $fastHashMapClassName(" + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", - iterTermForFastHashMap) + "fastHashMapIter") } } // Create a name for the iterator from the regular hash map. - val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm) + val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, "mapIter") // create hashMap - hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") - sorterTerm = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm) + hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap", + v => s"$v = $thisPlan.createHashMap();") + sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter") val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") @@ -758,8 +750,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.freshName("fallbackCounter") - ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") + val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter", v => s"$v = 0;") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 85b4529501ea8..1c613b19c4ab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -46,10 +46,8 @@ abstract class HashMapGenerator( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.freshName("bufIsNull") - val value = ctx.freshName("bufValue") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) - ctx.addMutableState(ctx.javaType(e.dataType), value) + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index c9a15147e30d0..2d7b9b6441625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -279,29 +279,29 @@ case class SampleExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val sampler = ctx.freshName("sampler") if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") - val initSamplerFuncName = ctx.addNewFunction(initSampler, - s""" - | private void $initSampler() { - | $sampler = new $samplerClass($upperBound - $lowerBound, false); - | java.util.Random random = new java.util.Random(${seed}L); - | long randomSeed = random.nextLong(); - | int loopCount = 0; - | while (loopCount < partitionIndex) { - | randomSeed = random.nextLong(); - | loopCount += 1; - | } - | $sampler.setSeed(randomSeed); - | } - """.stripMargin.trim) - - ctx.addMutableState(s"$samplerClass", sampler, - s"$initSamplerFuncName();") + val sampler = ctx.addMutableState(s"$samplerClass", "sampleReplace", + v => { + val initSamplerFuncName = ctx.addNewFunction(initSampler, + s""" + | private void $initSampler() { + | $v = new $samplerClass($upperBound - $lowerBound, false); + | java.util.Random random = new java.util.Random(${seed}L); + | long randomSeed = random.nextLong(); + | int loopCount = 0; + | while (loopCount < partitionIndex) { + | randomSeed = random.nextLong(); + | loopCount += 1; + | } + | $v.setSeed(randomSeed); + | } + """.stripMargin.trim) + s"$initSamplerFuncName();" + }) val samplingCount = ctx.freshName("samplingCount") s""" @@ -313,10 +313,10 @@ case class SampleExec( """.stripMargin.trim } else { val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName - ctx.addMutableState(s"$samplerClass", sampler, - s""" - | $sampler = new $samplerClass($lowerBound, $upperBound, false); - | $sampler.setSeed(${seed}L + partitionIndex); + val sampler = ctx.addMutableState(s"$samplerClass", "sampler", + v => s""" + | $v = new $samplerClass($lowerBound, $upperBound, false); + | $v.setSeed(${seed}L + partitionIndex); """.stripMargin.trim) s""" @@ -363,20 +363,17 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.freshName("initRange") - ctx.addMutableState(ctx.JAVA_BOOLEAN, initTerm, s"$initTerm = false;") - val number = ctx.freshName("number") - ctx.addMutableState(ctx.JAVA_LONG, number, s"$number = 0L;") + val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange", v => s"$v = false;") + val number = ctx.addMutableState(ctx.JAVA_LONG, "number", v => s"$v = 0L;") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName - val taskContext = ctx.freshName("taskContext") - ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();") - val inputMetrics = ctx.freshName("inputMetrics") - ctx.addMutableState("InputMetrics", inputMetrics, - s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();") + val taskContext = ctx.addMutableState("TaskContext", "taskContext", + v => s"$v = TaskContext.get();") + val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", + v => s"$v = $taskContext.taskMetrics().inputMetrics();") // In order to periodically update the metrics without inflicting performance penalty, this // operator produces elements in batches. After a batch is complete, the metrics are updated @@ -386,12 +383,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.freshName("batchEnd") - ctx.addMutableState(ctx.JAVA_LONG, batchEnd, s"$batchEnd = 0;") + val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd", v => s"$v = 0;") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.freshName("numElementsTodo") - ctx.addMutableState(ctx.JAVA_LONG, numElementsTodo, s"$numElementsTodo = 0L;") + val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo", v => s"$v = 0L;") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") @@ -440,9 +435,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } """.stripMargin) - val input = ctx.freshName("input") // Right now, Range is only used when there is one upstream. - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", + v => s"$v = inputs[0];") val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ff5dd707f0b38..4f28eeb725cbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -70,7 +70,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val ctx = newCodeGenContext() val numFields = columnTypes.size val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => - val accessorName = ctx.freshName("accessor") val accessorCls = dt match { case NullType => classOf[NullColumnAccessor].getName case BooleanType => classOf[BooleanColumnAccessor].getName @@ -89,7 +88,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case array: ArrayType => classOf[ArrayColumnAccessor].getName case t: MapType => classOf[MapColumnAccessor].getName } - ctx.addMutableState(accessorCls, accessorName) + val accessorName = ctx.addMutableState(accessorCls, "accessor") val createCode = dt match { case t if ctx.isPrimitiveType(dt) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index c96ed6ef41016..670a06ce1962a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -134,18 +134,16 @@ case class BroadcastHashJoinExec( // create a name for HashedRelation val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) - val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName // At the end of the task, we update the avg hash probe. val avgHashProbe = metricTerm(ctx, "avgHashProbe") - val addTaskListener = genTaskListener(avgHashProbe, relationTerm) - ctx.addMutableState(clsName, relationTerm, - s""" - | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.estimatedSize()); - | $addTaskListener + val relationTerm = ctx.addMutableState(clsName, "relation", + v => s""" + | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); + | incPeakExecutionMemory($v.estimatedSize()); + | ${genTaskListener(avgHashProbe, v)} """.stripMargin) (broadcastRelation, relationTerm) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 554b73181116c..677a65fac9190 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -422,10 +422,8 @@ case class SortMergeJoinExec( */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. - val leftRow = ctx.freshName("leftRow") - ctx.addMutableState("InternalRow", leftRow) - val rightRow = ctx.freshName("rightRow") - ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") + val leftRow = ctx.addMutableState("InternalRow", "leftRow") + val rightRow = ctx.addMutableState("InternalRow", "rightRow", v => s"$v = null;") // Create variables for join keys from both sides. val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) @@ -436,14 +434,13 @@ case class SortMergeJoinExec( val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) // A list to hold all matched rows from right side. - val matches = ctx.freshName("matches") val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold - ctx.addMutableState(clsName, matches, - s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);") + val matches = ctx.addMutableState(clsName, "matches", + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -578,10 +575,10 @@ case class SortMergeJoinExec( override def needCopyResult: Boolean = true override def doProduce(ctx: CodegenContext): String = { - val leftInput = ctx.freshName("leftInput") - ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") - val rightInput = ctx.freshName("rightInput") - ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];") + val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", + v => s"$v = inputs[0];") + val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", + v => s"$v = inputs[1];") val (leftRow, matches) = genScanner(ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index a8556f6ba107a..f6625cb7790db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -71,8 +71,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.freshName("stopEarly") - ctx.addMutableState(ctx.JAVA_BOOLEAN, stopEarly, s"$stopEarly = false;") + val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly", v => s"$v = false;") ctx.addNewFunction("stopEarly", s""" @Override @@ -80,8 +79,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.freshName("count") - ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") + val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count", v => s"$v = 0;") s""" | if ($countTerm < $limit) { | $countTerm += 1; From 870d106da7dabbade7811cc494dfabb82b1d0259 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 25 Nov 2017 17:28:21 +0000 Subject: [PATCH 02/26] eliminate initialization with default value --- .../catalyst/expressions/regexpExpressions.scala | 14 ++++++-------- .../catalyst/expressions/stringExpressions.scala | 6 +++--- .../spark/sql/execution/ColumnarBatchScan.scala | 8 ++++---- .../execution/aggregate/HashAggregateExec.scala | 6 +++--- .../sql/execution/basicPhysicalOperators.scala | 6 +++--- .../sql/execution/joins/SortMergeJoinExec.scala | 2 +- .../org/apache/spark/sql/execution/limit.scala | 4 ++-- 7 files changed, 22 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f4657797fdcc9..2e46d4dade44a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -321,12 +321,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val matcher = ctx.freshName("matcher") - val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex", v => s"$v = null;") - val termPattern = ctx.addMutableState(classNamePattern, "pattern", v => s"$v = null;") - val termLastReplacement = ctx.addMutableState("String", "lastReplacement", - v => s"$v = null;") - val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8", - v => s"$v = null;") + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") + val termLastReplacement = ctx.addMutableState("String", "lastReplacement") + val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8") val termResult = ctx.addMutableState(classNameStringBuffer, "result", v => s"$v = new $classNameStringBuffer();") @@ -414,8 +412,8 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") - val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex", v => s"$v = null;") - val termPattern = ctx.addMutableState(classNamePattern, "pattern", v => s"$v = null;") + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 60e7c99afb8f9..cf062e83d3b8b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -534,9 +534,9 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - val termLastMatching = ctx.addMutableState("UTF8String", "lastMatching", v => s"$v = null;") - val termLastReplace = ctx.addMutableState("UTF8String", "lastReplace", v => s"$v = null;") - val termDict = ctx.addMutableState(classNameDict, "dict", v => s"$v = null;") + val termLastMatching = ctx.addMutableState("UTF8String", "lastMatching") + val termLastReplace = ctx.addMutableState("UTF8String", "lastReplace") + val termDict = ctx.addMutableState(classNameDict, "dict") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index fed07df42b37e..8e85b40d1a7ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -75,17 +75,17 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime", v => s"$v = 0;") + val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") val columnarBatchClz = classOf[ColumnarBatch].getName - val batch = ctx.addMutableState(columnarBatchClz, "batch", v => s"$v = null;") + val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx", v => s"$v = 0;") + val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { case (columnVectorClz, i) => - val name = ctx.addMutableState(columnVectorClz, s"colInstance$i", v => s"$v = null;") + val name = ctx.addMutableState(columnVectorClz, s"colInstance$i") (name, s"$name = ($columnVectorClz) $batch.column($i);") }.unzip diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 54f2a2e2b392a..e528f744772fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,7 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg", v => s"$v = false;") + val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -565,7 +565,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg", v => s"$v = false;") + val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -750,7 +750,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter", v => s"$v = 0;") + val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2d7b9b6441625..105d367e2d232 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -363,7 +363,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange", v => s"$v = false;") + val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") val number = ctx.addMutableState(ctx.JAVA_LONG, "number", v => s"$v = 0L;") val value = ctx.freshName("value") @@ -383,10 +383,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd", v => s"$v = 0;") + val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo", v => s"$v = 0L;") + val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 677a65fac9190..d52cb7824c4d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -423,7 +423,7 @@ case class SortMergeJoinExec( private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. val leftRow = ctx.addMutableState("InternalRow", "leftRow") - val rightRow = ctx.addMutableState("InternalRow", "rightRow", v => s"$v = null;") + val rightRow = ctx.addMutableState("InternalRow", "rightRow") // Create variables for join keys from both sides. val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index f6625cb7790db..d0ef969faf5f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -71,7 +71,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly", v => s"$v = false;") + val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") ctx.addNewFunction("stopEarly", s""" @Override @@ -79,7 +79,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count", v => s"$v = 0;") + val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") s""" | if ($countTerm < $limit) { | $countTerm += 1; From 24d7087fb8f831ecabdca1f491a5211bb7dde970 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 26 Nov 2017 15:21:34 +0000 Subject: [PATCH 03/26] allocate a global Java array to store a lot of mutable state in a class --- .../expressions/codegen/CodeGenerator.scala | 114 ++++++++++++++++-- .../codegen/GeneratedProjectionSuite.scala | 51 ++++++++ 2 files changed, 156 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 37b5ad3665732..33f7d53f5867c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -154,6 +154,21 @@ class CodegenContext { val mutableStates: mutable.ArrayBuffer[(String, String, String)] = mutable.ArrayBuffer.empty[(String, String, String)] + // An array keyed by the tuple of mutable states' types and initialization code, holds the + // current max index of the array + var mutableStateArrayIdx: mutable.Map[(String, String), Int] = + mutable.Map.empty[(String, String), Int] + + // An array keyed by the tuple of mutable states' types and initialization code, holds the + // current name of the mutableStateArray into which state of the given key will be compacted + var mutableStateArrayCurrentNames: mutable.Map[(String, String), String] = + mutable.Map.empty[(String, String), String] + + // An array keyed by the tuple of mutable states' types, array names and initialization code, + // holds the code that will initialize the mutableStateArray when initialized in loops + var mutableStateArrayInitCodes: mutable.ArrayBuffer[(String, String, String)] = + mutable.ArrayBuffer.empty[(String, String, String)] + /** * Add a mutable state as a field to the generated class. c.f. the comments above. * @@ -169,18 +184,59 @@ class CodegenContext { * @param inline whether the declaration and initialization code may be inlined rather than * compacted. If true, the name is not changed * @return the name of the mutable state variable, which is either the original name if the - * variable is inlined to the class, or an array access if the variable is to be stored - * in an array of variables of the same type and initialization. + * variable is inlined to the outer class, or an array access if the variable is to be + * stored in an array of variables of the same type and initialization. + * primitive type variables will be inlined into outer class when the total number of + * mutable variables is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * the max size of an array for compaction is given by + * `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. */ def addMutableState( javaType: String, variableName: String, codeFunctions: String => String = _ => "", inline: Boolean = false): String = { - val newVariableName = if (!inline) freshName(variableName) else variableName - val initCode = codeFunctions(newVariableName) - mutableStates += ((javaType, newVariableName, initCode)) - newVariableName + val varName = if (!inline) freshName(variableName) else variableName + val initCode = codeFunctions(varName) + + if (inline || + // want to put a primitive type variable at outerClass for performance + isPrimitiveType(javaType) && + (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) || + // identify no simply-assigned object + !isPrimitiveType(javaType) && !initCode.matches("(^[\\w_]+\\d+\\s*=\\s*null;|" + + "^[\\w_]+\\d+\\s*=\\s*new\\s*[\\w\\.]+\\(\\);$|" + + "^$)")) { + // primitive type or non-simply-assigned state is declared inline to the outer class + mutableStates += ((javaType, varName, initCode)) + varName + } else { + // Create an initialization code agnostic to the actual variable name which we can key by + val initCodeKey = initCode.replaceAll(varName, "*VALUE*") + + val arrayName = mutableStateArrayCurrentNames.getOrElse((javaType, initCodeKey), "") + val prevIdx = mutableStateArrayIdx.getOrElse((javaType, arrayName), -1) + if (0 <= prevIdx && prevIdx < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT - 1) { + // a mutableStateArray for the given type and initialization has already been declared, + // update the max index of the array and return an array element + val idx = prevIdx + 1 + mutableStateArrayIdx.update((javaType, arrayName), idx) + s"$arrayName[$idx]" + } else { + // mutableStateArray has not been declared yet for the given type and initialization code. + // Create a new name for the array, and add an entry to keep track of current array name + // for type and initialized code. In addition, type, array name, and qualified initialized + // code is stored for code generation + val arrayName = freshName("mutableStateArray") + val qualifiedInitCode = initCode.replaceAll( + varName, s"$arrayName[${CodeGenerator.INIT_LOOP_VARIABLE_NAME}]") + mutableStateArrayCurrentNames += (javaType, initCodeKey) -> arrayName + mutableStateArrayInitCodes += ((javaType, arrayName, qualifiedInitCode)) + val idx = 0 + mutableStateArrayIdx += (javaType, arrayName) -> idx + s"$arrayName[$idx]" + } + } } /** @@ -201,18 +257,46 @@ class CodegenContext { def declareMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - mutableStates.distinct.map { case (javaType, variableName, _) => + val inlinedStates = mutableStates.distinct.map { case (javaType, variableName, _) => s"private $javaType $variableName;" - }.mkString("\n") + } + + val arrayStates = mutableStateArrayInitCodes.map { case (javaType, arrayName, _) => + val length = mutableStateArrayIdx((javaType, arrayName)) + 1 + if (javaType.matches("^.*\\[\\]$")) { + // initializer had an one-dimensional array variable + val baseType = javaType.substring(0, javaType.length - 2) + s"private $javaType[] $arrayName = new $baseType[$length][];" + } else { + // initializer had a scalar variable + s"private $javaType[] $arrayName = new $javaType[$length];" + } + } + + (inlinedStates ++ arrayStates).mkString("\n") } def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. val initCodes = mutableStates.distinct.map(_._3 + "\n") + // array state is initialized in loops + val arrayInitCodes = mutableStateArrayInitCodes.map { case (javaType, arrayName, qualifiedInitCode) => + if (qualifiedInitCode == "") { + "" + } else { + val loopIdxVar = CodeGenerator.INIT_LOOP_VARIABLE_NAME + s""" + for (int $loopIdxVar = 0; $loopIdxVar < $arrayName.length; $loopIdxVar++) { + $qualifiedInitCode + } + """ + } + } + // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. - splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil) + splitExpressions(expressions = initCodes ++ arrayInitCodes, funcName = "init", arguments = Nil) } /** @@ -1177,6 +1261,18 @@ object CodeGenerator extends Logging { // class. val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + // This is the threshold for the number of global variables, whose types are primitive type or + // complex type (e.g. more than one-dimensional array), that will be placed at the outer class + val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 + + // This is the maximum number of array elements to keep global variables in one Java array + // 32767 is the maximum integer value that does not require a constant pool entry in a Java + // bytecode instruction + val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + + // This is an index variable name used in a loop for initializing global variables + val INIT_LOOP_VARIABLE_NAME = "i" + /** * Compile the Java source code into a Java class, using Janino. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 6031bdf19e957..90ce011520ffd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -219,4 +219,55 @@ class GeneratedProjectionSuite extends SparkFunSuite { // - one is the mutableRow assert(globalVariables.length == 3) } + + test("SPARK-18016: generated projections on wider table requiring state compaction") { + val N = 6000 + val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(i === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs) + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } } From 3eb584200056a8332756af905ab81c430da867be Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 26 Nov 2017 15:44:11 +0000 Subject: [PATCH 04/26] fix scala style error --- .../expressions/codegen/CodeGenerator.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 33f7d53f5867c..c1ba9befe39b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -281,17 +281,18 @@ class CodegenContext { // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. val initCodes = mutableStates.distinct.map(_._3 + "\n") // array state is initialized in loops - val arrayInitCodes = mutableStateArrayInitCodes.map { case (javaType, arrayName, qualifiedInitCode) => - if (qualifiedInitCode == "") { - "" - } else { - val loopIdxVar = CodeGenerator.INIT_LOOP_VARIABLE_NAME - s""" - for (int $loopIdxVar = 0; $loopIdxVar < $arrayName.length; $loopIdxVar++) { - $qualifiedInitCode - } - """ - } + val arrayInitCodes = mutableStateArrayInitCodes.map { + case (javaType, arrayName, qualifiedInitCode) => + if (qualifiedInitCode == "") { + "" + } else { + val loopIdxVar = CodeGenerator.INIT_LOOP_VARIABLE_NAME + s""" + for (int $loopIdxVar = 0; $loopIdxVar < $arrayName.length; $loopIdxVar++) { + $qualifiedInitCode + } + """ + } } // The generated initialization code may exceed 64kb function size limit in JVM if there are too From 074d711ff0526380078b0626ea08d68a870eb140 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 27 Nov 2017 10:21:20 +0000 Subject: [PATCH 05/26] fix test failure of ExpressionEncoderSuite.NestedArray --- .../catalyst/expressions/codegen/CodeGenerator.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c1ba9befe39b3..42e87b5e9d234 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -203,10 +203,12 @@ class CodegenContext { // want to put a primitive type variable at outerClass for performance isPrimitiveType(javaType) && (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) || - // identify no simply-assigned object - !isPrimitiveType(javaType) && !initCode.matches("(^[\\w_]+\\d+\\s*=\\s*null;|" - + "^[\\w_]+\\d+\\s*=\\s*new\\s*[\\w\\.]+\\(\\);$|" - + "^$)")) { + // identify multi-dimensional array or no simply-assigned object + !isPrimitiveType(javaType) && + (javaType.contains("[][]") || + !initCode.matches("(^[\\w_]+\\d+\\s*=\\s*null;|" + + "^[\\w_]+\\d+\\s*=\\s*new\\s*[\\w\\.]+\\(\\);$|" + + "^$)"))) { // primitive type or non-simply-assigned state is declared inline to the outer class mutableStates += ((javaType, varName, initCode)) varName From eafa3f85e1ec20de4101a2fe1aeee8f5c96bee4e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 30 Nov 2017 03:05:08 +0000 Subject: [PATCH 06/26] rebase with master --- .../codegen/GenerateMutableProjection.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4aa8a9f455377..e62b98d8806a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -65,17 +65,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i", v => s"$v = ${ctx.defaultValue(e.dataType)};") (s""" - ${ev.code} - $isNull = ${ev.isNull}; - $value = ${ev.value}; - """, isNull, value, i) + ${ev.code} + $isNull = ${ev.isNull}; + $value = ${ev.value}; + """, isNull, value, i) } else { val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i", v => s"$v = ${ctx.defaultValue(e.dataType)};") (s""" - ${ev.code} - $value = ${ev.value}; - """, ev.isNull, value, i) + ${ev.code} + $value = ${ev.value}; + """, ev.isNull, value, i) } } From 90d15f3aee489bee37100ed0d53ebfd10e64ca84 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 30 Nov 2017 12:04:32 +0000 Subject: [PATCH 07/26] address review comments --- .../expressions/codegen/CodeGenerator.scala | 16 ++++++---------- .../codegen/GenerateMutableProjection.scala | 8 +++----- .../codegen/GenerateUnsafeProjection.scala | 15 +++++++++------ .../expressions/datetimeExpressions.scala | 5 +++-- .../catalyst/expressions/objects/objects.scala | 10 ++++++---- .../catalyst/expressions/regexpExpressions.scala | 10 ++++++---- 6 files changed, 33 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 42e87b5e9d234..98c7cb267f1b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -154,7 +154,7 @@ class CodegenContext { val mutableStates: mutable.ArrayBuffer[(String, String, String)] = mutable.ArrayBuffer.empty[(String, String, String)] - // An array keyed by the tuple of mutable states' types and initialization code, holds the + // An array keyed by the tuple of mutable states' types and array name, holds the // current max index of the array var mutableStateArrayIdx: mutable.Map[(String, String), Int] = mutable.Map.empty[(String, String), Int] @@ -186,7 +186,10 @@ class CodegenContext { * @return the name of the mutable state variable, which is either the original name if the * variable is inlined to the outer class, or an array access if the variable is to be * stored in an array of variables of the same type and initialization. - * primitive type variables will be inlined into outer class when the total number of + * There are two use cases. One is to use the original name for global variable instead + * of fresh name. Second is to use the original initialization statement since it is + * complex (e.g. allocate multi-dimensional array or object constructor has varibles). + * Primitive type variables will be inlined into outer class when the total number of * mutable variables is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` * the max size of an array for compaction is given by * `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. @@ -202,14 +205,7 @@ class CodegenContext { if (inline || // want to put a primitive type variable at outerClass for performance isPrimitiveType(javaType) && - (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) || - // identify multi-dimensional array or no simply-assigned object - !isPrimitiveType(javaType) && - (javaType.contains("[][]") || - !initCode.matches("(^[\\w_]+\\d+\\s*=\\s*null;|" - + "^[\\w_]+\\d+\\s*=\\s*new\\s*[\\w\\.]+\\(\\);$|" - + "^$)"))) { - // primitive type or non-simply-assigned state is declared inline to the outer class + (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)) { mutableStates += ((javaType, varName, initCode)) varName } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e62b98d8806a4..9d1105d493128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -61,17 +61,15 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case (ev, i) => val e = expressions(i) if (e.nullable) { - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, s"isNull_$i", v => s"$v = true;") - val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i", - v => s"$v = ${ctx.defaultValue(e.dataType)};") + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, s"isNull_$i") + val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i") (s""" ${ev.code} $isNull = ${ev.isNull}; $value = ${ev.value}; """, isNull, value, i) } else { - val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i", - v => s"$v = ${ctx.defaultValue(e.dataType)};") + val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i") (s""" ${ev.code} $value = ${ev.value}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 36ffa8dcdd2b6..cba12050d3aac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -73,8 +73,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro bufferHolder: String, isTopLevel: Boolean = false): String = { val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") + val rowWriter = ctx.freshName("rowWriter") + ctx.addMutableState(rowWriterClass, rowWriter, + v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});", inline = true) val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -316,12 +317,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.addMutableState("UnsafeRow", "result", - v => s"$v = new UnsafeRow(${expressions.length});") + val result = ctx.freshName("result") + ctx.addMutableState("UnsafeRow", result, + v => s"$v = new UnsafeRow(${expressions.length});", inline = true) val holderClass = classOf[BufferHolder].getName - val holder = ctx.addMutableState(holderClass, "holder", - v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") + val holder = ctx.freshName("holder") + ctx.addMutableState(holderClass, holder, + v => s"$v = new $holderClass($result, ${numVarLenFields * 32});", inline = true) val resetBufferHolder = if (numVarLenFields == 0) { "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index c48b92beed788..d6834c3f0709f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -485,12 +485,13 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = ctx.addMutableState(cal, "cal", + val c = ctx.freshName("cal") + ctx.addMutableState(cal, c, v => s""" $v = $cal.getInstance($dtu.getTimeZone("UTC")); $v.setFirstDayOfWeek($cal.MONDAY); $v.setMinimalDaysInFirstWeek(4); - """) + """, inline = true) s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); ${ev.value} = $c.get($cal.WEEK_OF_YEAR); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5c20944e65686..a12beb8eb836d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1156,14 +1156,15 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializer = ctx.addMutableState(serializerInstanceClass, "serializer", + val serializer = ctx.freshName("serializer") + ctx.addMutableState(serializerInstanceClass, serializer, v => s""" if ($env == null) { $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); } else { $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); } - """) + """, inline = true) // Code to serialize. val input = child.genCode(ctx) @@ -1201,14 +1202,15 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializer = ctx.addMutableState(serializerInstanceClass, "serializer", + val serializer = ctx.freshName("serializer") + ctx.addMutableState(serializerInstanceClass, serializer, v => s""" if ($env == null) { $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); } else { $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); } - """) + """, inline = true) // Code to deserialize. val input = child.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 2e46d4dade44a..4e4a097e340d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -118,8 +118,9 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - val pattern = ctx.addMutableState(patternClass, "patternLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""") + val pattern = ctx.freshName("patternLike") + ctx.addMutableState(patternClass, pattern, + v => s"""$v = ${patternClass}.compile("$regexStr");""", inline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -193,8 +194,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - val pattern = ctx.addMutableState(patternClass, "patternRLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""") + val pattern = ctx.freshName("patternRLike") + ctx.addMutableState(patternClass, pattern, + v => s"""$v = ${patternClass}.compile("$regexStr");""", inline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) From c456c079a344688ebd55a240ba84ed660036e261 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 30 Nov 2017 17:07:41 +0000 Subject: [PATCH 08/26] add useFreshname parameter to addMutableState --- .../expressions/codegen/CodeGenerator.scala | 8 ++++--- .../codegen/GenerateUnsafeProjection.scala | 9 +++---- .../expressions/datetimeExpressions.scala | 3 +-- .../sql/catalyst/expressions/generators.scala | 3 ++- .../expressions/objects/objects.scala | 24 +++++++++---------- .../expressions/regexpExpressions.scala | 6 ++--- .../apache/spark/sql/execution/SortExec.scala | 3 +-- .../sql/execution/WholeStageCodegenExec.scala | 3 ++- .../execution/basicPhysicalOperators.scala | 12 +++++----- .../joins/BroadcastHashJoinExec.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 4 ++-- 11 files changed, 36 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 98c7cb267f1b9..40d71fc59168d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -182,7 +182,8 @@ class CodegenContext { * initialize this field. An argument is the name of the mutable state variable * If left blank, the field will be default-initialized. * @param inline whether the declaration and initialization code may be inlined rather than - * compacted. If true, the name is not changed + * compacted. + * @param useFreshName If false and inline is true, the name is not changed * @return the name of the mutable state variable, which is either the original name if the * variable is inlined to the outer class, or an array access if the variable is to be * stored in an array of variables of the same type and initialization. @@ -198,8 +199,9 @@ class CodegenContext { javaType: String, variableName: String, codeFunctions: String => String = _ => "", - inline: Boolean = false): String = { - val varName = if (!inline) freshName(variableName) else variableName + inline: Boolean = false, + useFreshName: Boolean = true): String = { + val varName = if (useFreshName) freshName(variableName) else variableName val initCode = codeFunctions(varName) if (inline || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index cba12050d3aac..601170d385d09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -73,8 +73,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro bufferHolder: String, isTopLevel: Boolean = false): String = { val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});", inline = true) val resetWriter = if (isTopLevel) { @@ -317,13 +316,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, + val result = ctx.addMutableState("UnsafeRow", "result", v => s"$v = new UnsafeRow(${expressions.length});", inline = true) val holderClass = classOf[BufferHolder].getName - val holder = ctx.freshName("holder") - ctx.addMutableState(holderClass, holder, + val holder = ctx.addMutableState(holderClass, "holder", v => s"$v = new $holderClass($result, ${numVarLenFields * 32});", inline = true) val resetBufferHolder = if (numVarLenFields == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index d6834c3f0709f..c5cbfdae6d1d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -485,8 +485,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = ctx.freshName("cal") - ctx.addMutableState(cal, c, + val c = ctx.addMutableState(cal, "cal", v => s""" $v = $cal.getInstance($dtu.getTimeZone("UTC")); $v.setFirstDayOfWeek($cal.MONDAY); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index aa2c72ceb85b7..edab58bbd8d97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -217,7 +217,8 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState( s"$wrapperClass", ev.value, - _ => s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);", inline = true) + v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);", + inline = true, useFreshName = false) ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a12beb8eb836d..07977b126f52d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -546,7 +546,7 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue, inline = true) + ctx.addMutableState(elementJavaType, loopValue, inline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -642,7 +642,7 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, inline = true) + ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, inline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -806,10 +806,10 @@ case class CatalystToExternalMap private( val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] val keyElementJavaType = ctx.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue, inline = true) + ctx.addMutableState(keyElementJavaType, keyLoopValue, inline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue, inline = true) + ctx.addMutableState(valueElementJavaType, valueLoopValue, inline = true, useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -842,7 +842,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, inline = true) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, inline = true, useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { "" @@ -992,8 +992,8 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState(keyElementJavaType, key, inline = true) - ctx.addMutableState(valueElementJavaType, value, inline = true) + ctx.addMutableState(keyElementJavaType, key, inline = true, useFreshName = false) + ctx.addMutableState(valueElementJavaType, value, inline = true, useFreshName = false) val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => @@ -1029,14 +1029,14 @@ case class ExternalMapToCatalyst private( } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, inline = true) + ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, inline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, inline = true) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, inline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1156,8 +1156,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializer = ctx.freshName("serializer") - ctx.addMutableState(serializerInstanceClass, serializer, + val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForEncode", v => s""" if ($env == null) { $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); @@ -1202,8 +1201,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializer = ctx.freshName("serializer") - ctx.addMutableState(serializerInstanceClass, serializer, + val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode", v => s""" if ($env == null) { $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 4e4a097e340d3..07b4469e0d1a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -118,8 +118,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - val pattern = ctx.freshName("patternLike") - ctx.addMutableState(patternClass, pattern, + val pattern = ctx.addMutableState(patternClass, "patternLike", v => s"""$v = ${patternClass}.compile("$regexStr");""", inline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. @@ -194,8 +193,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - val pattern = ctx.freshName("patternRLike") - ctx.addMutableState(patternClass, pattern, + val pattern = ctx.addMutableState(patternClass, "patternRLike", v => s"""$v = ${patternClass}.compile("$regexStr");""", inline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 9b05127942419..1cc6983ab2f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -142,8 +142,7 @@ case class SortExec( v => s"$v = $thisPlan.createSorter();") val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();") - val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter", - _ => "") + val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter") val addToSorter = ctx.freshName("addToSorter") val addToSorterFuncName = ctx.addNewFunction(addToSorter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 94c7e6a4899e7..b8ad59cc3eda9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -283,7 +283,8 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp override def doProduce(ctx: CodegenContext): String = { // Right now, InputAdapter is only used when there is one input RDD. - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", + inline = true) val row = ctx.freshName("row") s""" | while ($input.hasNext() && !stopEarly()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 105d367e2d232..412f68ce9f770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -301,7 +301,7 @@ case class SampleExec( | } """.stripMargin.trim) s"$initSamplerFuncName();" - }) + }, inline = true) val samplingCount = ctx.freshName("samplingCount") s""" @@ -317,7 +317,7 @@ case class SampleExec( v => s""" | $v = new $samplerClass($lowerBound, $upperBound, false); | $v.setSeed(${seed}L + partitionIndex); - """.stripMargin.trim) + """.stripMargin.trim, inline = true) s""" | if ($sampler.sample() != 0) { @@ -364,16 +364,16 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val numOutput = metricTerm(ctx, "numOutputRows") val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") - val number = ctx.addMutableState(ctx.JAVA_LONG, "number", v => s"$v = 0L;") + val number = ctx.addMutableState(ctx.JAVA_LONG, "number") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val taskContext = ctx.addMutableState("TaskContext", "taskContext", - v => s"$v = TaskContext.get();") + v => s"$v = TaskContext.get();", inline = true) val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", - v => s"$v = $taskContext.taskMetrics().inputMetrics();") + v => s"$v = $taskContext.taskMetrics().inputMetrics();", inline = true) // In order to periodically update the metrics without inflicting performance penalty, this // operator produces elements in batches. After a batch is complete, the metrics are updated @@ -437,7 +437,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // Right now, Range is only used when there is one upstream. val input = ctx.addMutableState("scala.collection.Iterator", "input", - v => s"$v = inputs[0];") + v => s"$v = inputs[0];", inline = true) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 670a06ce1962a..f035e52e0c347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -144,7 +144,7 @@ case class BroadcastHashJoinExec( | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($v.estimatedSize()); | ${genTaskListener(avgHashProbe, v)} - """.stripMargin) + """.stripMargin, inline = true) (broadcastRelation, relationTerm) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d52cb7824c4d6..0c77f3e7c63d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -576,9 +576,9 @@ case class SortMergeJoinExec( override def doProduce(ctx: CodegenContext): String = { val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", - v => s"$v = inputs[0];") + v => s"$v = inputs[0];", inline = true) val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", - v => s"$v = inputs[1];") + v => s"$v = inputs[1];", inline = true) val (leftRow, matches) = genScanner(ctx) From 9ca5ab3094e3003f170cd216d2ec8bf3b72a9500 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 30 Nov 2017 19:38:30 +0000 Subject: [PATCH 09/26] rebase with master --- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 07b4469e0d1a9..31d8224e27b0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -316,6 +316,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val termResult = ctx.freshName("termResult") + val classNamePattern = classOf[Pattern].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName @@ -325,8 +327,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val termPattern = ctx.addMutableState(classNamePattern, "pattern") val termLastReplacement = ctx.addMutableState("String", "lastReplacement") val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8") - val termResult = ctx.addMutableState(classNameStringBuffer, "result", - v => s"$v = new $classNameStringBuffer();") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" From 5b36c615c15cbb8f6f4ea1b5cde53b6feb63d420 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 1 Dec 2017 02:01:26 +0000 Subject: [PATCH 10/26] fix test failures --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 7 ++++++- .../sql/catalyst/expressions/RegexpExpressionsSuite.scala | 5 +++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 40d71fc59168d..064fa3cf449f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -203,11 +203,16 @@ class CodegenContext { useFreshName: Boolean = true): String = { val varName = if (useFreshName) freshName(variableName) else variableName val initCode = codeFunctions(varName) + if (javaType.contains("[][]")) { + Thread.dumpStack() + } if (inline || // want to put a primitive type variable at outerClass for performance isPrimitiveType(javaType) && - (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)) { + (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) || + // type is multi-dimensional array + javaType.contains("[][]")) { mutableStates += ((javaType, varName, initCode)) varName } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 4fa61fbaf66c2..97753a87dabf5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -183,8 +183,9 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx) // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8) - // are always required - assert(ctx.mutableStates.length == 4) + // are always required, which are allocated in type-based global array + assert(ctx.mutableStates.length == 0) + assert(ctx.mutableStateArrayInitCodes.length == 3) } test("RegexExtract") { From fd51d750a0ad89770377248487149d577fe6e6fc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 7 Dec 2017 18:52:36 +0000 Subject: [PATCH 11/26] drop to creat a loop for initialization --- .../expressions/codegen/CodeGenerator.scala | 59 +++++++------------ 1 file changed, 21 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 064fa3cf449f4..e60d97c8ac0ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -159,15 +159,16 @@ class CodegenContext { var mutableStateArrayIdx: mutable.Map[(String, String), Int] = mutable.Map.empty[(String, String), Int] - // An array keyed by the tuple of mutable states' types and initialization code, holds the + // An array keyed by the tuple of mutable states' types, holds the // current name of the mutableStateArray into which state of the given key will be compacted - var mutableStateArrayCurrentNames: mutable.Map[(String, String), String] = - mutable.Map.empty[(String, String), String] + var mutableStateArrayCurrentNames: mutable.Map[String, String] = + mutable.Map.empty[String, String] - // An array keyed by the tuple of mutable states' types, array names and initialization code, - // holds the code that will initialize the mutableStateArray when initialized in loops - var mutableStateArrayInitCodes: mutable.ArrayBuffer[(String, String, String)] = - mutable.ArrayBuffer.empty[(String, String, String)] + // An array keyed by the tuple of mutable states' types, array names, array index, and + // initialization code, holds the code that will initialize the mutableStateArray when + // initialized in loops + var mutableStateArrayInitCodes: mutable.ArrayBuffer[String] = + mutable.ArrayBuffer.empty[String] /** * Add a mutable state as a field to the generated class. c.f. the comments above. @@ -202,10 +203,6 @@ class CodegenContext { inline: Boolean = false, useFreshName: Boolean = true): String = { val varName = if (useFreshName) freshName(variableName) else variableName - val initCode = codeFunctions(varName) - if (javaType.contains("[][]")) { - Thread.dumpStack() - } if (inline || // want to put a primitive type variable at outerClass for performance @@ -213,31 +210,29 @@ class CodegenContext { (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) || // type is multi-dimensional array javaType.contains("[][]")) { + val initCode = codeFunctions(varName) mutableStates += ((javaType, varName, initCode)) varName } else { - // Create an initialization code agnostic to the actual variable name which we can key by - val initCodeKey = initCode.replaceAll(varName, "*VALUE*") - - val arrayName = mutableStateArrayCurrentNames.getOrElse((javaType, initCodeKey), "") + val arrayName = mutableStateArrayCurrentNames.getOrElse(javaType, "") val prevIdx = mutableStateArrayIdx.getOrElse((javaType, arrayName), -1) if (0 <= prevIdx && prevIdx < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT - 1) { - // a mutableStateArray for the given type and initialization has already been declared, + // a mutableStateArray for the given type and name has already been declared, // update the max index of the array and return an array element val idx = prevIdx + 1 + val initCode = codeFunctions(s"$arrayName[$idx]") + mutableStateArrayInitCodes += initCode mutableStateArrayIdx.update((javaType, arrayName), idx) s"$arrayName[$idx]" } else { - // mutableStateArray has not been declared yet for the given type and initialization code. + // mutableStateArray has not been declared yet for the given type and name. // Create a new name for the array, and add an entry to keep track of current array name - // for type and initialized code. In addition, type, array name, and qualified initialized - // code is stored for code generation + // for type and initialized code. In addition, init code is stored for code generation val arrayName = freshName("mutableStateArray") - val qualifiedInitCode = initCode.replaceAll( - varName, s"$arrayName[${CodeGenerator.INIT_LOOP_VARIABLE_NAME}]") - mutableStateArrayCurrentNames += (javaType, initCodeKey) -> arrayName - mutableStateArrayInitCodes += ((javaType, arrayName, qualifiedInitCode)) + mutableStateArrayCurrentNames += javaType -> arrayName val idx = 0 + val initCode = codeFunctions(s"$arrayName[$idx]") + mutableStateArrayInitCodes += initCode mutableStateArrayIdx += (javaType, arrayName) -> idx s"$arrayName[$idx]" } @@ -266,7 +261,7 @@ class CodegenContext { s"private $javaType $variableName;" } - val arrayStates = mutableStateArrayInitCodes.map { case (javaType, arrayName, _) => + val arrayStates = mutableStateArrayIdx.keys.map { case (javaType, arrayName) => val length = mutableStateArrayIdx((javaType, arrayName)) + 1 if (javaType.matches("^.*\\[\\]$")) { // initializer had an one-dimensional array variable @@ -285,20 +280,8 @@ class CodegenContext { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. val initCodes = mutableStates.distinct.map(_._3 + "\n") - // array state is initialized in loops - val arrayInitCodes = mutableStateArrayInitCodes.map { - case (javaType, arrayName, qualifiedInitCode) => - if (qualifiedInitCode == "") { - "" - } else { - val loopIdxVar = CodeGenerator.INIT_LOOP_VARIABLE_NAME - s""" - for (int $loopIdxVar = 0; $loopIdxVar < $arrayName.length; $loopIdxVar++) { - $qualifiedInitCode - } - """ - } - } + // statements for array element initialization + val arrayInitCodes = mutableStateArrayInitCodes.distinct.map(_ + "\n") // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. From effe918f25a5681ed8e307773d58392e0773d54f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 8 Dec 2017 02:31:04 +0000 Subject: [PATCH 12/26] fix test failure --- .../spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 97753a87dabf5..a1931aa82eb3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -185,7 +185,7 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8) // are always required, which are allocated in type-based global array assert(ctx.mutableStates.length == 0) - assert(ctx.mutableStateArrayInitCodes.length == 3) + assert(ctx.mutableStateArrayInitCodes.length == 4) } test("RegexExtract") { From 9df109c7b1f0b8cd575d67205e9375bfd05ee284 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 8 Dec 2017 02:31:31 +0000 Subject: [PATCH 13/26] update comments code cleanup --- .../expressions/codegen/CodeGenerator.scala | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e60d97c8ac0ac..27b10ac3153a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -154,19 +154,17 @@ class CodegenContext { val mutableStates: mutable.ArrayBuffer[(String, String, String)] = mutable.ArrayBuffer.empty[(String, String, String)] - // An array keyed by the tuple of mutable states' types and array name, holds the - // current max index of the array + // An map keyed by the tuple of mutable states' types and array name, holds the current max + // index of the array var mutableStateArrayIdx: mutable.Map[(String, String), Int] = mutable.Map.empty[(String, String), Int] - // An array keyed by the tuple of mutable states' types, holds the - // current name of the mutableStateArray into which state of the given key will be compacted + // An map keyed by mutable states' types holds the current name of the mutableStateArray + // into which state of the given key will be compacted var mutableStateArrayCurrentNames: mutable.Map[String, String] = mutable.Map.empty[String, String] - // An array keyed by the tuple of mutable states' types, array names, array index, and - // initialization code, holds the code that will initialize the mutableStateArray when - // initialized in loops + // An array holds the code that will initialize each element of the mutableStateArray var mutableStateArrayInitCodes: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] @@ -220,21 +218,21 @@ class CodegenContext { // a mutableStateArray for the given type and name has already been declared, // update the max index of the array and return an array element val idx = prevIdx + 1 + mutableStateArrayIdx.update((javaType, arrayName), idx) val initCode = codeFunctions(s"$arrayName[$idx]") mutableStateArrayInitCodes += initCode - mutableStateArrayIdx.update((javaType, arrayName), idx) s"$arrayName[$idx]" } else { // mutableStateArray has not been declared yet for the given type and name. // Create a new name for the array, and add an entry to keep track of current array name - // for type and initialized code. In addition, init code is stored for code generation - val arrayName = freshName("mutableStateArray") - mutableStateArrayCurrentNames += javaType -> arrayName + // for type. In addition, init code is stored for code generation + val newArrayName = freshName("mutableStateArray") + mutableStateArrayCurrentNames += javaType -> newArrayName val idx = 0 - val initCode = codeFunctions(s"$arrayName[$idx]") + mutableStateArrayIdx += (javaType, newArrayName) -> idx + val initCode = codeFunctions(s"$newArrayName[$idx]") mutableStateArrayInitCodes += initCode - mutableStateArrayIdx += (javaType, arrayName) -> idx - s"$arrayName[$idx]" + s"$newArrayName[$idx]" } } } From 634d4945ba6215a5c33bbd4ece91bad0ad91d09b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 10 Dec 2017 14:34:50 +0000 Subject: [PATCH 14/26] address review comment --- .../expressions/codegen/GenerateMutableProjection.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 9d1105d493128..5dae048b43d5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -60,16 +60,15 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val projectionCodes = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) + val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") if (e.nullable) { - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, s"isNull_$i") - val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i") + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") (s""" ${ev.code} $isNull = ${ev.isNull}; $value = ${ev.value}; """, isNull, value, i) } else { - val value = ctx.addMutableState(ctx.javaType(e.dataType), s"value_$i") (s""" ${ev.code} $value = ${ev.value}; From d3438fd4a969dc98032d3c89efeff6cdd214200c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Dec 2017 08:46:06 +0000 Subject: [PATCH 15/26] address review comment --- .../codegen/GeneratedProjectionSuite.scala | 26 +------------------ 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 90ce011520ffd..3d00f377bed34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -233,32 +233,8 @@ class GeneratedProjectionSuite extends SparkFunSuite { val nestedSchema = StructType( Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) - // test generated UnsafeProjection - val unsafeProj = UnsafeProjection.create(nestedSchema) - val unsafe: UnsafeRow = unsafeProj(nested) - (0 until N).foreach { i => - val s = UTF8String.fromString(i.toString) - assert(i === unsafe.getInt(i + 2)) - assert(s === unsafe.getUTF8String(i + 2 + N)) - assert(i === unsafe.getStruct(0, N * 2).getInt(i)) - assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) - assert(i === unsafe.getStruct(1, N * 2).getInt(i)) - assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) - } - - // test generated SafeProjection val safeProj = FromUnsafeProjection(nestedSchema) - val result = safeProj(unsafe) - // Can't compare GenericInternalRow with JoinedRow directly - (0 until N).foreach { i => - val s = UTF8String.fromString(i.toString) - assert(i === result.getInt(i + 2)) - assert(s === result.getUTF8String(i + 2 + N)) - assert(i === result.getStruct(0, N * 2).getInt(i)) - assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) - assert(i === result.getStruct(1, N * 2).getInt(i)) - assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) - } + val result = safeProj(nested) // test generated MutableProjection val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => From f4f37549c719e3e8dcbd76108bbfac3d77e8107d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Dec 2017 14:25:49 +0000 Subject: [PATCH 16/26] address review comment --- .../expressions/codegen/CodeGenerator.scala | 47 +++++++++---------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 27b10ac3153a0..3a4230c47da82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -212,28 +212,28 @@ class CodegenContext { mutableStates += ((javaType, varName, initCode)) varName } else { - val arrayName = mutableStateArrayCurrentNames.getOrElse(javaType, "") - val prevIdx = mutableStateArrayIdx.getOrElse((javaType, arrayName), -1) - if (0 <= prevIdx && prevIdx < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT - 1) { - // a mutableStateArray for the given type and name has already been declared, - // update the max index of the array and return an array element - val idx = prevIdx + 1 - mutableStateArrayIdx.update((javaType, arrayName), idx) - val initCode = codeFunctions(s"$arrayName[$idx]") - mutableStateArrayInitCodes += initCode - s"$arrayName[$idx]" - } else { - // mutableStateArray has not been declared yet for the given type and name. - // Create a new name for the array, and add an entry to keep track of current array name - // for type. In addition, init code is stored for code generation - val newArrayName = freshName("mutableStateArray") - mutableStateArrayCurrentNames += javaType -> newArrayName - val idx = 0 - mutableStateArrayIdx += (javaType, newArrayName) -> idx - val initCode = codeFunctions(s"$newArrayName[$idx]") - mutableStateArrayInitCodes += initCode - s"$newArrayName[$idx]" + // mutableStateArray has not been declared yet for the given type and name. Create a new name + // for the array, The mutableStateArray for the given type and name has been declared, + // update the max index of the array. Then, add an entry to keep track of current array name + // for type and nit code is stored for code generation. Finally, return an array element + val (arrayName, newIdx) = { + val compactArrayName = "mutableStateArray" + var name = mutableStateArrayCurrentNames.getOrElse(javaType, freshName(compactArrayName)) + var idx = mutableStateArrayIdx.getOrElse((javaType, name), -1) + if (idx >= CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT - 1) { + // Create a new array name to avoid array index whose number is larger than 32767 that + // requires a constant pool entry + name = freshName(compactArrayName) + idx = -1 + } + (name, idx + 1) } + mutableStateArrayCurrentNames(javaType) = arrayName + mutableStateArrayIdx((javaType, arrayName)) = newIdx + + val initCode = codeFunctions(s"$arrayName[$newIdx]") + mutableStateArrayInitCodes += initCode + s"$arrayName[$newIdx]" } } @@ -277,7 +277,7 @@ class CodegenContext { def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStates.distinct.map(_._3 + "\n") + val initCodes = mutableStates.map(_._3).distinct.map(_ + "\n") // statements for array element initialization val arrayInitCodes = mutableStateArrayInitCodes.distinct.map(_ + "\n") @@ -1257,9 +1257,6 @@ object CodeGenerator extends Logging { // bytecode instruction val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 - // This is an index variable name used in a loop for initializing global variables - val INIT_LOOP_VARIABLE_NAME = "i" - /** * Compile the Java source code into a Java class, using Janino. * From f1e1fca57e232113a1f8f402bef0ee5cf99e798a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 13 Dec 2017 01:54:38 +0000 Subject: [PATCH 17/26] address review comments except test case --- .../expressions/codegen/CodeGenerator.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3a4230c47da82..67a1d3fa329fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -177,9 +177,9 @@ class CodegenContext { * the list of default imports available. * Also, generic type arguments are accepted but ignored. * @param variableName Name of the field. - * @param codeFunctions Function includes statement(s) to put into the init() method to - * initialize this field. An argument is the name of the mutable state variable - * If left blank, the field will be default-initialized. + * @param initFunc Function includes statement(s) to put into the init() method to initialize + * this field. An argument is the name of the mutable state variable. + * If left blank, the field will be default-initialized. * @param inline whether the declaration and initialization code may be inlined rather than * compacted. * @param useFreshName If false and inline is true, the name is not changed @@ -197,7 +197,7 @@ class CodegenContext { def addMutableState( javaType: String, variableName: String, - codeFunctions: String => String = _ => "", + initFunc: String => String = _ => "", inline: Boolean = false, useFreshName: Boolean = true): String = { val varName = if (useFreshName) freshName(variableName) else variableName @@ -208,7 +208,7 @@ class CodegenContext { (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) || // type is multi-dimensional array javaType.contains("[][]")) { - val initCode = codeFunctions(varName) + val initCode = initFunc(varName) mutableStates += ((javaType, varName, initCode)) varName } else { @@ -231,7 +231,7 @@ class CodegenContext { mutableStateArrayCurrentNames(javaType) = arrayName mutableStateArrayIdx((javaType, arrayName)) = newIdx - val initCode = codeFunctions(s"$arrayName[$newIdx]") + val initCode = initFunc(s"$arrayName[$newIdx]") mutableStateArrayInitCodes += initCode s"$arrayName[$newIdx]" } @@ -279,7 +279,7 @@ class CodegenContext { // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. val initCodes = mutableStates.map(_._3).distinct.map(_ + "\n") // statements for array element initialization - val arrayInitCodes = mutableStateArrayInitCodes.distinct.map(_ + "\n") + val arrayInitCodes = mutableStateArrayInitCodes.distinct // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. From 0937ef20a57a7b0648bd3145c383d9645c67de2c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 13 Dec 2017 08:11:20 +0000 Subject: [PATCH 18/26] rebase with master --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 67a1d3fa329fc..414daf1b9a960 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1122,9 +1122,8 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") - addMutableState(javaType(expr.dataType), value, - s"$value = ${defaultValue(expr.dataType)};") + addMutableState(JAVA_BOOLEAN, isNull, inline = true, useFreshName = false) + addMutableState(javaType(expr.dataType), value, inline = true, useFreshName = false) subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) From 4bfcc1a2e94e23b188be1560b17142db0503a0d4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 13 Dec 2017 15:05:36 +0000 Subject: [PATCH 19/26] Do not use compaction as possible for frequently-accessed variable --- .../org/apache/spark/sql/execution/ColumnarBatchScan.scala | 2 +- .../org/apache/spark/sql/execution/DataSourceScanExec.scala | 6 ++++-- .../scala/org/apache/spark/sql/execution/SortExec.scala | 3 ++- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 4 ++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 8e85b40d1a7ea..2a9bceaf086a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -70,7 +70,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { override protected def doProduce(ctx: CodegenContext): String = { // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", - v => s"$v = inputs[0];") + v => s"$v = inputs[0];", inline = true) // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 4c3b1c49f703d..45b66c71726ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -110,7 +110,8 @@ case class RowDataSourceScanExec( override protected def doProduce(ctx: CodegenContext): String = { val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", + inline = true) val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -352,7 +353,8 @@ case class FileSourceScanExec( } val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", + inline = true) val row = ctx.freshName("row") ctx.INPUT_ROW = row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 1cc6983ab2f2a..4925d6090880b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -142,7 +142,8 @@ case class SortExec( v => s"$v = $thisPlan.createSorter();") val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();") - val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter") + val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter", + inline = true) val addToSorter = ctx.freshName("addToSorter") val addToSorterFuncName = ctx.addNewFunction(addToSorter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 0c77f3e7c63d3..00444bdcf3824 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -422,8 +422,8 @@ case class SortMergeJoinExec( */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. - val leftRow = ctx.addMutableState("InternalRow", "leftRow") - val rightRow = ctx.addMutableState("InternalRow", "rightRow") + val leftRow = ctx.addMutableState("InternalRow", "leftRow", inline = true) + val rightRow = ctx.addMutableState("InternalRow", "rightRow", inline = true) // Create variables for join keys from both sides. val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) From 49119a9de53cd7bb2bf91df6b691358da22e1b00 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 13 Dec 2017 18:41:56 +0000 Subject: [PATCH 20/26] exclude mutable state from argument list for ExpressionCodegn --- .../expressions/codegen/CodeGenerator.scala | 16 ++++++++ .../expressions/CodeGenerationSuite.scala | 41 +++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 414daf1b9a960..1c35e3fd245c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -168,6 +168,22 @@ class CodegenContext { var mutableStateArrayInitCodes: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] + /** + * Return true if a given variable has been described as a global variable + */ + def isDeclaredMutableState(varName: String): Boolean = { + val j = varName.indexOf("[") + val qualifiedName = if (j < 0) varName else varName.substring(0, j) + mutableStates.exists { case s => + val i = s._2.indexOf("[") + qualifiedName == (if (i < 0) s._2 else s._2.substring(0, i)) + } || + mutableStateArrayIdx.keys.exists { case key => + val i = key._2.indexOf("[") + qualifiedName == (if (i < 0) key._2 else key._2.substring(0, i)) + } + } + /** * Add a mutable state as a field to the generated class. c.f. the comments above. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index a969811019161..96bfbdb8dd9a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -401,4 +401,45 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.addReferenceObj("foo", foo) assert(ctx.mutableStates.isEmpty) } + + test("SPARK-18016: Compact mutable states by using an array") { + val ctx1 = new CodegenContext + for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { + ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") + } + assert(ctx1.mutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + // When the number of primitive type mutable states is over the threshold, others are + // allocated into an array + assert(ctx1.mutableStateArrayIdx.size == 1) + assert(ctx1.mutableStateArrayInitCodes.size == 10) + + val ctx2 = new CodegenContext + for (i <- 1 to CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) { + ctx2.addMutableState("InternalRow[]", "r", v => s"$v = new InternalRow[$i];") + } + // When the number of non-primitive type mutable states is over the threshold, others are + // allocated into a new array + assert(ctx2.mutableStateArrayIdx.size == 2) + assert(ctx2.mutableStateArrayInitCodes.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) + } + + test("SPARK-18016: check whether a varible is declared as mutable") { + val ctx = new CodegenContext + val var1 = ctx.addMutableState(ctx.JAVA_INT, "ij") + val var2 = ctx.addMutableState("int[]", "array") + val var3 = ctx.addMutableState("int[][]", "b") + + assert(ctx.isDeclaredMutableState(var1)) + assert(ctx.isDeclaredMutableState(var2)) + assert(ctx.isDeclaredMutableState(s"$var2[1]")) + assert(ctx.isDeclaredMutableState(var3)) + assert(ctx.isDeclaredMutableState(s"$var3[]")) + assert(ctx.isDeclaredMutableState(s"$var3[1][]")) + + assert(!ctx.isDeclaredMutableState("i")) + assert(!ctx.isDeclaredMutableState("j")) + assert(!ctx.isDeclaredMutableState("ij99")) + assert(!ctx.isDeclaredMutableState("arr")) + assert(!ctx.isDeclaredMutableState("bb[]")) + } } From 24f49c5a5840517f4f6edbb50c955e758c7e2e5d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 14 Dec 2017 07:37:14 +0000 Subject: [PATCH 21/26] fix test failures --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index cf062e83d3b8b..c02c41db1668e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -291,8 +291,7 @@ case class Elt(children: Seq[Expression]) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.freshName("stringVal") - ctx.addMutableState(ctx.javaType(dataType), stringVal) + val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") val assignStringValue = strings.zipWithIndex.map { case (eval, index) => s""" From 15e967ec278abdae190bd3f0d3e937a54f674e19 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 14 Dec 2017 07:41:45 +0000 Subject: [PATCH 22/26] address review comments --- .../expressions/codegen/CodeGenerator.scala | 101 +++++++++--------- .../expressions/CodeGenerationSuite.scala | 25 +---- 2 files changed, 55 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1c35e3fd245c9..fe7b11a964cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -154,10 +154,9 @@ class CodegenContext { val mutableStates: mutable.ArrayBuffer[(String, String, String)] = mutable.ArrayBuffer.empty[(String, String, String)] - // An map keyed by the tuple of mutable states' types and array name, holds the current max - // index of the array - var mutableStateArrayIdx: mutable.Map[(String, String), Int] = - mutable.Map.empty[(String, String), Int] + // An map keyed by mutable states' types holds the status of mutableStateArray + var mutableStateArrayMap: mutable.Map[String, MutableStateArrays] = + mutable.Map.empty[String, MutableStateArrays] // An map keyed by mutable states' types holds the current name of the mutableStateArray // into which state of the given key will be compacted @@ -168,20 +167,29 @@ class CodegenContext { var mutableStateArrayInitCodes: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] - /** - * Return true if a given variable has been described as a global variable - */ - def isDeclaredMutableState(varName: String): Boolean = { - val j = varName.indexOf("[") - val qualifiedName = if (j < 0) varName else varName.substring(0, j) - mutableStates.exists { case s => - val i = s._2.indexOf("[") - qualifiedName == (if (i < 0) s._2 else s._2.substring(0, i)) - } || - mutableStateArrayIdx.keys.exists { case key => - val i = key._2.indexOf("[") - qualifiedName == (if (i < 0) key._2 else key._2.substring(0, i)) + // Holding names and current index of mutableStateArrays for a certain type + class MutableStateArrays { + val arrayNames = mutable.ListBuffer.empty[String] + createNewArray() + + private[this] var currentIndex = 0 + + private def createNewArray() = arrayNames.append(freshName("mutableStateArray")) + + def getCurrentIndex: Int = { currentIndex } + + def getNextSlot(): String = { + if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { + val res = s"${arrayNames.last}[$currentIndex]" + currentIndex += 1 + res + } else { + createNewArray() + currentIndex = 1 + s"${arrayNames.last}[0]" + } } + } /** @@ -194,8 +202,8 @@ class CodegenContext { * Also, generic type arguments are accepted but ignored. * @param variableName Name of the field. * @param initFunc Function includes statement(s) to put into the init() method to initialize - * this field. An argument is the name of the mutable state variable. - * If left blank, the field will be default-initialized. + * this field. The argument is the name of the mutable state variable. + * If left blank, the field will be default-initialized. * @param inline whether the declaration and initialization code may be inlined rather than * compacted. * @param useFreshName If false and inline is true, the name is not changed @@ -228,28 +236,16 @@ class CodegenContext { mutableStates += ((javaType, varName, initCode)) varName } else { - // mutableStateArray has not been declared yet for the given type and name. Create a new name - // for the array, The mutableStateArray for the given type and name has been declared, - // update the max index of the array. Then, add an entry to keep track of current array name - // for type and nit code is stored for code generation. Finally, return an array element - val (arrayName, newIdx) = { - val compactArrayName = "mutableStateArray" - var name = mutableStateArrayCurrentNames.getOrElse(javaType, freshName(compactArrayName)) - var idx = mutableStateArrayIdx.getOrElse((javaType, name), -1) - if (idx >= CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT - 1) { - // Create a new array name to avoid array index whose number is larger than 32767 that - // requires a constant pool entry - name = freshName(compactArrayName) - idx = -1 - } - (name, idx + 1) - } - mutableStateArrayCurrentNames(javaType) = arrayName - mutableStateArrayIdx((javaType, arrayName)) = newIdx - - val initCode = initFunc(s"$arrayName[$newIdx]") + // If mutableStateArray has not been declared yet for the given type, create a new + // name for the array, If the mutableStateArray for the given type has been declared, + // update the current index of the array. + // Then, initialization code is stored for code generation. + val arrays = mutableStateArrayMap.getOrElseUpdate(javaType, new MutableStateArrays) + val element = arrays.getNextSlot() + + val initCode = initFunc(element) mutableStateArrayInitCodes += initCode - s"$arrayName[$newIdx]" + element } } @@ -275,15 +271,22 @@ class CodegenContext { s"private $javaType $variableName;" } - val arrayStates = mutableStateArrayIdx.keys.map { case (javaType, arrayName) => - val length = mutableStateArrayIdx((javaType, arrayName)) + 1 - if (javaType.matches("^.*\\[\\]$")) { - // initializer had an one-dimensional array variable - val baseType = javaType.substring(0, javaType.length - 2) - s"private $javaType[] $arrayName = new $baseType[$length][];" - } else { - // initializer had a scalar variable - s"private $javaType[] $arrayName = new $javaType[$length];" + val arrayStates = mutableStateArrayMap.flatMap { case (javaType, mutableStateArrays) => + val numElements = mutableStateArrays.arrayNames.size + mutableStateArrays.arrayNames.zipWithIndex.map { case (arrayName, index) => + val length = if (index + 1 == numElements) { + mutableStateArrays.getCurrentIndex + } else { + CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + } + if (javaType.contains("[]")) { + // initializer had an one-dimensional array variable + val baseType = javaType.substring(0, javaType.length - 2) + s"private $javaType[] $arrayName = new $baseType[$length][];" + } else { + // initializer had a scalar variable + s"private $javaType[] $arrayName = new $javaType[$length];" + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 96bfbdb8dd9a6..5c291761e1e5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -410,7 +410,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx1.mutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) // When the number of primitive type mutable states is over the threshold, others are // allocated into an array - assert(ctx1.mutableStateArrayIdx.size == 1) + assert(ctx1.mutableStateArrayMap.get(ctx1.JAVA_INT).get.arrayNames.size == 1) assert(ctx1.mutableStateArrayInitCodes.size == 10) val ctx2 = new CodegenContext @@ -419,27 +419,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } // When the number of non-primitive type mutable states is over the threshold, others are // allocated into a new array - assert(ctx2.mutableStateArrayIdx.size == 2) + assert(ctx2.mutableStateArrayMap.get("InternalRow[]").get.arrayNames.size == 2) assert(ctx2.mutableStateArrayInitCodes.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) } - - test("SPARK-18016: check whether a varible is declared as mutable") { - val ctx = new CodegenContext - val var1 = ctx.addMutableState(ctx.JAVA_INT, "ij") - val var2 = ctx.addMutableState("int[]", "array") - val var3 = ctx.addMutableState("int[][]", "b") - - assert(ctx.isDeclaredMutableState(var1)) - assert(ctx.isDeclaredMutableState(var2)) - assert(ctx.isDeclaredMutableState(s"$var2[1]")) - assert(ctx.isDeclaredMutableState(var3)) - assert(ctx.isDeclaredMutableState(s"$var3[]")) - assert(ctx.isDeclaredMutableState(s"$var3[1][]")) - - assert(!ctx.isDeclaredMutableState("i")) - assert(!ctx.isDeclaredMutableState("j")) - assert(!ctx.isDeclaredMutableState("ij99")) - assert(!ctx.isDeclaredMutableState("arr")) - assert(!ctx.isDeclaredMutableState("bb[]")) - } } + From d6c1a97ebf98c7251ec6923a1b9f4c84ec3a4b75 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 14 Dec 2017 11:01:47 +0000 Subject: [PATCH 23/26] address review comments --- .../expressions/codegen/CodeGenerator.scala | 70 ++++++++----------- .../codegen/GenerateMutableProjection.scala | 16 ++--- .../codegen/GenerateUnsafeProjection.scala | 6 +- .../expressions/datetimeExpressions.scala | 18 ++--- .../sql/catalyst/expressions/generators.scala | 5 +- .../expressions/objects/objects.scala | 38 +++++----- .../expressions/regexpExpressions.scala | 4 +- .../expressions/CodeGenerationSuite.scala | 6 +- .../expressions/RegexpExpressionsSuite.scala | 2 +- .../sql/execution/ColumnarBatchScan.scala | 2 +- .../sql/execution/DataSourceScanExec.scala | 4 +- .../apache/spark/sql/execution/SortExec.scala | 2 +- .../sql/execution/WholeStageCodegenExec.scala | 2 +- .../execution/basicPhysicalOperators.scala | 10 +-- .../joins/BroadcastHashJoinExec.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 8 +-- 16 files changed, 91 insertions(+), 104 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fe7b11a964cfe..19bad7a8895bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -138,33 +138,24 @@ class CodegenContext { /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a - * 3-tuple: java type, variable name, code to init it. - * As an example, ("int", "count", "count = 0;") will produce code: + * 2-tuple: java type, variable name. + * As an example, ("int", "count") will produce code: * {{{ * private int count; * }}} - * as a member variable, and add - * {{{ - * count = 0; - * }}} - * to the constructor. + * as a member variable * * They will be kept as member variables in generated classes like `SpecificProjection`. */ - val mutableStates: mutable.ArrayBuffer[(String, String, String)] = - mutable.ArrayBuffer.empty[(String, String, String)] + val mutableStates: mutable.ArrayBuffer[(String, String)] = + mutable.ArrayBuffer.empty[(String, String)] // An map keyed by mutable states' types holds the status of mutableStateArray - var mutableStateArrayMap: mutable.Map[String, MutableStateArrays] = + val mutableStateArrayMap: mutable.Map[String, MutableStateArrays] = mutable.Map.empty[String, MutableStateArrays] - // An map keyed by mutable states' types holds the current name of the mutableStateArray - // into which state of the given key will be compacted - var mutableStateArrayCurrentNames: mutable.Map[String, String] = - mutable.Map.empty[String, String] - - // An array holds the code that will initialize each element of the mutableStateArray - var mutableStateArrayInitCodes: mutable.ArrayBuffer[String] = + // An array holds the code that will initialize each state + val mutableStateInitCodes: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] // Holding names and current index of mutableStateArrays for a certain type @@ -176,7 +167,7 @@ class CodegenContext { private def createNewArray() = arrayNames.append(freshName("mutableStateArray")) - def getCurrentIndex: Int = { currentIndex } + def getCurrentIndex: Int = currentIndex def getNextSlot(): String = { if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { @@ -204,8 +195,10 @@ class CodegenContext { * @param initFunc Function includes statement(s) to put into the init() method to initialize * this field. The argument is the name of the mutable state variable. * If left blank, the field will be default-initialized. - * @param inline whether the declaration and initialization code may be inlined rather than - * compacted. + * @param forceInline whether the declaration and initialization code may be inlined rather than + * compacted. Please set `true` into forceInline, if you want to access the + * status fast (e.g. frequently accessed) or if you want to use the original + * variable name * @param useFreshName If false and inline is true, the name is not changed * @return the name of the mutable state variable, which is either the original name if the * variable is inlined to the outer class, or an array access if the variable is to be @@ -222,29 +215,24 @@ class CodegenContext { javaType: String, variableName: String, initFunc: String => String = _ => "", - inline: Boolean = false, + forceInline: Boolean = false, useFreshName: Boolean = true): String = { val varName = if (useFreshName) freshName(variableName) else variableName - if (inline || - // want to put a primitive type variable at outerClass for performance - isPrimitiveType(javaType) && - (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) || - // type is multi-dimensional array - javaType.contains("[][]")) { + // want to put a primitive type variable at outerClass for performance + val canInlinePrimitive = isPrimitiveType(javaType) && + (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { val initCode = initFunc(varName) - mutableStates += ((javaType, varName, initCode)) + mutableStates += ((javaType, varName)) + mutableStateInitCodes += initCode varName } else { - // If mutableStateArray has not been declared yet for the given type, create a new - // name for the array, If the mutableStateArray for the given type has been declared, - // update the current index of the array. - // Then, initialization code is stored for code generation. val arrays = mutableStateArrayMap.getOrElseUpdate(javaType, new MutableStateArrays) val element = arrays.getNextSlot() val initCode = initFunc(element) - mutableStateArrayInitCodes += initCode + mutableStateInitCodes += initCode element } } @@ -267,14 +255,14 @@ class CodegenContext { def declareMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val inlinedStates = mutableStates.distinct.map { case (javaType, variableName, _) => + val inlinedStates = mutableStates.distinct.map { case (javaType, variableName) => s"private $javaType $variableName;" } val arrayStates = mutableStateArrayMap.flatMap { case (javaType, mutableStateArrays) => - val numElements = mutableStateArrays.arrayNames.size + val numArrays = mutableStateArrays.arrayNames.size mutableStateArrays.arrayNames.zipWithIndex.map { case (arrayName, index) => - val length = if (index + 1 == numElements) { + val length = if (index + 1 == numArrays) { mutableStateArrays.getCurrentIndex } else { CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT @@ -296,13 +284,11 @@ class CodegenContext { def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStates.map(_._3).distinct.map(_ + "\n") - // statements for array element initialization - val arrayInitCodes = mutableStateArrayInitCodes.distinct + val initCodes = mutableStateInitCodes.distinct // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. - splitExpressions(expressions = initCodes ++ arrayInitCodes, funcName = "init", arguments = Nil) + splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil) } /** @@ -1141,8 +1127,8 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState(JAVA_BOOLEAN, isNull, inline = true, useFreshName = false) - addMutableState(javaType(expr.dataType), value, inline = true, useFreshName = false) + addMutableState(JAVA_BOOLEAN, isNull, forceInline = true, useFreshName = false) + addMutableState(javaType(expr.dataType), value, forceInline = true, useFreshName = false) subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 5dae048b43d5f..eeec8cb841cd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -57,22 +57,22 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case _ => true }.unzip val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) - val projectionCodes = exprVals.zip(index).map { + val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") if (e.nullable) { val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") (s""" - ${ev.code} - $isNull = ${ev.isNull}; - $value = ${ev.value}; - """, isNull, value, i) + |${ev.code} + |$isNull = ${ev.isNull}; + |$value = ${ev.value}; + """.stripMargin, isNull, value, i) } else { (s""" - ${ev.code} - $value = ${ev.value}; - """, ev.isNull, value, i) + |${ev.code} + |$value = ${ev.value}; + """.stripMargin, ev.isNull, value, i) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 601170d385d09..36ffa8dcdd2b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -74,7 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro isTopLevel: Boolean = false): String = { val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});", inline = true) + v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -317,11 +317,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val result = ctx.addMutableState("UnsafeRow", "result", - v => s"$v = new UnsafeRow(${expressions.length});", inline = true) + v => s"$v = new UnsafeRow(${expressions.length});") val holderClass = classOf[BufferHolder].getName val holder = ctx.addMutableState(holderClass, "holder", - v => s"$v = new $holderClass($result, ${numVarLenFields * 32});", inline = true) + v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index c5cbfdae6d1d9..dea685d9cbfe4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -485,16 +485,16 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = ctx.addMutableState(cal, "cal", - v => s""" - $v = $cal.getInstance($dtu.getTimeZone("UTC")); - $v.setFirstDayOfWeek($cal.MONDAY); - $v.setMinimalDaysInFirstWeek(4); - """, inline = true) + val c = ctx.addMutableState(cal, "cal", v => + s""" + |$v = $cal.getInstance($dtu.getTimeZone("UTC")); + |$v.setFirstDayOfWeek($cal.MONDAY); + |$v.setMinimalDaysInFirstWeek(4); + """) s""" - $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.WEEK_OF_YEAR); - """ + |$c.setTimeInMillis($time * 1000L * 3600L * 24L); + |${ev.value} = $c.get($cal.WEEK_OF_YEAR); + """ }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index edab58bbd8d97..1cd73a92a8635 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -212,13 +212,12 @@ case class Stack(children: Seq[Expression]) extends Generator { s"${eval.code}\n$rowData[$row] = ${eval.value};" }) - // Create the collection. Inline to outer class. + // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ctx.addMutableState( s"$wrapperClass", ev.value, - v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);", - inline = true, useFreshName = false) + v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);", useFreshName = false) ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 07977b126f52d..468bdeacc6eed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -546,7 +546,7 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue, inline = true, useFreshName = false) + ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -642,7 +642,7 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, inline = true, useFreshName = false) + ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -806,10 +806,11 @@ case class CatalystToExternalMap private( val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] val keyElementJavaType = ctx.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue, inline = true, useFreshName = false) + ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue, inline = true, useFreshName = false) + ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, + useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -842,7 +843,8 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, inline = true, useFreshName = false) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, + useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { "" @@ -992,8 +994,8 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState(keyElementJavaType, key, inline = true, useFreshName = false) - ctx.addMutableState(valueElementJavaType, value, inline = true, useFreshName = false) + ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) + ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => @@ -1029,14 +1031,14 @@ case class ExternalMapToCatalyst private( } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, inline = true, useFreshName = false) + ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, inline = true, useFreshName = false) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1156,14 +1158,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForEncode", - v => s""" - if ($env == null) { - $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - } else { - $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - } - """, inline = true) + val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForEncode", v => + s""" + |if ($env == null) { + | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + |} else { + | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + |} + """, forceInline = true) // Code to serialize. val input = child.genCode(ctx) @@ -1208,7 +1210,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B } else { $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); } - """, inline = true) + """, forceInline = true) // Code to deserialize. val input = child.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 31d8224e27b0e..5ac8121a08532 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -119,7 +119,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) val pattern = ctx.addMutableState(patternClass, "patternLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""", inline = true) + v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -194,7 +194,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) val pattern = ctx.addMutableState(patternClass, "patternRLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""", inline = true) + v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 5c291761e1e5b..dccafff6e13a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -402,7 +402,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx.mutableStates.isEmpty) } - test("SPARK-18016: Compact mutable states by using an array") { + test("SPARK-18016: def mutable states by using an array") { val ctx1 = new CodegenContext for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") @@ -411,7 +411,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { // When the number of primitive type mutable states is over the threshold, others are // allocated into an array assert(ctx1.mutableStateArrayMap.get(ctx1.JAVA_INT).get.arrayNames.size == 1) - assert(ctx1.mutableStateArrayInitCodes.size == 10) + assert(ctx1.mutableStateInitCodes.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) val ctx2 = new CodegenContext for (i <- 1 to CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) { @@ -420,7 +420,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { // When the number of non-primitive type mutable states is over the threshold, others are // allocated into a new array assert(ctx2.mutableStateArrayMap.get("InternalRow[]").get.arrayNames.size == 2) - assert(ctx2.mutableStateArrayInitCodes.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) + assert(ctx2.mutableStateInitCodes.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index a1931aa82eb3d..c1f4a5532e10b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -185,7 +185,7 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8) // are always required, which are allocated in type-based global array assert(ctx.mutableStates.length == 0) - assert(ctx.mutableStateArrayInitCodes.length == 4) + assert(ctx.mutableStateInitCodes.length == 4) } test("RegexExtract") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 2a9bceaf086a3..7aff06ab4f550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -70,7 +70,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { override protected def doProduce(ctx: CodegenContext): String = { // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", - v => s"$v = inputs[0];", inline = true) + v => s"$v = inputs[0];", forceInline = true) // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 45b66c71726ce..377f5ff24e5d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -111,7 +111,7 @@ case class RowDataSourceScanExec( val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", - inline = true) + forceInline = true) val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -354,7 +354,7 @@ case class FileSourceScanExec( val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", - inline = true) + forceInline = true) val row = ctx.freshName("row") ctx.INPUT_ROW = row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 4925d6090880b..a8246de39132d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -143,7 +143,7 @@ case class SortExec( val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();") val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter", - inline = true) + forceInline = true) val addToSorter = ctx.freshName("addToSorter") val addToSorterFuncName = ctx.addNewFunction(addToSorter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index b8ad59cc3eda9..e123b47e4cfd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -284,7 +284,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp override def doProduce(ctx: CodegenContext): String = { // Right now, InputAdapter is only used when there is one input RDD. val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", - inline = true) + forceInline = true) val row = ctx.freshName("row") s""" | while ($input.hasNext() && !stopEarly()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 412f68ce9f770..d1bd3960fce75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -301,7 +301,7 @@ case class SampleExec( | } """.stripMargin.trim) s"$initSamplerFuncName();" - }, inline = true) + }, forceInline = true) val samplingCount = ctx.freshName("samplingCount") s""" @@ -317,7 +317,7 @@ case class SampleExec( v => s""" | $v = new $samplerClass($lowerBound, $upperBound, false); | $v.setSeed(${seed}L + partitionIndex); - """.stripMargin.trim, inline = true) + """.stripMargin.trim, forceInline = true) s""" | if ($sampler.sample() != 0) { @@ -371,9 +371,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val BigInt = classOf[java.math.BigInteger].getName val taskContext = ctx.addMutableState("TaskContext", "taskContext", - v => s"$v = TaskContext.get();", inline = true) + v => s"$v = TaskContext.get();", forceInline = true) val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", - v => s"$v = $taskContext.taskMetrics().inputMetrics();", inline = true) + v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true) // In order to periodically update the metrics without inflicting performance penalty, this // operator produces elements in batches. After a batch is complete, the metrics are updated @@ -437,7 +437,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // Right now, Range is only used when there is one upstream. val input = ctx.addMutableState("scala.collection.Iterator", "input", - v => s"$v = inputs[0];", inline = true) + v => s"$v = inputs[0];", forceInline = true) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index f035e52e0c347..e87f8b0639dc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -144,7 +144,7 @@ case class BroadcastHashJoinExec( | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($v.estimatedSize()); | ${genTaskListener(avgHashProbe, v)} - """.stripMargin, inline = true) + """.stripMargin, forceInline = true) (broadcastRelation, relationTerm) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 00444bdcf3824..c1bb5786cb0b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -422,8 +422,8 @@ case class SortMergeJoinExec( */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. - val leftRow = ctx.addMutableState("InternalRow", "leftRow", inline = true) - val rightRow = ctx.addMutableState("InternalRow", "rightRow", inline = true) + val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) + val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) // Create variables for join keys from both sides. val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) @@ -576,9 +576,9 @@ case class SortMergeJoinExec( override def doProduce(ctx: CodegenContext): String = { val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", - v => s"$v = inputs[0];", inline = true) + v => s"$v = inputs[0];", forceInline = true) val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", - v => s"$v = inputs[1];", inline = true) + v => s"$v = inputs[1];", forceInline = true) val (leftRow, matches) = genScanner(ctx) From a9d40e9665bbe5e31a1c3dbd76e675c01730b487 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 14 Dec 2017 15:21:27 +0000 Subject: [PATCH 24/26] address review comments --- .../expressions/codegen/CodeGenerator.scala | 26 +++++++++---------- .../codegen/GenerateMutableProjection.scala | 2 ++ .../expressions/datetimeExpressions.scala | 4 +-- .../expressions/objects/objects.scala | 18 ++++++------- .../ArithmeticExpressionSuite.scala | 4 +-- .../sql/catalyst/expressions/CastSuite.scala | 2 +- .../expressions/CodeGenerationSuite.scala | 20 +++++++------- .../expressions/ComplexTypeSuite.scala | 2 +- .../ConditionalExpressionSuite.scala | 2 +- .../expressions/NullExpressionsSuite.scala | 2 +- .../catalyst/expressions/PredicateSuite.scala | 4 +-- .../expressions/RegexpExpressionsSuite.scala | 4 +-- .../catalyst/expressions/ScalaUDFSuite.scala | 2 +- .../codegen/GeneratedProjectionSuite.scala | 4 +-- .../optimizer/complexTypesSuite.scala | 2 +- .../sql/execution/ColumnarBatchScan.scala | 6 ++--- .../sql/execution/DataSourceScanExec.scala | 6 ++--- .../apache/spark/sql/execution/limit.scala | 4 +-- 18 files changed, 58 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 19bad7a8895bd..e9db1e315f229 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -147,15 +147,15 @@ class CodegenContext { * * They will be kept as member variables in generated classes like `SpecificProjection`. */ - val mutableStates: mutable.ArrayBuffer[(String, String)] = + val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = mutable.ArrayBuffer.empty[(String, String)] // An map keyed by mutable states' types holds the status of mutableStateArray - val mutableStateArrayMap: mutable.Map[String, MutableStateArrays] = + val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = mutable.Map.empty[String, MutableStateArrays] // An array holds the code that will initialize each state - val mutableStateInitCodes: mutable.ArrayBuffer[String] = + val mutableStateInitCode: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] // Holding names and current index of mutableStateArrays for a certain type @@ -202,7 +202,7 @@ class CodegenContext { * @param useFreshName If false and inline is true, the name is not changed * @return the name of the mutable state variable, which is either the original name if the * variable is inlined to the outer class, or an array access if the variable is to be - * stored in an array of variables of the same type and initialization. + * stored in an array of variables of the same type. * There are two use cases. One is to use the original name for global variable instead * of fresh name. Second is to use the original initialization statement since it is * complex (e.g. allocate multi-dimensional array or object constructor has varibles). @@ -217,22 +217,22 @@ class CodegenContext { initFunc: String => String = _ => "", forceInline: Boolean = false, useFreshName: Boolean = true): String = { - val varName = if (useFreshName) freshName(variableName) else variableName // want to put a primitive type variable at outerClass for performance val canInlinePrimitive = isPrimitiveType(javaType) && - (mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { + val varName = if (useFreshName) freshName(variableName) else variableName val initCode = initFunc(varName) - mutableStates += ((javaType, varName)) - mutableStateInitCodes += initCode + inlinedMutableStates += ((javaType, varName)) + mutableStateInitCode += initCode varName } else { - val arrays = mutableStateArrayMap.getOrElseUpdate(javaType, new MutableStateArrays) + val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays) val element = arrays.getNextSlot() val initCode = initFunc(element) - mutableStateInitCodes += initCode + mutableStateInitCode += initCode element } } @@ -255,11 +255,11 @@ class CodegenContext { def declareMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val inlinedStates = mutableStates.distinct.map { case (javaType, variableName) => + val inlinedStates = inlinedMutableStates.distinct.map { case (javaType, variableName) => s"private $javaType $variableName;" } - val arrayStates = mutableStateArrayMap.flatMap { case (javaType, mutableStateArrays) => + val arrayStates = arrayCompactedMutableStates.flatMap { case (javaType, mutableStateArrays) => val numArrays = mutableStateArrays.arrayNames.size mutableStateArrays.arrayNames.zipWithIndex.map { case (arrayName, index) => val length = if (index + 1 == numArrays) { @@ -284,7 +284,7 @@ class CodegenContext { def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStateInitCodes.distinct + val initCodes = mutableStateInitCode.distinct // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index eeec8cb841cd6..b53c0087e7e2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -57,6 +57,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case _ => true }.unzip val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + + // 4-tuples: (code for projection, isNull variable name, value variable name, column index) val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index dea685d9cbfe4..cfec7f82951a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -490,11 +490,11 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa |$v = $cal.getInstance($dtu.getTimeZone("UTC")); |$v.setFirstDayOfWeek($cal.MONDAY); |$v.setMinimalDaysInFirstWeek(4); - """) + """.stripMargin) s""" |$c.setTimeInMillis($time * 1000L * 3600L * 24L); |${ev.value} = $c.get($cal.WEEK_OF_YEAR); - """ + """.stripMargin }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 468bdeacc6eed..a59aad5be8715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1165,7 +1165,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) |} else { | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); |} - """, forceInline = true) + """.stripMargin) // Code to serialize. val input = child.genCode(ctx) @@ -1203,14 +1203,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode", - v => s""" - if ($env == null) { - $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - } else { - $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - } - """, forceInline = true) + val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode", v => + s""" + |if ($env == null) { + | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + |} else { + | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + |} + """.stripMargin) // Code to deserialize. val input = child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index be638d80e45d8..6edb4348f8309 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -348,10 +348,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("SPARK-22704: Least and greatest use less global variables") { val ctx1 = new CodegenContext() Least(Seq(Literal(1), Literal(1))).genCode(ctx1) - assert(ctx1.mutableStates.size == 1) + assert(ctx1.inlinedMutableStates.size == 1) val ctx2 = new CodegenContext() Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) - assert(ctx2.mutableStates.size == 1) + assert(ctx2.inlinedMutableStates.size == 1) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 65617be05a434..1dd040e4696a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -851,6 +851,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext cast("1", IntegerType).genCode(ctx) cast("2", LongType).genCode(ctx) - assert(ctx.mutableStates.length == 0) + assert(ctx.inlinedMutableStates.length == 0) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index dccafff6e13a1..27878f54f24e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -385,33 +385,33 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext val schema = new StructType().add("a", IntegerType).add("b", StringType) CreateExternalRow(Seq(Literal(1), Literal("x")), schema).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } test("SPARK-22696: InitializeJavaBean should not use global variables") { val ctx = new CodegenContext InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), Map("add" -> Literal(1))).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } test("SPARK-22716: addReferenceObj should not add mutable states") { val ctx = new CodegenContext val foo = new Object() ctx.addReferenceObj("foo", foo) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } - test("SPARK-18016: def mutable states by using an array") { + test("SPARK-18016: define mutable states by using an array") { val ctx1 = new CodegenContext for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") } - assert(ctx1.mutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) // When the number of primitive type mutable states is over the threshold, others are // allocated into an array - assert(ctx1.mutableStateArrayMap.get(ctx1.JAVA_INT).get.arrayNames.size == 1) - assert(ctx1.mutableStateInitCodes.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) + assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1) + assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) val ctx2 = new CodegenContext for (i <- 1 to CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) { @@ -419,8 +419,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } // When the number of non-primitive type mutable states is over the threshold, others are // allocated into a new array - assert(ctx2.mutableStateArrayMap.get("InternalRow[]").get.arrayNames.size == 2) - assert(ctx2.mutableStateInitCodes.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) + assert(ctx2.inlinedMutableStates.isEmpty) + assert(ctx2.arrayCompactedMutableStates.get("InternalRow[]").get.arrayNames.size == 2) + assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex == 10) + assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 6dfca7d73a3df..84190f0bd5f7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -304,6 +304,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22693: CreateNamedStruct should not use global variables") { val ctx = new CodegenContext CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 60d84aae1fa3f..a099119732e25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -150,6 +150,6 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("SPARK-22705: case when should use less global variables") { val ctx = new CodegenContext() CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx) - assert(ctx.mutableStates.size == 1) + assert(ctx.inlinedMutableStates.size == 1) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index a23cd95632770..cc6c15cb2c909 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -159,7 +159,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22705: Coalesce should use less global variables") { val ctx = new CodegenContext() Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx) - assert(ctx.mutableStates.size == 1) + assert(ctx.inlinedMutableStates.size == 1) } test("AtLeastNNonNulls should not throw 64kb exception") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 15cb0bea08f17..8a8f8e10225fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -249,7 +249,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22705: In should use less global variables") { val ctx = new CodegenContext() In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } test("INSET") { @@ -440,6 +440,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22693: InSet should not use global variables") { val ctx = new CodegenContext InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index c1f4a5532e10b..2a0a42c65b086 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -184,8 +184,8 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx) // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8) // are always required, which are allocated in type-based global array - assert(ctx.mutableStates.length == 0) - assert(ctx.mutableStateInitCodes.length == 4) + assert(ctx.inlinedMutableStates.length == 0) + assert(ctx.mutableStateInitCode.length == 4) } test("RegexExtract") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 70dea4b39d55d..10e3ffd0dff97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -51,6 +51,6 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 3d00f377bed34..2c45b3b0c73d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -222,10 +222,10 @@ class GeneratedProjectionSuite extends SparkFunSuite { test("SPARK-18016: generated projections on wider table requiring state compaction") { val N = 6000 - val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) + val wideRow1 = new GenericInternalRow(new Array[Any](N)) val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) val wideRow2 = new GenericInternalRow( - (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + Array.tabulate[Any](N)(i => UTF8String.fromString(i.toString))) val schema2 = StructType((1 to N).map(i => StructField("", StringType))) val joined = new JoinedRow(wideRow1, wideRow2) val joinedSchema = StructType(schema1 ++ schema2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index e3675367d78e4..0d11958876ce9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -168,7 +168,7 @@ class ComplexTypesSuite extends PlanTest{ test("SPARK-22570: CreateArray should not create a lot of global variables") { val ctx = new CodegenContext CreateArray(Seq(Literal(1))).genCode(ctx) - assert(ctx.mutableStates.length == 0) + assert(ctx.inlinedMutableStates.length == 0) } test("simplify map ops") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 7aff06ab4f550..782cec5e292ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -70,17 +70,17 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { override protected def doProduce(ctx: CodegenContext): String = { // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", - v => s"$v = inputs[0];", forceInline = true) + v => s"$v = inputs[0];") // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") + val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0 val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") + val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 377f5ff24e5d0..4c3b1c49f703d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -110,8 +110,7 @@ case class RowDataSourceScanExec( override protected def doProduce(ctx: CodegenContext): String = { val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", - forceInline = true) + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -353,8 +352,7 @@ case class FileSourceScanExec( } val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", - forceInline = true) + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") val row = ctx.freshName("row") ctx.INPUT_ROW = row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index d0ef969faf5f3..c168637fc9768 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -71,7 +71,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") + val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = 0 ctx.addNewFunction("stopEarly", s""" @Override @@ -79,7 +79,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") + val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1; From 31914c0f768a6aab24bed22b0c15b6376b6d070f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 15 Dec 2017 13:36:47 +0000 Subject: [PATCH 25/26] address review comments --- .../expressions/codegen/CodeGenerator.scala | 50 ++++++++++++------- .../expressions/CodeGenerationSuite.scala | 1 - .../apache/spark/sql/execution/SortExec.scala | 4 +- .../apache/spark/sql/execution/limit.scala | 2 +- 4 files changed, 34 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e9db1e315f229..6851525e15d33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -137,7 +137,7 @@ class CodegenContext { var currentVars: Seq[ExprCode] = null /** - * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a + * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a * 2-tuple: java type, variable name. * As an example, ("int", "count") will produce code: * {{{ @@ -150,7 +150,11 @@ class CodegenContext { val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = mutable.ArrayBuffer.empty[(String, String)] - // An map keyed by mutable states' types holds the status of mutableStateArray + /** + * The mapping between mutable state types and corrseponding compacted arrays. + * The keys are java type string. The values are [[MutableStateArrays]] which encapsulates + * the compacted arrays for the mutable states with the same java type. + */ val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = mutable.Map.empty[String, MutableStateArrays] @@ -158,7 +162,10 @@ class CodegenContext { val mutableStateInitCode: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] - // Holding names and current index of mutableStateArrays for a certain type + /** + * This class holds a set of names of mutableStateArrays that is used for compacting mutable + * states for a certain type, and holds the next available slot of the current compacted array. + */ class MutableStateArrays { val arrayNames = mutable.ListBuffer.empty[String] createNewArray() @@ -169,6 +176,11 @@ class CodegenContext { def getCurrentIndex: Int = currentIndex + /** + * Returns the reference of next available slot in current compacted array. The size of each + * compacted array is controlled by the config `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * Once reaching the threshold, new compacted array is created. + */ def getNextSlot(): String = { if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { val res = s"${arrayNames.last}[$currentIndex]" @@ -199,17 +211,19 @@ class CodegenContext { * compacted. Please set `true` into forceInline, if you want to access the * status fast (e.g. frequently accessed) or if you want to use the original * variable name - * @param useFreshName If false and inline is true, the name is not changed - * @return the name of the mutable state variable, which is either the original name if the - * variable is inlined to the outer class, or an array access if the variable is to be - * stored in an array of variables of the same type. - * There are two use cases. One is to use the original name for global variable instead - * of fresh name. Second is to use the original initialization statement since it is - * complex (e.g. allocate multi-dimensional array or object constructor has varibles). - * Primitive type variables will be inlined into outer class when the total number of - * mutable variables is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` - * the max size of an array for compaction is given by - * `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * @param useFreshName If this is false and forceInline is true, the name is not changed + * @return the name of the mutable state variable, which is the original name or fresh name if + * the variable is inlined to the outer class, or an array access if the variable is to + * be stored in an array of variables of the same type. + * A variable will be inlined into the outer class when one of the following conditions + * are satisfied: + * 1. forceInline is true + * 2. its type is primitive type and the total number of the inlined mutable variables + * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * 3. its type is multi-dimensional array + * A primitive type variable will be inlined into outer class when the total number of + * When a variable is compacted into an array, the max size of the array for compaction + * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. */ def addMutableState( javaType: String, @@ -1099,9 +1113,9 @@ class CodegenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach { e => val expr = e.head - val fnName = freshName("evalExpr") - val isNull = s"${fnName}IsNull" - val value = s"${fnName}Value" + val fnName = freshName("subExpr") + val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + val value = addMutableState(javaType(expr.dataType), "subExprValue") // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) @@ -1127,8 +1141,6 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState(JAVA_BOOLEAN, isNull, forceInline = true, useFreshName = false) - addMutableState(javaType(expr.dataType), value, forceInline = true, useFreshName = false) subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 27878f54f24e3..b1a44528e64d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -425,4 +425,3 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index a8246de39132d..996598055b1ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -139,9 +139,9 @@ case class SortExec( // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter", - v => s"$v = $thisPlan.createSorter();") + v => s"$v = $thisPlan.createSorter();", forceInline = true) val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", - v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();") + v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();", forceInline = true) val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter", forceInline = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index c168637fc9768..cccee63bc0680 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -71,7 +71,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = 0 + val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false ctx.addNewFunction("stopEarly", s""" @Override From 0e45c19059e8973be995cf5058f50324986256e6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 19 Dec 2017 12:13:01 +0000 Subject: [PATCH 26/26] address review comments --- .../catalyst/expressions/codegen/CodeGenerator.scala | 11 ++++++----- .../sql/catalyst/expressions/regexpExpressions.scala | 2 ++ .../org/apache/spark/sql/execution/SortExec.scala | 1 + .../spark/sql/execution/WholeStageCodegenExec.scala | 1 + .../sql/execution/aggregate/HashAggregateExec.scala | 7 +++++-- .../spark/sql/execution/basicPhysicalOperators.scala | 8 +++----- .../sql/execution/joins/BroadcastHashJoinExec.scala | 1 + .../spark/sql/execution/joins/SortMergeJoinExec.scala | 2 ++ 8 files changed, 21 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6851525e15d33..41a920ba3d677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -208,10 +208,12 @@ class CodegenContext { * this field. The argument is the name of the mutable state variable. * If left blank, the field will be default-initialized. * @param forceInline whether the declaration and initialization code may be inlined rather than - * compacted. Please set `true` into forceInline, if you want to access the - * status fast (e.g. frequently accessed) or if you want to use the original - * variable name - * @param useFreshName If this is false and forceInline is true, the name is not changed + * compacted. Please set `true` into forceInline for one of the followings: + * 1. use the original name of the status + * 2. expect to non-frequently generate the status + * (e.g. not much sort operators in one stage) + * @param useFreshName If this is false and the mutable state ends up inlining in the outer + * class, the name is not changed * @return the name of the mutable state variable, which is the original name or fresh name if * the variable is inlined to the outer class, or an array access if the variable is to * be stored in an array of variables of the same type. @@ -221,7 +223,6 @@ class CodegenContext { * 2. its type is primitive type and the total number of the inlined mutable variables * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` * 3. its type is multi-dimensional array - * A primitive type variable will be inlined into outer class when the total number of * When a variable is compacted into an array, the max size of the array for compaction * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 5ac8121a08532..fa5425c77ebba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -118,6 +118,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + // inline mutable state since not many Like operations in a task val pattern = ctx.addMutableState(patternClass, "patternLike", v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) @@ -193,6 +194,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) + // inline mutable state since not many RLike operations in a task val pattern = ctx.addMutableState(patternClass, "patternRLike", v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 996598055b1ba..daff3c49e7517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -138,6 +138,7 @@ case class SortExec( // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) + // inline mutable state since not many Sort operations in a task sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter", v => s"$v = $thisPlan.createSorter();", forceInline = true) val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index e123b47e4cfd8..9e7008d1e0c31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -283,6 +283,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp override def doProduce(ctx: CodegenContext): String = { // Right now, InputAdapter is only used when there is one input RDD. + // inline mutable state since an inputAdaptor in a task val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", forceInline = true) val row = ctx.freshName("row") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index e528f744772fb..b1af360d85095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -605,12 +605,15 @@ case class HashAggregateExec( } // Create a name for the iterator from the regular hash map. - val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, "mapIter") + // inline mutable state since not many aggregation operations in a task + val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, + "mapIter", forceInline = true) // create hashMap val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap", v => s"$v = $thisPlan.createHashMap();") - sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter") + sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter", + forceInline = true) val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d1bd3960fce75..78137d3f97cfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -284,6 +284,7 @@ case class SampleExec( val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") + // inline mutable state since not many Sample operations in a task val sampler = ctx.addMutableState(s"$samplerClass", "sampleReplace", v => { val initSamplerFuncName = ctx.addNewFunction(initSampler, @@ -317,7 +318,7 @@ case class SampleExec( v => s""" | $v = new $samplerClass($lowerBound, $upperBound, false); | $v.setSeed(${seed}L + partitionIndex); - """.stripMargin.trim, forceInline = true) + """.stripMargin.trim) s""" | if ($sampler.sample() != 0) { @@ -370,6 +371,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName + // inline mutable state since not many Range operations in a task val taskContext = ctx.addMutableState("TaskContext", "taskContext", v => s"$v = TaskContext.get();", forceInline = true) val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", @@ -435,10 +437,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } """.stripMargin) - // Right now, Range is only used when there is one upstream. - val input = ctx.addMutableState("scala.collection.Iterator", "input", - v => s"$v = inputs[0];", forceInline = true) - val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val range = ctx.freshName("range") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index e87f8b0639dc5..ee763e23415cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -139,6 +139,7 @@ case class BroadcastHashJoinExec( // At the end of the task, we update the avg hash probe. val avgHashProbe = metricTerm(ctx, "avgHashProbe") + // inline mutable state since not many join operations in a task val relationTerm = ctx.addMutableState(clsName, "relation", v => s""" | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index c1bb5786cb0b9..073730462a75f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -422,6 +422,7 @@ case class SortMergeJoinExec( */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. + // inline mutable state since not many join operations in a task val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) @@ -575,6 +576,7 @@ case class SortMergeJoinExec( override def needCopyResult: Boolean = true override def doProduce(ctx: CodegenContext): String = { + // inline mutable state since not many join operations in a task val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", v => s"$v = inputs[0];", forceInline = true) val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",