Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 48c454f

Browse files
committed
Some code gen update.
1 parent 2344bc0 commit 48c454f

File tree

7 files changed

+99
-59
lines changed

7 files changed

+99
-59
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -435,37 +435,57 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
435435
if (evaluated == null) null else cast(evaluated)
436436
}
437437

438-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = this match {
438+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
439+
// TODO(cg): Add support for more data types.
440+
(child.dataType, dataType) match {
439441

440-
case Cast(child @ BinaryType(), StringType) =>
441-
castOrNull (ctx, ev, c =>
442-
s"new ${ctx.stringType}().set($c)")
442+
case (BinaryType, StringType) =>
443+
defineCodeGen (ctx, ev, c =>
444+
s"new ${ctx.stringType}().set($c)")
443445

444-
case Cast(child @ DateType(), StringType) =>
445-
castOrNull(ctx, ev, c =>
446-
s"""new ${ctx.stringType}().set(
446+
case (DateType, StringType) =>
447+
defineCodeGen(ctx, ev, c =>
448+
s"""new ${ctx.stringType}().set(
447449
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
448450

449-
case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
450-
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)")
451+
case (BooleanType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
452+
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)")
451453

452-
case Cast(child @ DecimalType(), IntegerType) =>
453-
castOrNull(ctx, ev, c => s"($c).toInt()")
454+
case (_: NumericType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
455+
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")
454456

455-
case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
456-
castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
457+
case (_: DecimalType, ByteType) =>
458+
defineCodeGen(ctx, ev, c => s"($c).toByte()")
457459

458-
case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
459-
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")
460+
case (_: DecimalType, ShortType) =>
461+
defineCodeGen(ctx, ev, c => s"($c).toShort()")
460462

461-
// Special handling required for timestamps in hive test cases since the toString function
462-
// does not match the expected output.
463-
case Cast(e, StringType) if e.dataType != TimestampType =>
464-
castOrNull(ctx, ev, c =>
465-
s"new ${ctx.stringType}().set(String.valueOf($c))")
463+
case (_: DecimalType, IntegerType) =>
464+
defineCodeGen(ctx, ev, c => s"($c).toInt()")
466465

467-
case other =>
468-
super.genCode(ctx, ev)
466+
case (_: DecimalType, LongType) =>
467+
defineCodeGen(ctx, ev, c => s"($c).toLong()")
468+
469+
case (_: DecimalType, FloatType) =>
470+
defineCodeGen(ctx, ev, c => s"($c).toFloat()")
471+
472+
case (_: DecimalType, DoubleType) =>
473+
defineCodeGen(ctx, ev, c => s"($c).toDouble()")
474+
475+
case (_: DecimalType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
476+
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
477+
478+
// Special handling required for timestamps in hive test cases since the toString function
479+
// does not match the expected output.
480+
case (TimestampType, StringType) =>
481+
super.genCode(ctx, ev)
482+
483+
case (_, StringType) =>
484+
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")
485+
486+
case other =>
487+
super.genCode(ctx, ev)
488+
}
469489
}
470490
}
471491

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ abstract class Expression extends TreeNode[Expression] {
6969
}
7070

7171
/**
72-
* Returns Java source code for this expression.
72+
* Returns Java source code that can be compiled to evaluate this expression.
73+
* The default behavior is to call the eval method of the expression. Concrete expression
74+
* implementations should override this to do actual code generation.
7375
*
7476
* @param ctx a [[CodeGenContext]]
7577
* @param ev an [[GeneratedExpressionCode]] with unique terms.
@@ -82,10 +84,10 @@ abstract class Expression extends TreeNode[Expression] {
8284
/* expression: ${this} */
8385
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
8486
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
85-
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} =
86-
${ctx.defaultValue(e.dataType)};
87-
if (!${ev.nullTerm}) ${ev.primitiveTerm} =
88-
(${ctx.boxedType(e.dataType)})${ev.objectTerm};
87+
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
88+
if (!${ev.nullTerm}) {
89+
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${ev.objectTerm};
90+
}
8991
"""
9092
}
9193

@@ -155,17 +157,17 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
155157

156158
override def toString: String = s"($left $symbol $right)"
157159

158-
159160
/**
160161
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
161162
* the same type. If either of the sub-expressions is null, the result of this computation
162163
* is assumed to be null.
163164
*
164-
* @param f a function from two primitive term names to a tree that evaluates them.
165+
* @param f accepts two variable names and returns Java code to compute the output.
165166
*/
166-
def evaluate(ctx: CodeGenContext,
167-
ev: GeneratedExpressionCode,
168-
f: (String, String) => String): String = {
167+
protected def defineCodeGen(
168+
ctx: CodeGenContext,
169+
ev: GeneratedExpressionCode,
170+
f: (String, String) => String): String = {
169171
// TODO: Right now some timestamp tests fail if we enforce this...
170172
if (left.dataType != right.dataType) {
171173
// log.warn(s"${left.dataType} != ${right.dataType}")
@@ -197,9 +199,22 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
197199

198200
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
199201
self: Product =>
200-
def castOrNull(ctx: CodeGenContext,
201-
ev: GeneratedExpressionCode,
202-
f: String => String): String = {
202+
203+
/**
204+
* Called by unary expressions to generate a code block that returns null if its parent returns
205+
* null, and if not not null, use `f` to generate the expression.
206+
*
207+
* As an example, the following does a boolean inversion (i.e. NOT).
208+
* {{{
209+
* defineCodeGen(ctx, ev, c => s"!($c)")
210+
* }}}
211+
*
212+
* @param f function that accepts a variable name and returns Java code to compute the output.
213+
*/
214+
protected def defineCodeGen(
215+
ctx: CodeGenContext,
216+
ev: GeneratedExpressionCode,
217+
f: String => String): String = {
203218
val eval = child.gen(ctx)
204219
eval.code + s"""
205220
boolean ${ev.nullTerm} = ${eval.nullTerm};

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ case class Abs(child: Expression) extends UnaryArithmetic {
8787
abstract class BinaryArithmetic extends BinaryExpression {
8888
self: Product =>
8989

90+
/** Name of the function for this expression on a [[Decimal]] type. */
9091
def decimalMethod: String = ""
9192

9293
override def dataType: DataType = left.dataType
@@ -119,9 +120,9 @@ abstract class BinaryArithmetic extends BinaryExpression {
119120

120121
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
121122
if (left.dataType.isInstanceOf[DecimalType]) {
122-
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } )
123+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
123124
} else {
124-
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1 $symbol $eval2" } )
125+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
125126
}
126127
}
127128

@@ -205,6 +206,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
205206
}
206207
}
207208

209+
/**
210+
* Special case handling due to division by 0 => null.
211+
*/
208212
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
209213
val eval1 = left.gen(ctx)
210214
val eval2 = right.gen(ctx)
@@ -221,8 +225,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
221225
eval1.code + eval2.code +
222226
s"""
223227
boolean ${ev.nullTerm} = false;
224-
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
225-
${ctx.defaultValue(left.dataType)};
228+
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
226229
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
227230
${ev.nullTerm} = true;
228231
} else {
@@ -263,6 +266,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
263266
}
264267
}
265268

269+
/**
270+
* Special case handling for x % 0 ==> null.
271+
*/
266272
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
267273
val eval1 = left.gen(ctx)
268274
val eval2 = right.gen(ctx)
@@ -279,8 +285,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
279285
eval1.code + eval2.code +
280286
s"""
281287
boolean ${ev.nullTerm} = false;
282-
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
283-
${ctx.defaultValue(left.dataType)};
288+
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
284289
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
285290
${ev.nullTerm} = true;
286291
} else {
@@ -337,7 +342,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
337342
}
338343

339344
/**
340-
* A function that calculates bitwise xor(^) of two numbers.
345+
* A function that calculates bitwise xor of two numbers.
341346
*/
342347
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
343348
override def symbol: String = "^"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
6767
val eval = child.gen(ctx)
6868
eval.code + s"""
6969
boolean ${ev.nullTerm} = ${eval.nullTerm};
70-
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} =
71-
${ctx.defaultValue(DecimalType())};
70+
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = ${ctx.defaultValue(DecimalType())};
7271

7372
if (!${ev.nullTerm}) {
74-
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();
75-
${ev.primitiveTerm} =
76-
${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale);
77-
${ev.nullTerm} = ${ev.primitiveTerm} == null;
73+
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();
74+
${ev.primitiveTerm} =
75+
${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale);
76+
${ev.nullTerm} = ${ev.primitiveTerm} == null;
7877
}
7978
"""
8079
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
8888
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
8989
"""
9090
} else {
91+
// TODO(cg): Add support for more data types.
9192
dataType match {
9293
case StringType =>
9394
val v = value.asInstanceOf[UTF8String]
@@ -96,12 +97,12 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
9697
final boolean ${ev.nullTerm} = false;
9798
${ctx.stringType} ${ev.primitiveTerm} = new ${ctx.stringType}().set(${arr});
9899
"""
99-
case FloatType =>
100+
case FloatType => // This must go before NumericType
100101
s"""
101102
final boolean ${ev.nullTerm} = false;
102103
float ${ev.primitiveTerm} = ${value}f;
103104
"""
104-
case dt: DecimalType =>
105+
case dt: DecimalType => // This must go before NumericType
105106
s"""
106107
final boolean ${ev.nullTerm} = false;
107108
${ctx.primitiveType(dt)} ${ev.primitiveTerm} =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
6161
children.map { e =>
6262
val eval = e.gen(ctx)
6363
s"""
64-
if(${ev.nullTerm}) {
64+
if (${ev.nullTerm}) {
6565
${eval.code}
66-
if(!${eval.nullTerm}) {
66+
if (!${eval.nullTerm}) {
6767
${ev.nullTerm} = false;
6868
${ev.primitiveTerm} = ${eval.primitiveTerm};
6969
}
@@ -137,9 +137,9 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
137137
val code = children.map { e =>
138138
val eval = e.gen(ctx)
139139
s"""
140-
if($nonnull < $n) {
140+
if ($nonnull < $n) {
141141
${eval.code}
142-
if(!${eval.nullTerm}) {
142+
if (!${eval.nullTerm}) {
143143
$nonnull += 1;
144144
}
145145
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex
8585
}
8686

8787
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
88-
castOrNull(ctx, ev, c => s"!($c)")
88+
defineCodeGen(ctx, ev, c => s"!($c)")
8989
}
9090
}
9191

@@ -220,13 +220,13 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
220220
self: Product =>
221221
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
222222
left.dataType match {
223-
case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, {
223+
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
224224
(c1, c3) => s"$c1 $symbol $c3"
225225
})
226226
case TimestampType =>
227227
// java.sql.Timestamp does not have compare()
228228
super.genCode(ctx, ev)
229-
case other => evaluate (ctx, ev, {
229+
case other => defineCodeGen (ctx, ev, {
230230
(c1, c2) => s"$c1.compare($c2) $symbol 0"
231231
})
232232
}
@@ -277,7 +277,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
277277
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
278278
}
279279
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
280-
evaluate(ctx, ev, ctx.equalFunc(left.dataType))
280+
defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
281281
}
282282
}
283283

@@ -392,7 +392,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
392392
boolean ${ev.nullTerm} = false;
393393
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
394394
${condEval.code}
395-
if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
395+
if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
396396
${trueEval.code}
397397
${ev.nullTerm} = ${trueEval.nullTerm};
398398
${ev.primitiveTerm} = ${trueEval.primitiveTerm};

0 commit comments

Comments
 (0)