Skip to content

Commit 5e7b6b6

Browse files
Davies Liurxin
authored andcommitted
[SPARK-8117] [SQL] Push codegen implementation into each Expression
This PR move codegen implementation of expressions into Expression class itself, make it easy to manage. It introduces two APIs in Expression: ``` def gen(ctx: CodeGenContext): GeneratedExpressionCode def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code ``` gen(ctx) will call genSource(ctx, ev) to generate Java source code for the current expression. A expression needs to override genSource(). Here are the types: ``` type Term String type Code String /** * Java source for evaluating an [[Expression]] given a [[Row]] of input. */ case class GeneratedExpressionCode(var code: Code, nullTerm: Term, primitiveTerm: Term, objectTerm: Term) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported * by codegen, then they are evaluated directly. The unsupported expression is appended at the * end of `references`, the position of it is kept in the code, used to access and evaluate it. */ class CodeGenContext { /** * Holding all the expressions those do not support codegen, will be evaluated directly. */ val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]() } ``` This is basically apache#6660, but fixed style violation and compilation failure. Author: Davies Liu <[email protected]> Author: Reynold Xin <[email protected]> Closes apache#6690 from rxin/codegen and squashes the following commits: e1368c2 [Reynold Xin] Fixed tests. 73db80e [Reynold Xin] Fixed compilation failure. 19d6435 [Reynold Xin] Fixed style violation. 9adaeaf [Davies Liu] address comments f42c732 [Davies Liu] improve coverage and tests bad6828 [Davies Liu] address comments e03edaa [Davies Liu] consts fold 86fac2c [Davies Liu] fix style 02262c9 [Davies Liu] address comments b5d3617 [Davies Liu] Merge pull request #5 from rxin/codegen 48c454f [Reynold Xin] Some code gen update. 2344bc0 [Davies Liu] fix test 12ff88a [Davies Liu] fix build c5fb514 [Davies Liu] rename 8c6d82d [Davies Liu] update docs b145047 [Davies Liu] fix style e57959d [Davies Liu] add type alias 3ff25f8 [Davies Liu] refactor 593d617 [Davies Liu] pushing codegen into Expression
1 parent b127ff8 commit 5e7b6b6

23 files changed

+1036
-718
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.Logging
2121
import org.apache.spark.sql.catalyst.errors.attachTree
22+
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
2223
import org.apache.spark.sql.types._
2324
import org.apache.spark.sql.catalyst.trees
2425

@@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
4142
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
4243

4344
override def exprId: ExprId = throw new UnsupportedOperationException
45+
46+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
47+
s"""
48+
boolean ${ev.isNull} = i.isNullAt($ordinal);
49+
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
50+
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
51+
"""
52+
}
4453
}
4554

4655
object BindReferences extends Logging {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
2121
import java.text.{DateFormat, SimpleDateFormat}
2222

2323
import org.apache.spark.Logging
24+
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
2425
import org.apache.spark.sql.catalyst.util.DateUtils
2526
import org.apache.spark.sql.types._
2627

@@ -433,6 +434,47 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
433434
val evaluated = child.eval(input)
434435
if (evaluated == null) null else cast(evaluated)
435436
}
437+
438+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
439+
// TODO(cg): Add support for more data types.
440+
(child.dataType, dataType) match {
441+
442+
case (BinaryType, StringType) =>
443+
defineCodeGen (ctx, ev, c =>
444+
s"new ${ctx.stringType}().set($c)")
445+
case (DateType, StringType) =>
446+
defineCodeGen(ctx, ev, c =>
447+
s"""new ${ctx.stringType}().set(
448+
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
449+
// Special handling required for timestamps in hive test cases since the toString function
450+
// does not match the expected output.
451+
case (TimestampType, StringType) =>
452+
super.genCode(ctx, ev)
453+
case (_, StringType) =>
454+
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")
455+
456+
// fallback for DecimalType, this must be before other numeric types
457+
case (_, dt: DecimalType) =>
458+
super.genCode(ctx, ev)
459+
460+
case (BooleanType, dt: NumericType) =>
461+
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
462+
case (dt: DecimalType, BooleanType) =>
463+
defineCodeGen(ctx, ev, c => s"$c.isZero()")
464+
case (dt: NumericType, BooleanType) =>
465+
defineCodeGen(ctx, ev, c => s"$c != 0")
466+
467+
case (_: DecimalType, IntegerType) =>
468+
defineCodeGen(ctx, ev, c => s"($c).toInt()")
469+
case (_: DecimalType, dt: NumericType) =>
470+
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
471+
case (_: NumericType, dt: NumericType) =>
472+
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
473+
474+
case other =>
475+
super.genCode(ctx, ev)
476+
}
477+
}
436478
}
437479

438480
object Cast {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
2122
import org.apache.spark.sql.catalyst.trees
2223
import org.apache.spark.sql.catalyst.trees.TreeNode
2324
import org.apache.spark.sql.types._
@@ -51,6 +52,44 @@ abstract class Expression extends TreeNode[Expression] {
5152
/** Returns the result of evaluating this expression on a given input Row */
5253
def eval(input: Row = null): Any
5354

55+
/**
56+
* Returns an [[GeneratedExpressionCode]], which contains Java source code that
57+
* can be used to generate the result of evaluating the expression on an input row.
58+
*
59+
* @param ctx a [[CodeGenContext]]
60+
* @return [[GeneratedExpressionCode]]
61+
*/
62+
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
63+
val isNull = ctx.freshName("isNull")
64+
val primitive = ctx.freshName("primitive")
65+
val ve = GeneratedExpressionCode("", isNull, primitive)
66+
ve.code = genCode(ctx, ve)
67+
ve
68+
}
69+
70+
/**
71+
* Returns Java source code that can be compiled to evaluate this expression.
72+
* The default behavior is to call the eval method of the expression. Concrete expression
73+
* implementations should override this to do actual code generation.
74+
*
75+
* @param ctx a [[CodeGenContext]]
76+
* @param ev an [[GeneratedExpressionCode]] with unique terms.
77+
* @return Java source code
78+
*/
79+
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
80+
ctx.references += this
81+
val objectTerm = ctx.freshName("obj")
82+
s"""
83+
/* expression: ${this} */
84+
Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
85+
boolean ${ev.isNull} = ${objectTerm} == null;
86+
${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
87+
if (!${ev.isNull}) {
88+
${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm};
89+
}
90+
"""
91+
}
92+
5493
/**
5594
* Returns `true` if this expression and all its children have been resolved to a specific schema
5695
* and input data types checking passed, and `false` if it still contains any unresolved
@@ -116,6 +155,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
116155
override def nullable: Boolean = left.nullable || right.nullable
117156

118157
override def toString: String = s"($left $symbol $right)"
158+
159+
/**
160+
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
161+
* the same type. If either of the sub-expressions is null, the result of this computation
162+
* is assumed to be null.
163+
*
164+
* @param f accepts two variable names and returns Java code to compute the output.
165+
*/
166+
protected def defineCodeGen(
167+
ctx: CodeGenContext,
168+
ev: GeneratedExpressionCode,
169+
f: (Term, Term) => Code): String = {
170+
// TODO: Right now some timestamp tests fail if we enforce this...
171+
if (left.dataType != right.dataType) {
172+
// log.warn(s"${left.dataType} != ${right.dataType}")
173+
}
174+
175+
val eval1 = left.gen(ctx)
176+
val eval2 = right.gen(ctx)
177+
val resultCode = f(eval1.primitive, eval2.primitive)
178+
179+
s"""
180+
${eval1.code}
181+
boolean ${ev.isNull} = ${eval1.isNull};
182+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
183+
if (!${ev.isNull}) {
184+
${eval2.code}
185+
if(!${eval2.isNull}) {
186+
${ev.primitive} = $resultCode;
187+
} else {
188+
${ev.isNull} = true;
189+
}
190+
}
191+
"""
192+
}
119193
}
120194

121195
private[sql] object BinaryExpression {
@@ -128,6 +202,32 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
128202

129203
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
130204
self: Product =>
205+
206+
/**
207+
* Called by unary expressions to generate a code block that returns null if its parent returns
208+
* null, and if not not null, use `f` to generate the expression.
209+
*
210+
* As an example, the following does a boolean inversion (i.e. NOT).
211+
* {{{
212+
* defineCodeGen(ctx, ev, c => s"!($c)")
213+
* }}}
214+
*
215+
* @param f function that accepts a variable name and returns Java code to compute the output.
216+
*/
217+
protected def defineCodeGen(
218+
ctx: CodeGenContext,
219+
ev: GeneratedExpressionCode,
220+
f: Term => Code): Code = {
221+
val eval = child.gen(ctx)
222+
// reuse the previous isNull
223+
ev.isNull = eval.isNull
224+
eval.code + s"""
225+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
226+
if (!${ev.isNull}) {
227+
${ev.primitive} = ${f(eval.primitive)};
228+
}
229+
"""
230+
}
131231
}
132232

133233
// TODO Semantically we probably not need GroupExpression

0 commit comments

Comments
 (0)