Skip to content

Commit 593d617

Browse files
author
Davies Liu
committed
pushing codegen into Expression
1 parent 2bcdf8c commit 593d617

File tree

16 files changed

+650
-587
lines changed

16 files changed

+650
-587
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.Logging
2121
import org.apache.spark.sql.catalyst.errors.attachTree
22+
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
2223
import org.apache.spark.sql.types._
2324
import org.apache.spark.sql.catalyst.trees
2425

@@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
4142
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
4243

4344
override def exprId: ExprId = throw new UnsupportedOperationException
45+
46+
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
47+
s"""
48+
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
49+
final ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
50+
${ctx.defaultPrimitive(dataType)} : (${ctx.getColumn(dataType, ordinal)});
51+
"""
52+
}
4453
}
4554

4655
object BindReferences extends Logging {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
2121
import java.text.{DateFormat, SimpleDateFormat}
2222

2323
import org.apache.spark.Logging
24+
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
2425
import org.apache.spark.sql.catalyst.util.DateUtils
2526
import org.apache.spark.sql.types._
2627

@@ -433,6 +434,42 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
433434
val evaluated = child.eval(input)
434435
if (evaluated == null) null else cast(evaluated)
435436
}
437+
438+
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = this match {
439+
440+
case Cast(child @ BinaryType(), StringType) =>
441+
castOrNull (ctx, ev, c =>
442+
s"new org.apache.spark.sql.types.UTF8String().set($c)",
443+
StringType)
444+
445+
case Cast(child @ DateType(), StringType) =>
446+
castOrNull(ctx, ev, c =>
447+
s"""new org.apache.spark.sql.types.UTF8String().set(
448+
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
449+
StringType)
450+
451+
case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
452+
castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c?1:0)", dt)
453+
454+
case Cast(child @ DecimalType(), IntegerType) =>
455+
castOrNull(ctx, ev, c => s"($c).toInt()", IntegerType)
456+
457+
case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
458+
castOrNull(ctx, ev, c => s"($c).to${ctx.termForType(dt)}()", dt)
459+
460+
case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
461+
castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c)", dt)
462+
463+
// Special handling required for timestamps in hive test cases since the toString function
464+
// does not match the expected output.
465+
case Cast(e, StringType) if e.dataType != TimestampType =>
466+
castOrNull(ctx, ev, c =>
467+
s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))",
468+
StringType)
469+
470+
case other =>
471+
super.genSource(ctx, ev)
472+
}
436473
}
437474

438475
object Cast {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
2122
import org.apache.spark.sql.catalyst.trees
2223
import org.apache.spark.sql.catalyst.trees.TreeNode
2324
import org.apache.spark.sql.types._
@@ -51,6 +52,51 @@ abstract class Expression extends TreeNode[Expression] {
5152
/** Returns the result of evaluating this expression on a given input Row */
5253
def eval(input: Row = null): Any
5354

55+
/**
56+
* Returns an [[EvaluatedExpression]], which contains Java source code that
57+
* can be used to generate the result of evaluating the expression on an input row.
58+
* @param ctx a [[CodeGenContext]]
59+
*/
60+
def gen(ctx: CodeGenContext): EvaluatedExpression = {
61+
val nullTerm = ctx.freshName("nullTerm")
62+
val primitiveTerm = ctx.freshName("primitiveTerm")
63+
val objectTerm = ctx.freshName("objectTerm")
64+
val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm)
65+
ve.code = genSource(ctx, ve)
66+
67+
// Only inject debugging code if debugging is turned on.
68+
// val debugCode =
69+
// if (debugLogging) {
70+
// val localLogger = log
71+
// val localLoggerTree = reify { localLogger }
72+
// s"""
73+
// $localLoggerTree.debug(
74+
// ${this.toString} + ": " + (if (${ev.nullTerm}) "null" else ${ev.primitiveTerm}.toString))
75+
// """
76+
// } else {
77+
// ""
78+
// }
79+
80+
ve
81+
}
82+
83+
/**
84+
* Returns Java source code for this expression
85+
*/
86+
def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
87+
val e = this.asInstanceOf[Expression]
88+
ctx.references += e
89+
s"""
90+
/* expression: ${this} */
91+
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
92+
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
93+
${ctx.primitiveForType(e.dataType)} ${ev.primitiveTerm} =
94+
${ctx.defaultPrimitive(e.dataType)};
95+
if (!${ev.nullTerm}) ${ev.primitiveTerm} =
96+
(${ctx.termForType(e.dataType)})${ev.objectTerm};
97+
"""
98+
}
99+
54100
/**
55101
* Returns `true` if this expression and all its children have been resolved to a specific schema
56102
* and input data types checking passed, and `false` if it still contains any unresolved
@@ -116,6 +162,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
116162
override def nullable: Boolean = left.nullable || right.nullable
117163

118164
override def toString: String = s"($left $symbol $right)"
165+
166+
167+
/**
168+
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
169+
* the same type. If either of the sub-expressions is null, the result of this computation
170+
* is assumed to be null.
171+
*
172+
* @param f a function from two primitive term names to a tree that evaluates them.
173+
*/
174+
def evaluate(ctx: CodeGenContext,
175+
ev: EvaluatedExpression,
176+
f: (String, String) => String): String =
177+
evaluateAs(left.dataType)(ctx, ev, f)
178+
179+
def evaluateAs(resultType: DataType)(ctx: CodeGenContext,
180+
ev: EvaluatedExpression,
181+
f: (String, String) => String): String = {
182+
// TODO: Right now some timestamp tests fail if we enforce this...
183+
if (left.dataType != right.dataType) {
184+
// log.warn(s"${left.dataType} != ${right.dataType}")
185+
}
186+
187+
val eval1 = left.gen(ctx)
188+
val eval2 = right.gen(ctx)
189+
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
190+
191+
eval1.code + eval2.code +
192+
s"""
193+
boolean ${ev.nullTerm} = ${eval1.nullTerm} || ${eval2.nullTerm};
194+
${ctx.primitiveForType(resultType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(resultType)};
195+
if(!${ev.nullTerm}) {
196+
${ev.primitiveTerm} = (${ctx.primitiveForType(resultType)})($resultCode);
197+
}
198+
"""
199+
}
119200
}
120201

121202
abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
@@ -124,6 +205,19 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
124205

125206
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
126207
self: Product =>
208+
def castOrNull(ctx: CodeGenContext,
209+
ev: EvaluatedExpression,
210+
f: String => String, dataType: DataType): String = {
211+
val eval = child.gen(ctx)
212+
eval.code +
213+
s"""
214+
boolean ${ev.nullTerm} = ${eval.nullTerm};
215+
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)};
216+
if (!${ev.nullTerm}) {
217+
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
218+
}
219+
"""
220+
}
127221
}
128222

129223
// TODO Semantically we probably not need GroupExpression

0 commit comments

Comments
 (0)