Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit e03edaa

Browse files
author
Davies Liu
committed
consts fold
1 parent 86fac2c commit e03edaa

File tree

7 files changed

+33
-30
lines changed

7 files changed

+33
-30
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
184184
if (!${ev.nullTerm}) {
185185
${eval2.code}
186186
if(!${eval2.nullTerm}) {
187-
${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode);
187+
${ev.primitiveTerm} = $resultCode;
188188
} else {
189189
${ev.nullTerm} = true;
190190
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,15 @@ abstract class BinaryArithmetic extends BinaryExpression {
118118
}
119119
}
120120

121-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
122-
if (left.dataType.isInstanceOf[DecimalType]) {
121+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
122+
case dt: DecimalType =>
123123
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
124-
} else {
124+
// byte and short are casted into int when add, minus, times or divide
125+
case ByteType | ShortType =>
126+
defineCodeGen(ctx, ev, (eval1, eval2) =>
127+
s"(${ctx.primitiveType(dataType)})($eval1 $symbol $eval2)")
128+
case _ =>
125129
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
126-
}
127130
}
128131

129132
protected def evalInternal(evalE1: Any, evalE2: Any): Any =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
4040
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
4141
* valid if `nullTerm` is set to `true`.
4242
*/
43-
case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, primitiveTerm: Term)
43+
case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, var primitiveTerm: Term)
4444

4545
/**
4646
* A context for codegen, which is used to bookkeeping the expressions those are not supported

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,25 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
8282
override def eval(input: Row): Any = value
8383

8484
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
85+
// change the nullTerm and primitiveTerm to consts, to inline them
8586
if (value == null) {
86-
s"""
87-
final boolean ${ev.nullTerm} = true;
88-
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
89-
"""
87+
ev.nullTerm = "true"
88+
ev.primitiveTerm = ctx.defaultValue(dataType)
89+
""
9090
} else {
9191
dataType match {
92+
case BooleanType =>
93+
ev.nullTerm = "false"
94+
ev.primitiveTerm = value.toString
95+
""
9296
case FloatType => // This must go before NumericType
93-
s"""
94-
final boolean ${ev.nullTerm} = false;
95-
final float ${ev.primitiveTerm} = ${value}f;
96-
"""
97+
ev.nullTerm = "false"
98+
ev.primitiveTerm = s"${value}f"
99+
""
97100
case dt: NumericType if !dt.isInstanceOf[DecimalType] =>
98-
s"""
99-
final boolean ${ev.nullTerm} = false;
100-
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
101-
"""
101+
ev.nullTerm = "false"
102+
ev.primitiveTerm = value.toString
103+
""
102104
// eval() version may be faster for non-primitive types
103105
case other =>
104106
super.genCode(ctx, ev)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
8383

8484
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
8585
val eval = child.gen(ctx)
86-
eval.code + s"""
87-
final boolean ${ev.nullTerm} = false;
88-
final boolean ${ev.primitiveTerm} = ${eval.nullTerm};
89-
"""
86+
ev.nullTerm = "false"
87+
ev.primitiveTerm = eval.nullTerm
88+
eval.code
9089
}
9190

9291
override def toString: String = s"IS NULL $child"
@@ -103,10 +102,9 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
103102

104103
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
105104
val eval = child.gen(ctx)
106-
eval.code + s"""
107-
boolean ${ev.nullTerm} = false;
108-
boolean ${ev.primitiveTerm} = !${eval.nullTerm};
109-
"""
105+
ev.nullTerm = "false"
106+
ev.primitiveTerm = s"(!(${eval.nullTerm}))"
107+
eval.code
110108
}
111109
}
112110

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
304304
val eval1 = left.gen(ctx)
305305
val eval2 = right.gen(ctx)
306306
val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm)
307+
ev.nullTerm = "false"
307308
eval1.code + eval2.code + s"""
308-
final boolean ${ev.nullTerm} = false;
309309
final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) ||
310310
(!${eval1.nullTerm} && $equalCode);
311311
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ case class NewSet(elementType: DataType) extends LeafExpression {
6464
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
6565
elementType match {
6666
case IntegerType | LongType =>
67+
ev.nullTerm = "false"
6768
s"""
68-
boolean ${ev.nullTerm} = false;
6969
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dataType)}();
7070
"""
7171
case _ => super.genCode(ctx, ev)
@@ -111,11 +111,11 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
111111
val setEval = set.gen(ctx)
112112
val htype = ctx.primitiveType(dataType)
113113

114+
ev.nullTerm = "false"
114115
itemEval.code + setEval.code + s"""
115116
if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
116117
(($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
117118
}
118-
boolean ${ev.nullTerm} = false;
119119
${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm};
120120
"""
121121
case _ => super.genCode(ctx, ev)
@@ -164,8 +164,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
164164
val rightEval = right.gen(ctx)
165165
val htype = ctx.primitiveType(dataType)
166166

167+
ev.nullTerm = "false"
167168
leftEval.code + rightEval.code + s"""
168-
boolean ${ev.nullTerm} = false;
169169
${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm};
170170
${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm});
171171
"""

0 commit comments

Comments
 (0)