Skip to content

Commit 3ff25f8

Browse files
author
Davies Liu
committed
refactor
1 parent 593d617 commit 3ff25f8

File tree

11 files changed

+163
-131
lines changed

11 files changed

+163
-131
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
4646
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
4747
s"""
4848
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
49-
final ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
50-
${ctx.defaultPrimitive(dataType)} : (${ctx.getColumn(dataType, ordinal)});
49+
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
50+
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
5151
"""
5252
}
5353
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -439,33 +439,30 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
439439

440440
case Cast(child @ BinaryType(), StringType) =>
441441
castOrNull (ctx, ev, c =>
442-
s"new org.apache.spark.sql.types.UTF8String().set($c)",
443-
StringType)
442+
s"new org.apache.spark.sql.types.UTF8String().set($c)")
444443

445444
case Cast(child @ DateType(), StringType) =>
446445
castOrNull(ctx, ev, c =>
447446
s"""new org.apache.spark.sql.types.UTF8String().set(
448-
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
449-
StringType)
447+
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
450448

451-
case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
452-
castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c?1:0)", dt)
449+
case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
450+
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)")
453451

454452
case Cast(child @ DecimalType(), IntegerType) =>
455-
castOrNull(ctx, ev, c => s"($c).toInt()", IntegerType)
453+
castOrNull(ctx, ev, c => s"($c).toInt()")
456454

457455
case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
458-
castOrNull(ctx, ev, c => s"($c).to${ctx.termForType(dt)}()", dt)
456+
castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
459457

460458
case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
461-
castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c)", dt)
459+
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")
462460

463461
// Special handling required for timestamps in hive test cases since the toString function
464462
// does not match the expected output.
465463
case Cast(e, StringType) if e.dataType != TimestampType =>
466464
castOrNull(ctx, ev, c =>
467-
s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))",
468-
StringType)
465+
s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))")
469466

470467
case other =>
471468
super.genSource(ctx, ev)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ abstract class Expression extends TreeNode[Expression] {
9090
/* expression: ${this} */
9191
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
9292
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
93-
${ctx.primitiveForType(e.dataType)} ${ev.primitiveTerm} =
94-
${ctx.defaultPrimitive(e.dataType)};
93+
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} =
94+
${ctx.defaultValue(e.dataType)};
9595
if (!${ev.nullTerm}) ${ev.primitiveTerm} =
96-
(${ctx.termForType(e.dataType)})${ev.objectTerm};
96+
(${ctx.boxedType(e.dataType)})${ev.objectTerm};
9797
"""
9898
}
9999

@@ -173,12 +173,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
173173
*/
174174
def evaluate(ctx: CodeGenContext,
175175
ev: EvaluatedExpression,
176-
f: (String, String) => String): String =
177-
evaluateAs(left.dataType)(ctx, ev, f)
178-
179-
def evaluateAs(resultType: DataType)(ctx: CodeGenContext,
180-
ev: EvaluatedExpression,
181-
f: (String, String) => String): String = {
176+
f: (String, String) => String): String = {
182177
// TODO: Right now some timestamp tests fail if we enforce this...
183178
if (left.dataType != right.dataType) {
184179
// log.warn(s"${left.dataType} != ${right.dataType}")
@@ -188,14 +183,19 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
188183
val eval2 = right.gen(ctx)
189184
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
190185

191-
eval1.code + eval2.code +
192-
s"""
193-
boolean ${ev.nullTerm} = ${eval1.nullTerm} || ${eval2.nullTerm};
194-
${ctx.primitiveForType(resultType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(resultType)};
195-
if(!${ev.nullTerm}) {
196-
${ev.primitiveTerm} = (${ctx.primitiveForType(resultType)})($resultCode);
197-
}
198-
"""
186+
s"""
187+
${eval1.code}
188+
boolean ${ev.nullTerm} = ${eval1.nullTerm};
189+
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
190+
if (!${ev.nullTerm}) {
191+
${eval2.code}
192+
if(!${eval2.nullTerm}) {
193+
${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode);
194+
} else {
195+
${ev.nullTerm} = true;
196+
}
197+
}
198+
"""
199199
}
200200
}
201201

@@ -207,16 +207,15 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
207207
self: Product =>
208208
def castOrNull(ctx: CodeGenContext,
209209
ev: EvaluatedExpression,
210-
f: String => String, dataType: DataType): String = {
210+
f: String => String): String = {
211211
val eval = child.gen(ctx)
212-
eval.code +
213-
s"""
214-
boolean ${ev.nullTerm} = ${eval.nullTerm};
215-
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)};
216-
if (!${ev.nullTerm}) {
217-
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
218-
}
219-
"""
212+
eval.code + s"""
213+
boolean ${ev.nullTerm} = ${eval.nullTerm};
214+
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
215+
if (!${ev.nullTerm}) {
216+
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
217+
}
218+
"""
220219
}
221220
}
222221

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
221221
eval1.code + eval2.code +
222222
s"""
223223
boolean ${ev.nullTerm} = false;
224-
${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} =
225-
${ctx.defaultPrimitive(left.dataType)};
224+
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
225+
${ctx.defaultValue(left.dataType)};
226226
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
227227
${ev.nullTerm} = true;
228228
} else {
@@ -279,8 +279,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
279279
eval1.code + eval2.code +
280280
s"""
281281
boolean ${ev.nullTerm} = false;
282-
${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} =
283-
${ctx.defaultPrimitive(left.dataType)};
282+
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
283+
${ctx.defaultValue(left.dataType)};
284284
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
285285
${ev.nullTerm} = true;
286286
} else {
@@ -412,8 +412,8 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
412412
val eval2 = right.gen(ctx)
413413
eval1.code + eval2.code + s"""
414414
boolean ${ev.nullTerm} = false;
415-
${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} =
416-
${ctx.defaultPrimitive(left.dataType)};
415+
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
416+
${ctx.defaultValue(left.dataType)};
417417

418418
if (${eval1.nullTerm}) {
419419
${ev.nullTerm} = ${eval2.nullTerm};
@@ -468,8 +468,8 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
468468

469469
eval1.code + eval2.code + s"""
470470
boolean ${ev.nullTerm} = false;
471-
${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} =
472-
${ctx.defaultPrimitive(left.dataType)};
471+
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
472+
${ctx.defaultValue(left.dataType)};
473473

474474
if (${eval1.nullTerm}) {
475475
${ev.nullTerm} = ${eval2.nullTerm};

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
7171
dataType match {
7272
case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)"
7373
case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)"
74-
case _ => s"(${termForType(dataType)})i.apply($ordinal)"
74+
case _ => s"(${boxedType(dataType)})i.apply($ordinal)"
7575
}
7676
}
7777

@@ -86,12 +86,12 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
8686

8787
def accessorForType(dt: DataType): String = dt match {
8888
case IntegerType => "getInt"
89-
case other => s"get${termForType(dt)}"
89+
case other => s"get${boxedType(dt)}"
9090
}
9191

9292
def mutatorForType(dt: DataType): String = dt match {
9393
case IntegerType => "setInt"
94-
case other => s"set${termForType(dt)}"
94+
case other => s"set${boxedType(dt)}"
9595
}
9696

9797
def hashSetForType(dt: DataType): String = dt match {
@@ -101,7 +101,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
101101
sys.error(s"Code generation not support for hashset of type $unsupportedType")
102102
}
103103

104-
def primitiveForType(dt: DataType): String = dt match {
104+
/**
105+
* Return the primitive type for a DataType
106+
*/
107+
def primitiveType(dt: DataType): String = dt match {
105108
case IntegerType => "int"
106109
case LongType => "long"
107110
case ShortType => "short"
@@ -117,7 +120,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
117120
case _ => "Object"
118121
}
119122

120-
def defaultPrimitive(dt: DataType): String = dt match {
123+
/**
124+
* Return the representation of default value for given DataType
125+
*/
126+
def defaultValue(dt: DataType): String = dt match {
121127
case BooleanType => "false"
122128
case FloatType => "-1.0f"
123129
case ShortType => "-1"
@@ -131,7 +137,10 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
131137
case _ => "null"
132138
}
133139

134-
def termForType(dt: DataType): String = dt match {
140+
/**
141+
* Return the boxed type in Java
142+
*/
143+
def boxedType(dt: DataType): String = dt match {
135144
case IntegerType => "Integer"
136145
case LongType => "Long"
137146
case ShortType => "Short"
@@ -147,6 +156,15 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
147156
case _ => "Object"
148157
}
149158

159+
/**
160+
* Returns a function to generate equal expression in Java
161+
*/
162+
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
163+
case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" }
164+
case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" }
165+
case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" }
166+
}
167+
150168
/**
151169
* List of data types that have special accessors and setters in [[Row]].
152170
*/
@@ -166,7 +184,6 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
166184
*/
167185
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
168186

169-
protected val rowType = classOf[Row].getName
170187
protected val exprType = classOf[Expression].getName
171188
protected val mutableRowType = classOf[MutableRow].getName
172189
protected val genericMutableRowType = classOf[GenericMutableRow].getName

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
4545
val ctx = newCodeGenContext()
4646
val columns = expressions.zipWithIndex.map {
4747
case (e, i) =>
48-
s"private ${ctx.primitiveForType(e.dataType)} c$i = ${ctx.defaultPrimitive(e.dataType)};\n"
48+
s"private ${ctx.primitiveType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
4949
}.mkString("\n ")
5050

5151
val initColumns = expressions.zipWithIndex.map {
@@ -68,7 +68,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
6868
}.mkString("\n ")
6969

7070
val updateCases = expressions.zipWithIndex.map { case (e, i) =>
71-
s"case $i: { c$i = (${ctx.termForType(e.dataType)})value; return;}"
71+
s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
7272
}.mkString("\n ")
7373

7474
val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
@@ -80,14 +80,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
8080
if (cases.count(_ != '\n') > 0) {
8181
s"""
8282
@Override
83-
public ${ctx.primitiveForType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
83+
public ${ctx.primitiveType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
8484
if (isNullAt(i)) {
85-
return ${ctx.defaultPrimitive(dataType)};
85+
return ${ctx.defaultValue(dataType)};
8686
}
8787
switch (i) {
8888
$cases
8989
}
90-
return ${ctx.defaultPrimitive(dataType)};
90+
return ${ctx.defaultValue(dataType)};
9191
}"""
9292
} else {
9393
""
@@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
103103
if (cases.count(_ != '\n') > 0) {
104104
s"""
105105
@Override
106-
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveForType(dataType)} value) {
106+
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveType(dataType)} value) {
107107
nullBits[i] = false;
108108
switch (i) {
109109
$cases

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
6868
eval.code + s"""
6969
boolean ${ev.nullTerm} = ${eval.nullTerm};
7070
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} =
71-
${ctx.defaultPrimitive(DecimalType())};
71+
${ctx.defaultValue(DecimalType())};
7272

7373
if (!${ev.nullTerm}) {
7474
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,33 +85,33 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
8585
if (value == null) {
8686
s"""
8787
final boolean ${ev.nullTerm} = true;
88-
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)};
88+
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
8989
"""
9090
} else {
9191
dataType match {
9292
case StringType =>
9393
val v = value.asInstanceOf[UTF8String]
9494
val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}"
9595
s"""
96-
final boolean ${ev.nullTerm} = false;
97-
org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} =
98-
new org.apache.spark.sql.types.UTF8String().set(${arr});
99-
"""
96+
final boolean ${ev.nullTerm} = false;
97+
org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} =
98+
new org.apache.spark.sql.types.UTF8String().set(${arr});
99+
"""
100100
case FloatType =>
101101
s"""
102-
final boolean ${ev.nullTerm} = false;
103-
float ${ev.primitiveTerm} = ${value}f;
104-
"""
102+
final boolean ${ev.nullTerm} = false;
103+
float ${ev.primitiveTerm} = ${value}f;
104+
"""
105105
case dt: DecimalType =>
106106
s"""
107-
final boolean ${ev.nullTerm} = false;
108-
${ctx.primitiveForType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveForType(dt)}().set($value);
109-
"""
107+
final boolean ${ev.nullTerm} = false;
108+
${ctx.primitiveType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dt)}().set($value);
109+
"""
110110
case dt: NumericType =>
111111
s"""
112-
final boolean ${ev.nullTerm} = false;
113-
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = $value;
114-
"""
112+
final boolean ${ev.nullTerm} = false;
113+
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
114+
"""
115115
case other =>
116116
super.genSource(ctx, ev)
117117
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
5656
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
5757
s"""
5858
boolean ${ev.nullTerm} = true;
59-
${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)};
59+
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
6060
""" +
6161
children.map { e =>
6262
val eval = e.gen(ctx)
@@ -131,4 +131,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
131131
}
132132
numNonNulls >= n
133133
}
134+
135+
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
136+
val nonnull = ctx.freshName("nonnull")
137+
val code = children.map { e =>
138+
val eval = e.gen(ctx)
139+
s"""
140+
if($nonnull < $n) {
141+
${eval.code}
142+
if(!${eval.nullTerm}) {
143+
$nonnull += 1;
144+
}
145+
}
146+
"""
147+
}.mkString("\n")
148+
s"""
149+
int $nonnull = 0;
150+
$code
151+
boolean ${ev.nullTerm} = false;
152+
boolean ${ev.primitiveTerm} = $nonnull >= $n;
153+
"""
154+
}
134155
}

0 commit comments

Comments
 (0)