Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 2344bc0

Browse files
author
Davies Liu
committed
fix test
1 parent 12ff88a commit 2344bc0

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,12 @@ class CodeGenContext {
166166
* Returns a function to generate equal expression in Java
167167
*/
168168
def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match {
169-
case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" }
170-
case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" }
171-
case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" }
169+
case BinaryType => { case (eval1, eval2) =>
170+
s"java.util.Arrays.equals($eval1, $eval2)" }
171+
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
172+
{ case (eval1, eval2) => s"$eval1 == $eval2" }
173+
case other =>
174+
{ case (eval1, eval2) => s"$eval1.equals($eval2)" }
172175
}
173176

174177
/**
@@ -221,7 +224,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
221224
*/
222225
protected def compile(code: String): Class[_] = {
223226
val startTime = System.nanoTime()
224-
val clazz = new ClassBodyEvaluator(code).getClazz()
227+
val clazz = try {
228+
new ClassBodyEvaluator(code).getClazz()
229+
} catch {
230+
case e: Exception =>
231+
logError(s"failed to compile:\n $code", e)
232+
throw e
233+
}
225234
val endTime = System.nanoTime()
226235
def timeMs: Double = (endTime - startTime).toDouble / 1000000
227236
logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
167167

168168
leftEval.code + rightEval.code + s"""
169169
boolean ${ev.nullTerm} = false;
170-
${htype} ${ev.primitiveTerm} = ${leftEval.primitiveTerm};
171-
${ev.primitiveTerm}.union(${rightEval.primitiveTerm});
170+
${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm};
171+
${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm});
172172
"""
173173
case _ => super.genCode(ctx, ev)
174174
}
@@ -191,9 +191,5 @@ case class CountSet(child: Expression) extends UnaryExpression {
191191
}
192192
}
193193

194-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
195-
castOrNull(ctx, ev, c => s"$c.size().toLong()")
196-
}
197-
198194
override def toString: String = s"$child.count()"
199195
}

0 commit comments

Comments
 (0)