Skip to content

Code gen code review. #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -435,37 +435,57 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
if (evaluated == null) null else cast(evaluated)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = this match {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
// TODO(cg): Add support for more data types.
(child.dataType, dataType) match {

case Cast(child @ BinaryType(), StringType) =>
castOrNull (ctx, ev, c =>
s"new ${ctx.stringType}().set($c)")
case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
s"new ${ctx.stringType}().set($c)")

case Cast(child @ DateType(), StringType) =>
castOrNull(ctx, ev, c =>
s"""new ${ctx.stringType}().set(
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""new ${ctx.stringType}().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")

case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)")
case (BooleanType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)")

case Cast(child @ DecimalType(), IntegerType) =>
castOrNull(ctx, ev, c => s"($c).toInt()")
case (_: NumericType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can not handle DecimalType => NumericType

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't that handled later?


case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
case (_: DecimalType, ByteType) =>
defineCodeGen(ctx, ev, c => s"($c).toByte()")

case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")
case (_: DecimalType, ShortType) =>
defineCodeGen(ctx, ev, c => s"($c).toShort()")

// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
castOrNull(ctx, ev, c =>
s"new ${ctx.stringType}().set(String.valueOf($c))")
case (_: DecimalType, IntegerType) =>
defineCodeGen(ctx, ev, c => s"($c).toInt()")

case other =>
super.genCode(ctx, ev)
case (_: DecimalType, LongType) =>
defineCodeGen(ctx, ev, c => s"($c).toLong()")

case (_: DecimalType, FloatType) =>
defineCodeGen(ctx, ev, c => s"($c).toFloat()")

case (_: DecimalType, DoubleType) =>
defineCodeGen(ctx, ev, c => s"($c).toDouble()")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could be covered by next case

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok can you accept and just remove that? I have more comments you need to address anyway (and we should dicsuss them too).


case (_: DecimalType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")

// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case (TimestampType, StringType) =>
super.genCode(ctx, ev)

case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")

case other =>
super.genCode(ctx, ev)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ abstract class Expression extends TreeNode[Expression] {
}

/**
* Returns Java source code for this expression.
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression. Concrete expression
* implementations should override this to do actual code generation.
*
* @param ctx a [[CodeGenContext]]
* @param ev an [[GeneratedExpressionCode]] with unique terms.
Expand All @@ -82,10 +84,10 @@ abstract class Expression extends TreeNode[Expression] {
/* expression: ${this} */
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) ${ev.primitiveTerm} =
(${ctx.boxedType(e.dataType)})${ev.objectTerm};
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${ev.objectTerm};
}
"""
}

Expand Down Expand Up @@ -155,17 +157,17 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express

override def toString: String = s"($left $symbol $right)"


/**
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
* the same type. If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f a function from two primitive term names to a tree that evaluates them.
* @param f accepts two variable names and returns Java code to compute the output.
*/
def evaluate(ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
Expand Down Expand Up @@ -197,9 +199,22 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]

abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
def castOrNull(ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: String => String): String = {

/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
*
* As an example, the following does a boolean inversion (i.e. NOT).
* {{{
* defineCodeGen(ctx, ev, c => s"!($c)")
* }}}
*
* @param f function that accepts a variable name and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: String => String): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ case class Abs(child: Expression) extends UnaryArithmetic {
abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>

/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String = ""

override def dataType: DataType = left.dataType
Expand Down Expand Up @@ -119,9 +120,9 @@ abstract class BinaryArithmetic extends BinaryExpression {

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
if (left.dataType.isInstanceOf[DecimalType]) {
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } )
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
} else {
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1 $symbol $eval2" } )
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}

Expand Down Expand Up @@ -205,6 +206,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}

/**
* Special case handling due to division by 0 => null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
Expand All @@ -221,8 +225,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
} else {
Expand Down Expand Up @@ -263,6 +266,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
}

/**
* Special case handling for x % 0 ==> null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
Expand All @@ -279,8 +285,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
} else {
Expand Down Expand Up @@ -337,7 +342,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
}

/**
* A function that calculates bitwise xor(^) of two numbers.
* A function that calculates bitwise xor of two numbers.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "^"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} =
${ctx.defaultValue(DecimalType())};
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = ${ctx.defaultValue(DecimalType())};

if (!${ev.nullTerm}) {
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();
${ev.primitiveTerm} =
${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale);
${ev.nullTerm} = ${ev.primitiveTerm} == null;
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();
${ev.primitiveTerm} =
${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale);
${ev.nullTerm} = ${ev.primitiveTerm} == null;
}
"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
"""
} else {
// TODO(cg): Add support for more data types.
dataType match {
case StringType =>
val v = value.asInstanceOf[UTF8String]
Expand All @@ -96,12 +97,12 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
final boolean ${ev.nullTerm} = false;
${ctx.stringType} ${ev.primitiveTerm} = new ${ctx.stringType}().set(${arr});
"""
case FloatType =>
case FloatType => // This must go before NumericType
s"""
final boolean ${ev.nullTerm} = false;
float ${ev.primitiveTerm} = ${value}f;
"""
case dt: DecimalType =>
case dt: DecimalType => // This must go before NumericType
s"""
final boolean ${ev.nullTerm} = false;
${ctx.primitiveType(dt)} ${ev.primitiveTerm} =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
children.map { e =>
val eval = e.gen(ctx)
s"""
if(${ev.nullTerm}) {
if (${ev.nullTerm}) {
${eval.code}
if(!${eval.nullTerm}) {
if (!${eval.nullTerm}) {
${ev.nullTerm} = false;
${ev.primitiveTerm} = ${eval.primitiveTerm};
}
Expand Down Expand Up @@ -137,9 +137,9 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
val code = children.map { e =>
val eval = e.gen(ctx)
s"""
if($nonnull < $n) {
if ($nonnull < $n) {
${eval.code}
if(!${eval.nullTerm}) {
if (!${eval.nullTerm}) {
$nonnull += 1;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
castOrNull(ctx, ev, c => s"!($c)")
defineCodeGen(ctx, ev, c => s"!($c)")
}
}

Expand Down Expand Up @@ -220,13 +220,13 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
self: Product =>
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
left.dataType match {
case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, {
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
(c1, c3) => s"$c1 $symbol $c3"
})
case TimestampType =>
// java.sql.Timestamp does not have compare()
super.genCode(ctx, ev)
case other => evaluate (ctx, ev, {
case other => defineCodeGen (ctx, ev, {
(c1, c2) => s"$c1.compare($c2) $symbol 0"
})
}
Expand Down Expand Up @@ -277,7 +277,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
evaluate(ctx, ev, ctx.equalFunc(left.dataType))
defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
}
}

Expand Down Expand Up @@ -392,7 +392,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
${condEval.code}
if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
${trueEval.code}
${ev.nullTerm} = ${trueEval.nullTerm};
${ev.primitiveTerm} = ${trueEval.primitiveTerm};
Expand Down