Skip to content

Commit ac41620

Browse files
committed
Simplified version.
1 parent c677aed commit ac41620

File tree

2 files changed

+48
-64
lines changed

2 files changed

+48
-64
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -309,17 +309,12 @@ class CodegenContext {
309309
funcCode: String,
310310
inlineToOuterClass: Boolean = false): String = {
311311
val newFunction = addNewFunctionInternal(funcName, funcCode, inlineToOuterClass)
312-
qualifiedFunctionName(newFunction)
313-
}
314-
315-
// Returns the name of the function, qualified by class if it will be inlined to a private,
316-
// inner class
317-
private[this] def qualifiedFunctionName(functionSpec: NewFunctionSpec): String =
318-
functionSpec match {
312+
newFunction match {
319313
case NewFunctionSpec(functionName, None, None) => functionName
320314
case NewFunctionSpec(functionName, Some(_), Some(innerClassInstance)) =>
321315
innerClassInstance + "." + functionName
322316
}
317+
}
323318

324319
private[this] def addNewFunctionInternal(
325320
funcName: String,
@@ -800,19 +795,14 @@ class CodegenContext {
800795
* @param returnType the return type of the split function.
801796
* @param makeSplitFunction makes split function body, e.g. add preparation or cleanup.
802797
* @param foldFunctions folds the split function calls.
803-
* @param makeFunctionCallback a callback function that is called after each function split.
804-
* The name of split function will be passed into the callback.
805-
* @param mergeSplit When true, try to merge split methods.
806798
*/
807799
def splitExpressionsWithCurrentInputs(
808800
expressions: Seq[String],
809801
funcName: String = "apply",
810802
extraArguments: Seq[(String, String)] = Nil,
811803
returnType: String = "void",
812804
makeSplitFunction: String => String = identity,
813-
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";"),
814-
makeFunctionCallback: String => Unit = identity,
815-
mergeSplit: Boolean = true): String = {
805+
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
816806
// TODO: support whole stage codegen
817807
if (INPUT_ROW == null || currentVars != null) {
818808
expressions.mkString("\n")
@@ -823,9 +813,7 @@ class CodegenContext {
823813
("InternalRow", INPUT_ROW) +: extraArguments,
824814
returnType,
825815
makeSplitFunction,
826-
foldFunctions,
827-
makeFunctionCallback,
828-
mergeSplit)
816+
foldFunctions)
829817
}
830818
}
831819

@@ -841,19 +829,14 @@ class CodegenContext {
841829
* @param returnType the return type of the split function.
842830
* @param makeSplitFunction makes split function body, e.g. add preparation or cleanup.
843831
* @param foldFunctions folds the split function calls.
844-
* @param makeFunctionCallback a callback function that is called after each function split.
845-
* The name of split function will be passed into the callback.
846-
* @param mergeSplit When true, try to merge split methods.
847832
*/
848833
def splitExpressions(
849834
expressions: Seq[String],
850835
funcName: String,
851836
arguments: Seq[(String, String)],
852837
returnType: String = "void",
853838
makeSplitFunction: String => String = identity,
854-
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";"),
855-
makeFunctionCallback: String => Unit = identity,
856-
mergeSplit: Boolean = true): String = {
839+
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
857840
val blocks = buildCodeBlocks(expressions)
858841

859842
if (blocks.length == 1) {
@@ -869,9 +852,7 @@ class CodegenContext {
869852
| ${makeSplitFunction(body)}
870853
|}
871854
""".stripMargin
872-
val functionSpec = addNewFunctionInternal(name, code, inlineToOuterClass = false)
873-
makeFunctionCallback(qualifiedFunctionName(functionSpec))
874-
functionSpec
855+
addNewFunctionInternal(name, code, inlineToOuterClass = false)
875856
}
876857

877858
val (outerClassFunctions, innerClassFunctions) = functions.partition(_.innerClassName.isEmpty)
@@ -885,8 +866,7 @@ class CodegenContext {
885866
arguments,
886867
returnType,
887868
makeSplitFunction,
888-
foldFunctions,
889-
mergeSplit)
869+
foldFunctions)
890870

891871
foldFunctions(outerClassFunctionCalls ++ innerClassFunctionCalls)
892872
}
@@ -934,7 +914,6 @@ class CodegenContext {
934914
* @param returnType the return type of the split function.
935915
* @param makeSplitFunction makes split function body, e.g. add preparation or cleanup.
936916
* @param foldFunctions folds the split function calls.
937-
* @param mergeSplit When true, try to merge split methods.
938917
* @return an [[Iterable]] containing the methods' invocations
939918
*/
940919
private def generateInnerClassesFunctionCalls(
@@ -943,8 +922,7 @@ class CodegenContext {
943922
arguments: Seq[(String, String)],
944923
returnType: String,
945924
makeSplitFunction: String => String,
946-
foldFunctions: Seq[String] => String,
947-
mergeSplit: Boolean = true): Iterable[String] = {
925+
foldFunctions: Seq[String] => String): Iterable[String] = {
948926
val innerClassToFunctions = mutable.LinkedHashMap.empty[(String, String), Seq[String]]
949927
functions.foreach(f => {
950928
val key = (f.innerClassName.get, f.innerClassInstance.get)
@@ -960,7 +938,7 @@ class CodegenContext {
960938
// for performance reasons, the functions are prepended, instead of appended,
961939
// thus here they are in reversed order
962940
val orderedFunctions = innerClassFunctions.reverse
963-
if (mergeSplit && orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) {
941+
if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) {
964942
// Adding a new function to each inner class which contains the invocation of all the
965943
// ones which have been added to that inner class. For example,
966944
// private class NestedClass {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -289,55 +289,61 @@ case class Elt(children: Seq[Expression])
289289
val index = indexExpr.genCode(ctx)
290290
val strings = stringExprs.map(_.genCode(ctx))
291291
val indexVal = ctx.freshName("index")
292+
293+
// -1 means the given index doesn't match indices of strings in split function.
294+
val NOT_MATCHED = -1
295+
// 0 means the given index matches one of indices of strings in split function.
296+
val MATCHED = 0
297+
val resultState = ctx.freshName("eltResultState")
298+
292299
val stringVal = ctx.freshName("stringVal")
300+
ctx.addMutableState(ctx.javaType(dataType), stringVal)
301+
293302
val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
294303
s"""
295-
case ${index + 1}:
296-
${eval.code}
297-
$stringVal = ${eval.isNull} ? null : ${eval.value};
298-
break;
299-
"""
304+
|if ($indexVal == ${index + 1}) {
305+
| ${eval.code}
306+
| $stringVal = ${eval.isNull} ? null : ${eval.value};
307+
| $resultState = (byte)$MATCHED;
308+
| continue;
309+
|}
310+
""".stripMargin
300311
}
301312

302-
var prevFunc = "null"
303-
var codes = ctx.splitExpressionsWithCurrentInputs(
313+
val codes = ctx.splitExpressionsWithCurrentInputs(
304314
expressions = assignStringValue,
305315
funcName = "eltFunc",
306316
extraArguments = ("int", indexVal) :: Nil,
307-
returnType = "UTF8String",
317+
returnType = ctx.JAVA_BYTE,
308318
makeSplitFunction = body =>
309319
s"""
310-
|UTF8String $stringVal = null;
311-
|switch ($indexVal) {
320+
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
321+
|do {
312322
| $body
313-
| default:
314-
| return $prevFunc;
315-
|}
316-
|return $stringVal;
317-
""".stripMargin,
318-
foldFunctions = funcs => s"UTF8String $stringVal = ${funcs.last};",
319-
makeFunctionCallback = f => prevFunc = s"$f(${ctx.INPUT_ROW}, $indexVal)",
320-
mergeSplit = false)
321-
322-
// If no any functions split, wraps all cases in a single switch.
323-
if (prevFunc == "null") {
324-
codes =
323+
|} while (false);
324+
|return $resultState;
325+
""".stripMargin,
326+
foldFunctions = _.map { funcCall =>
325327
s"""
326-
|UTF8String $stringVal = null;
327-
|switch ($indexVal) {
328-
| $codes
328+
|$resultState = $funcCall;
329+
|if ($resultState != $NOT_MATCHED) {
330+
| continue;
329331
|}
330-
""".stripMargin
331-
}
332+
""".stripMargin
333+
}.mkString)
332334

333335
ev.copy(
334336
s"""
335-
${index.code}
336-
final int $indexVal = ${index.value};
337-
$codes
338-
UTF8String ${ev.value} = $stringVal;
339-
final boolean ${ev.isNull} = ${ev.value} == null;
340-
""")
337+
|${index.code}
338+
|final int $indexVal = ${index.value};
339+
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
340+
|$stringVal = ${ctx.defaultValue(dataType)};
341+
|do {
342+
| $codes
343+
|} while (false);
344+
|final UTF8String ${ev.value} = $stringVal;
345+
|final boolean ${ev.isNull} = ${ev.value} == null;
346+
""".stripMargin)
341347
}
342348
}
343349

0 commit comments

Comments
 (0)