Skip to content

Commit e57959d

Browse files
author
Davies Liu
committed
add type alias
1 parent 3ff25f8 commit e57959d

File tree

11 files changed

+69
-74
lines changed

11 files changed

+69
-74
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.Logging
2121
import org.apache.spark.sql.catalyst.errors.attachTree
22-
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
2323
import org.apache.spark.sql.types._
2424
import org.apache.spark.sql.catalyst.trees
2525

@@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
4343

4444
override def exprId: ExprId = throw new UnsupportedOperationException
4545

46-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
46+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
4747
s"""
4848
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
4949
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
2121
import java.text.{DateFormat, SimpleDateFormat}
2222

2323
import org.apache.spark.Logging
24-
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
24+
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
2525
import org.apache.spark.sql.catalyst.util.DateUtils
2626
import org.apache.spark.sql.types._
2727

@@ -435,7 +435,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
435435
if (evaluated == null) null else cast(evaluated)
436436
}
437437

438-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = this match {
438+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = this match {
439439

440440
case Cast(child @ BinaryType(), StringType) =>
441441
castOrNull (ctx, ev, c =>
@@ -465,7 +465,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
465465
s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))")
466466

467467
case other =>
468-
super.genSource(ctx, ev)
468+
super.genCode(ctx, ev)
469469
}
470470
}
471471

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
21-
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
2222
import org.apache.spark.sql.catalyst.trees
2323
import org.apache.spark.sql.catalyst.trees.TreeNode
2424
import org.apache.spark.sql.types._
@@ -62,28 +62,14 @@ abstract class Expression extends TreeNode[Expression] {
6262
val primitiveTerm = ctx.freshName("primitiveTerm")
6363
val objectTerm = ctx.freshName("objectTerm")
6464
val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm)
65-
ve.code = genSource(ctx, ve)
66-
67-
// Only inject debugging code if debugging is turned on.
68-
// val debugCode =
69-
// if (debugLogging) {
70-
// val localLogger = log
71-
// val localLoggerTree = reify { localLogger }
72-
// s"""
73-
// $localLoggerTree.debug(
74-
// ${this.toString} + ": " + (if (${ev.nullTerm}) "null" else ${ev.primitiveTerm}.toString))
75-
// """
76-
// } else {
77-
// ""
78-
// }
79-
65+
ve.code = genCode(ctx, ve)
8066
ve
8167
}
8268

8369
/**
8470
* Returns Java source code for this expression
8571
*/
86-
def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
72+
def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
8773
val e = this.asInstanceOf[Expression]
8874
ctx.references += e
8975
s"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21-
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, EvaluatedExpression, CodeGenContext}
2222
import org.apache.spark.sql.catalyst.util.TypeUtils
2323
import org.apache.spark.sql.types._
2424

@@ -117,7 +117,7 @@ abstract class BinaryArithmetic extends BinaryExpression {
117117
}
118118
}
119119

120-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
120+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
121121
if (left.dataType.isInstanceOf[DecimalType]) {
122122
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } )
123123
} else {
@@ -205,7 +205,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
205205
}
206206
}
207207

208-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
208+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
209209
val eval1 = left.gen(ctx)
210210
val eval2 = right.gen(ctx)
211211
val test = if (left.dataType.isInstanceOf[DecimalType]) {
@@ -263,7 +263,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
263263
}
264264
}
265265

266-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
266+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
267267
val eval1 = left.gen(ctx)
268268
val eval2 = right.gen(ctx)
269269
val test = if (left.dataType.isInstanceOf[DecimalType]) {
@@ -406,7 +406,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
406406
}
407407
}
408408

409-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
409+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
410410
if (ctx.isNativeType(left.dataType)) {
411411
val eval1 = left.gen(ctx)
412412
val eval2 = right.gen(ctx)
@@ -430,7 +430,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
430430
}
431431
"""
432432
} else {
433-
super.genSource(ctx, ev)
433+
super.genCode(ctx, ev)
434434
}
435435
}
436436
override def toString: String = s"MaxOf($left, $right)"
@@ -460,7 +460,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
460460
}
461461
}
462462

463-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
463+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
464464
if (ctx.isNativeType(left.dataType)) {
465465

466466
val eval1 = left.gen(ctx)
@@ -486,7 +486,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
486486
}
487487
"""
488488
} else {
489-
super.genSource(ctx, ev)
489+
super.genCode(ctx, ev)
490490
}
491491
}
492492

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,22 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
4141
* valid if `nullTerm` is set to `true`.
4242
* @param objectTerm A possibly boxed version of the result of evaluating this expression.
4343
*/
44-
case class EvaluatedExpression(var code: String,
45-
nullTerm: String,
46-
primitiveTerm: String,
47-
objectTerm: String)
44+
case class EvaluatedExpression(var code: Code,
45+
nullTerm: Term,
46+
primitiveTerm: Term,
47+
objectTerm: Term)
4848

4949
/**
50-
* A context for codegen
51-
* @param references the expressions that don't support codegen
50+
* A context for codegen, which is used to bookkeeping the expressions those are not supported
51+
* by codegen, then they are evaluated directly. The unsupported expression is appended at the
52+
* end of `references`, the position of it is kept in the code, used to access and evaluate it.
5253
*/
53-
case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
54+
class CodeGenContext {
55+
56+
/**
57+
* Holding all the expressions those do not support codegen, will be evaluated directly.
58+
*/
59+
val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]()
5460

5561
protected val stringType = classOf[UTF8String].getName
5662
protected val decimalType = classOf[Decimal].getName
@@ -63,19 +69,19 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
6369
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
6470
* function.)
6571
*/
66-
def freshName(prefix: String): String = {
72+
def freshName(prefix: String): Term = {
6773
s"$prefix${curId.getAndIncrement}"
6874
}
6975

70-
def getColumn(dataType: DataType, ordinal: Int): String = {
76+
def getColumn(dataType: DataType, ordinal: Int): Code = {
7177
dataType match {
7278
case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)"
7379
case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)"
7480
case _ => s"(${boxedType(dataType)})i.apply($ordinal)"
7581
}
7682
}
7783

78-
def setColumn(destinationRow: String, dataType: DataType, ordinal: Int, value: String): String = {
84+
def setColumn(destinationRow: Term, dataType: DataType, ordinal: Int, value: Term): Code = {
7985
dataType match {
8086
case StringType => s"$destinationRow.update($ordinal, $value)"
8187
case dt: DataType if isNativeType(dt) =>
@@ -84,17 +90,17 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
8490
}
8591
}
8692

87-
def accessorForType(dt: DataType): String = dt match {
93+
def accessorForType(dt: DataType): Term = dt match {
8894
case IntegerType => "getInt"
8995
case other => s"get${boxedType(dt)}"
9096
}
9197

92-
def mutatorForType(dt: DataType): String = dt match {
98+
def mutatorForType(dt: DataType): Term = dt match {
9399
case IntegerType => "setInt"
94100
case other => s"set${boxedType(dt)}"
95101
}
96102

97-
def hashSetForType(dt: DataType): String = dt match {
103+
def hashSetForType(dt: DataType): Term = dt match {
98104
case IntegerType => classOf[IntegerHashSet].getName
99105
case LongType => classOf[LongHashSet].getName
100106
case unsupportedType =>
@@ -104,7 +110,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
104110
/**
105111
* Return the primitive type for a DataType
106112
*/
107-
def primitiveType(dt: DataType): String = dt match {
113+
def primitiveType(dt: DataType): Term = dt match {
108114
case IntegerType => "int"
109115
case LongType => "long"
110116
case ShortType => "short"
@@ -123,7 +129,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
123129
/**
124130
* Return the representation of default value for given DataType
125131
*/
126-
def defaultValue(dt: DataType): String = dt match {
132+
def defaultValue(dt: DataType): Term = dt match {
127133
case BooleanType => "false"
128134
case FloatType => "-1.0f"
129135
case ShortType => "-1"
@@ -140,7 +146,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
140146
/**
141147
* Return the boxed type in Java
142148
*/
143-
def boxedType(dt: DataType): String = dt match {
149+
def boxedType(dt: DataType): Term = dt match {
144150
case IntegerType => "Integer"
145151
case LongType => "Long"
146152
case ShortType => "Short"
@@ -159,7 +165,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
159165
/**
160166
* Returns a function to generate equal expression in Java
161167
*/
162-
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
168+
def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match {
163169
case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" }
164170
case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" }
165171
case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" }
@@ -257,6 +263,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
257263
* expressions that don't support codegen
258264
*/
259265
def newCodeGenContext(): CodeGenContext = {
260-
new CodeGenContext(new mutable.ArrayBuffer[Expression]())
266+
new CodeGenContext
261267
}
262268
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ import org.apache.spark.util.Utils
2727
*/
2828
package object codegen {
2929

30+
type Term = String
31+
type Code = String
32+
3033
/** Canonicalizes an expression so those that differ only by names can reuse the same code. */
3134
object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
3235
val batches =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
20+
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
2121
import org.apache.spark.sql.types._
2222

2323
/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
@@ -37,7 +37,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
3737
}
3838
}
3939

40-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
40+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
4141
val eval = child.gen(ctx)
4242
eval.code +s"""
4343
boolean ${ev.nullTerm} = ${eval.nullTerm};
@@ -63,7 +63,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
6363
}
6464
}
6565

66-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
66+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
6767
val eval = child.gen(ctx)
6868
eval.code + s"""
6969
boolean ${ev.nullTerm} = ${eval.nullTerm};

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import java.sql.{Date, Timestamp}
2121

2222
import org.apache.spark.sql.catalyst.CatalystTypeConverters
23-
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, EvaluatedExpression}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, EvaluatedExpression}
2424
import org.apache.spark.sql.catalyst.util.DateUtils
2525
import org.apache.spark.sql.types._
2626

@@ -81,7 +81,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
8181

8282
override def eval(input: Row): Any = value
8383

84-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
84+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
8585
if (value == null) {
8686
s"""
8787
final boolean ${ev.nullTerm} = true;
@@ -113,7 +113,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
113113
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
114114
"""
115115
case other =>
116-
super.genSource(ctx, ev)
116+
super.genCode(ctx, ev)
117117
}
118118
}
119119
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
5353
result
5454
}
5555

56-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
56+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
5757
s"""
5858
boolean ${ev.nullTerm} = true;
5959
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
@@ -81,7 +81,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
8181
child.eval(input) == null
8282
}
8383

84-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
84+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
8585
val eval = child.gen(ctx)
8686
eval.code + s"""
8787
final boolean ${ev.nullTerm} = false;
@@ -101,7 +101,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
101101
child.eval(input) != null
102102
}
103103

104-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
104+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
105105
val eval = child.gen(ctx)
106106
eval.code + s"""
107107
boolean ${ev.nullTerm} = false;
@@ -132,7 +132,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
132132
numNonNulls >= n
133133
}
134134

135-
override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
135+
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
136136
val nonnull = ctx.freshName("nonnull")
137137
val code = children.map { e =>
138138
val eval = e.gen(ctx)

0 commit comments

Comments
 (0)