Skip to content

Commit 1bdff95

Browse files
committed
Address comment
1 parent 82e97e3 commit 1bdff95

File tree

2 files changed

+16
-32
lines changed

2 files changed

+16
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -309,17 +309,11 @@ trait DivModLike extends BinaryArithmetic {
309309

310310
override def nullable: Boolean = true
311311

312-
final override def eval(input: InternalRow): Any = {
313-
val input2 = right.eval(input)
314-
if (input2 == null || input2 == 0) {
312+
final override def nullSafeEval(input1: Any, input2: Any): Any = {
313+
if (input2 == 0) {
315314
null
316315
} else {
317-
val input1 = left.eval(input)
318-
if (input1 == null) {
319-
null
320-
} else {
321-
evalOperation(input1, input2)
322-
}
316+
evalOperation(input1, input2)
323317
}
324318
}
325319

@@ -516,24 +510,18 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
516510

517511
override def nullable: Boolean = true
518512

519-
override def eval(input: InternalRow): Any = {
520-
val input2 = right.eval(input)
521-
if (input2 == null || input2 == 0) {
513+
override def nullSafeEval(input1: Any, input2: Any): Any = {
514+
if (input2 == 0) {
522515
null
523516
} else {
524-
val input1 = left.eval(input)
525-
if (input1 == null) {
526-
null
527-
} else {
528-
input1 match {
529-
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
530-
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
531-
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
532-
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
533-
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
534-
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
535-
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
536-
}
517+
input1 match {
518+
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
519+
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
520+
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
521+
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
522+
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
523+
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
524+
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
537525
}
538526
}
539527
}

sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
164164
val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
165165
classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression])
166166

167-
// Do not check these expressions, because these expressions extend NullIntolerant
168-
// and override the eval function.
169-
val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod])
170-
171167
val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
172168
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
173-
.filterNot(c => ignoreSet.exists(_.getName.equals(c)))
174169
.map(name => Utils.classForName(name))
175170
.filterNot(classOf[NonSQLExpression].isAssignableFrom)
176171

@@ -180,8 +175,9 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
180175
superClass.getMethod("eval", classOf[InternalRow])
181176
val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz)
182177
if (isEvalOverrode && isNullIntolerantMixedIn) {
183-
fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " +
184-
s"or add ${clazz.getName} in the ignoreSet of this test.")
178+
fail(s"${clazz.getName} overrode the eval method and extended " +
179+
s"${classOf[NullIntolerant].getSimpleName}, which may be incorrect. " +
180+
s"You may need to override the nullSafeEval method.")
185181
} else if (!isEvalOverrode && !isNullIntolerantMixedIn) {
186182
fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.")
187183
} else {

0 commit comments

Comments
 (0)