Skip to content

Commit 98093bb

Browse files
cloud-fandongjoon-hyun
authored andcommitted
[SPARK-32764][SQL] -0.0 should be equal to 0.0
This is a Spark 3.0 regression introduced by apache#26761. We missed a corner case that `java.lang.Double.compare` treats 0.0 and -0.0 as different, which breaks SQL semantic. This PR adds back the `OrderingUtil`, to provide custom compare methods that take care of 0.0 vs -0.0 Fix a correctness bug. Yes, now `SELECT 0.0 > -0.0` returns false correctly as Spark 2.x. new tests Closes apache#29647 from cloud-fan/float. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 4144b6d) Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent c633869 commit 98093bb

File tree

8 files changed

+151
-7
lines changed

8 files changed

+151
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.metrics.source.CodegenMetrics
3838
import org.apache.spark.sql.catalyst.InternalRow
3939
import org.apache.spark.sql.catalyst.expressions._
4040
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
41-
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
41+
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, SQLOrderingUtil}
4242
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
4343
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
4444
import org.apache.spark.sql.internal.SQLConf
@@ -624,8 +624,12 @@ class CodegenContext extends Logging {
624624
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
625625
// java boolean doesn't support > or < operator
626626
case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
627-
case DoubleType => s"java.lang.Double.compare($c1, $c2)"
628-
case FloatType => s"java.lang.Float.compare($c1, $c2)"
627+
case DoubleType =>
628+
val clsName = SQLOrderingUtil.getClass.getName.stripSuffix("$")
629+
s"$clsName.compareDoubles($c1, $c2)"
630+
case FloatType =>
631+
val clsName = SQLOrderingUtil.getClass.getName.stripSuffix("$")
632+
s"$clsName.compareFloats($c1, $c2)"
629633
// use c1 - c2 may overflow
630634
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
631635
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
object SQLOrderingUtil {
21+
22+
/**
23+
* A special version of double comparison that follows SQL semantic:
24+
* 1. NaN == NaN
25+
* 2. NaN is greater than any non-NaN double
26+
* 3. -0.0 == 0.0
27+
*/
28+
def compareDoubles(x: Double, y: Double): Int = {
29+
if (x == y) 0 else java.lang.Double.compare(x, y)
30+
}
31+
32+
/**
33+
* A special version of float comparison that follows SQL semantic:
34+
* 1. NaN == NaN
35+
* 2. NaN is greater than any non-NaN float
36+
* 3. -0.0 == 0.0
37+
*/
38+
def compareFloats(x: Float, y: Float): Int = {
39+
if (x == y) 0 else java.lang.Float.compare(x, y)
40+
}
41+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag
2222
import scala.util.Try
2323

2424
import org.apache.spark.annotation.Stable
25+
import org.apache.spark.sql.catalyst.util.SQLOrderingUtil
2526

2627
/**
2728
* The data type representing `Double` values. Please use the singleton `DataTypes.DoubleType`.
@@ -38,7 +39,7 @@ class DoubleType private() extends FractionalType {
3839
private[sql] val numeric = implicitly[Numeric[Double]]
3940
private[sql] val fractional = implicitly[Fractional[Double]]
4041
private[sql] val ordering =
41-
(x: Double, y: Double) => java.lang.Double.compare(x, y)
42+
(x: Double, y: Double) => SQLOrderingUtil.compareDoubles(x, y)
4243
private[sql] val asIntegral = DoubleType.DoubleAsIfIntegral
4344

4445
override private[sql] def exactNumeric = DoubleExactNumeric

sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag
2222
import scala.util.Try
2323

2424
import org.apache.spark.annotation.Stable
25+
import org.apache.spark.sql.catalyst.util.SQLOrderingUtil
2526

2627
/**
2728
* The data type representing `Float` values. Please use the singleton `DataTypes.FloatType`.
@@ -38,7 +39,7 @@ class FloatType private() extends FractionalType {
3839
private[sql] val numeric = implicitly[Numeric[Float]]
3940
private[sql] val fractional = implicitly[Fractional[Float]]
4041
private[sql] val ordering =
41-
(x: Float, y: Float) => java.lang.Float.compare(x, y)
42+
(x: Float, y: Float) => SQLOrderingUtil.compareFloats(x, y)
4243
private[sql] val asIntegral = FloatType.FloatAsIfIntegral
4344

4445
override private[sql] def exactNumeric = FloatExactNumeric

sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.types
2020
import scala.math.Numeric._
2121
import scala.math.Ordering
2222

23+
import org.apache.spark.sql.catalyst.util.SQLOrderingUtil
2324
import org.apache.spark.sql.types.Decimal.DecimalIsConflicted
2425

2526
private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering {
@@ -148,7 +149,7 @@ private[sql] object FloatExactNumeric extends FloatIsFractional {
148149
}
149150
}
150151

151-
override def compare(x: Float, y: Float): Int = java.lang.Float.compare(x, y)
152+
override def compare(x: Float, y: Float): Int = SQLOrderingUtil.compareFloats(x, y)
152153
}
153154

154155
private[sql] object DoubleExactNumeric extends DoubleIsFractional {
@@ -176,7 +177,7 @@ private[sql] object DoubleExactNumeric extends DoubleIsFractional {
176177
}
177178
}
178179

179-
override def compare(x: Double, y: Double): Int = java.lang.Double.compare(x, y)
180+
override def compare(x: Double, y: Double): Int = SQLOrderingUtil.compareDoubles(x, y)
180181
}
181182

182183
private[sql] object DecimalExactNumeric extends DecimalIsConflicted {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,4 +538,20 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
538538
val inSet = InSet(BoundReference(0, IntegerType, true), Set.empty)
539539
checkEvaluation(inSet, false, row)
540540
}
541+
542+
test("SPARK-32764: compare special double/float values") {
543+
checkEvaluation(EqualTo(Literal(Double.NaN), Literal(Double.NaN)), true)
544+
checkEvaluation(EqualTo(Literal(Double.NaN), Literal(Double.PositiveInfinity)), false)
545+
checkEvaluation(EqualTo(Literal(0.0D), Literal(-0.0D)), true)
546+
checkEvaluation(GreaterThan(Literal(Double.NaN), Literal(Double.PositiveInfinity)), true)
547+
checkEvaluation(GreaterThan(Literal(Double.NaN), Literal(Double.NaN)), false)
548+
checkEvaluation(GreaterThan(Literal(0.0D), Literal(-0.0D)), false)
549+
550+
checkEvaluation(EqualTo(Literal(Float.NaN), Literal(Float.NaN)), true)
551+
checkEvaluation(EqualTo(Literal(Float.NaN), Literal(Float.PositiveInfinity)), false)
552+
checkEvaluation(EqualTo(Literal(0.0F), Literal(-0.0F)), true)
553+
checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.PositiveInfinity)), true)
554+
checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.NaN)), false)
555+
checkEvaluation(GreaterThan(Literal(0.0F), Literal(-0.0F)), false)
556+
}
541557
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
import java.lang.{Double => JDouble, Float => JFloat}
21+
22+
import org.apache.spark.SparkFunSuite
23+
24+
class SQLOrderingUtilSuite extends SparkFunSuite {
25+
26+
test("compareDoublesSQL") {
27+
def shouldMatchDefaultOrder(a: Double, b: Double): Unit = {
28+
assert(SQLOrderingUtil.compareDoubles(a, b) === JDouble.compare(a, b))
29+
assert(SQLOrderingUtil.compareDoubles(b, a) === JDouble.compare(b, a))
30+
}
31+
shouldMatchDefaultOrder(0d, 0d)
32+
shouldMatchDefaultOrder(0d, 1d)
33+
shouldMatchDefaultOrder(-1d, 1d)
34+
shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue)
35+
36+
val specialNaN = JDouble.longBitsToDouble(0x7ff1234512345678L)
37+
assert(JDouble.isNaN(specialNaN))
38+
assert(JDouble.doubleToRawLongBits(Double.NaN) != JDouble.doubleToRawLongBits(specialNaN))
39+
40+
assert(SQLOrderingUtil.compareDoubles(Double.NaN, Double.NaN) === 0)
41+
assert(SQLOrderingUtil.compareDoubles(Double.NaN, specialNaN) === 0)
42+
assert(SQLOrderingUtil.compareDoubles(Double.NaN, Double.PositiveInfinity) > 0)
43+
assert(SQLOrderingUtil.compareDoubles(specialNaN, Double.PositiveInfinity) > 0)
44+
assert(SQLOrderingUtil.compareDoubles(Double.NaN, Double.NegativeInfinity) > 0)
45+
assert(SQLOrderingUtil.compareDoubles(Double.PositiveInfinity, Double.NaN) < 0)
46+
assert(SQLOrderingUtil.compareDoubles(Double.NegativeInfinity, Double.NaN) < 0)
47+
assert(SQLOrderingUtil.compareDoubles(0.0d, -0.0d) === 0)
48+
assert(SQLOrderingUtil.compareDoubles(-0.0d, 0.0d) === 0)
49+
}
50+
51+
test("compareFloatsSQL") {
52+
def shouldMatchDefaultOrder(a: Float, b: Float): Unit = {
53+
assert(SQLOrderingUtil.compareFloats(a, b) === JFloat.compare(a, b))
54+
assert(SQLOrderingUtil.compareFloats(b, a) === JFloat.compare(b, a))
55+
}
56+
shouldMatchDefaultOrder(0f, 0f)
57+
shouldMatchDefaultOrder(0f, 1f)
58+
shouldMatchDefaultOrder(-1f, 1f)
59+
shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue)
60+
61+
val specialNaN = JFloat.intBitsToFloat(-6966608)
62+
assert(JFloat.isNaN(specialNaN))
63+
assert(JFloat.floatToRawIntBits(Float.NaN) != JFloat.floatToRawIntBits(specialNaN))
64+
65+
assert(SQLOrderingUtil.compareDoubles(Float.NaN, Float.NaN) === 0)
66+
assert(SQLOrderingUtil.compareDoubles(Float.NaN, specialNaN) === 0)
67+
assert(SQLOrderingUtil.compareDoubles(Float.NaN, Float.PositiveInfinity) > 0)
68+
assert(SQLOrderingUtil.compareDoubles(specialNaN, Float.PositiveInfinity) > 0)
69+
assert(SQLOrderingUtil.compareDoubles(Float.NaN, Float.NegativeInfinity) > 0)
70+
assert(SQLOrderingUtil.compareDoubles(Float.PositiveInfinity, Float.NaN) < 0)
71+
assert(SQLOrderingUtil.compareDoubles(Float.NegativeInfinity, Float.NaN) < 0)
72+
assert(SQLOrderingUtil.compareDoubles(0.0f, -0.0f) === 0)
73+
assert(SQLOrderingUtil.compareDoubles(-0.0f, 0.0f) === 0)
74+
}
75+
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,6 +2450,11 @@ class DataFrameSuite extends QueryTest
24502450
assert(df.schema === new StructType().add(StructField("d", DecimalType(38, 0))))
24512451
}
24522452
}
2453+
2454+
test("SPARK-32764: -0.0 and 0.0 should be equal") {
2455+
val df = Seq(0.0 -> -0.0).toDF("pos", "neg")
2456+
checkAnswer(df.select($"pos" > $"neg"), Row(false))
2457+
}
24532458
}
24542459

24552460
case class GroupByKey(a: Int, b: Int)

0 commit comments

Comments
 (0)