Skip to content

Commit 1093c0f

Browse files
cloud-fandongjoon-hyun
authored andcommitted
[SPARK-32110][SQL] normalize special floating numbers in HyperLogLog++
### What changes were proposed in this pull request? Currently, Spark treats 0.0 and -0.0 semantically equal, while it still retains the difference between them so that users can see -0.0 when displaying the data set. The comparison expressions in Spark take care of the special floating numbers and implement the correct semantic. However, Spark doesn't always use these comparison expressions to compare values, and we need to normalize the special floating numbers before comparing them in these places: 1. GROUP BY 2. join keys 3. window partition keys This PR fixes one more place that compares values without using comparison expressions: HyperLogLog++ ### Why are the changes needed? Fix the query result ### Does this PR introduce _any_ user-facing change? Yes, the result of HyperLogLog++ becomes correct now. ### How was this patch tested? a new test case, and a few more test cases that pass before this PR to improve test coverage. Closes #30673 from cloud-fan/bug. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 6fd2345) Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 54a73ab commit 1093c0f

File tree

4 files changed

+144
-23
lines changed

4 files changed

+144
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,28 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
143143

144144
case _ => throw new IllegalStateException(s"fail to normalize $expr")
145145
}
146+
147+
val FLOAT_NORMALIZER: Any => Any = (input: Any) => {
148+
val f = input.asInstanceOf[Float]
149+
if (f.isNaN) {
150+
Float.NaN
151+
} else if (f == -0.0f) {
152+
0.0f
153+
} else {
154+
f
155+
}
156+
}
157+
158+
val DOUBLE_NORMALIZER: Any => Any = (input: Any) => {
159+
val d = input.asInstanceOf[Double]
160+
if (d.isNaN) {
161+
Double.NaN
162+
} else if (d == -0.0d) {
163+
0.0d
164+
} else {
165+
d
166+
}
167+
}
146168
}
147169

148170
case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with ExpectsInputTypes {
@@ -152,27 +174,8 @@ case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with E
152174
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(FloatType, DoubleType))
153175

154176
private lazy val normalizer: Any => Any = child.dataType match {
155-
case FloatType => (input: Any) => {
156-
val f = input.asInstanceOf[Float]
157-
if (f.isNaN) {
158-
Float.NaN
159-
} else if (f == -0.0f) {
160-
0.0f
161-
} else {
162-
f
163-
}
164-
}
165-
166-
case DoubleType => (input: Any) => {
167-
val d = input.asInstanceOf[Double]
168-
if (d.isNaN) {
169-
Double.NaN
170-
} else if (d == -0.0d) {
171-
0.0d
172-
} else {
173-
d
174-
}
175-
}
177+
case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER
178+
case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER
176179
}
177180

178181
override def nullSafeEval(input: Any): Any = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util
2222

2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.expressions.XxHash64Function
25+
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers.{DOUBLE_NORMALIZER, FLOAT_NORMALIZER}
2526
import org.apache.spark.sql.types._
2627

2728
// A helper class for HyperLogLogPlusPlus.
@@ -88,7 +89,12 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable {
8889
*
8990
* Variable names in the HLL++ paper match variable names in the code.
9091
*/
91-
def update(buffer: InternalRow, bufferOffset: Int, value: Any, dataType: DataType): Unit = {
92+
def update(buffer: InternalRow, bufferOffset: Int, _value: Any, dataType: DataType): Unit = {
93+
val value = dataType match {
94+
case FloatType => FLOAT_NORMALIZER.apply(_value)
95+
case DoubleType => DOUBLE_NORMALIZER.apply(_value)
96+
case _ => _value
97+
}
9298
// Create the hashed value 'x'.
9399
val x = XxHash64Function.hash(value, dataType, 42L)
94100

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,4 +554,94 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
554554
checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.NaN)), false)
555555
checkEvaluation(GreaterThan(Literal(0.0F), Literal(-0.0F)), false)
556556
}
557+
558+
test("SPARK-32110: compare special double/float values in array") {
559+
def createUnsafeDoubleArray(d: Double): Literal = {
560+
Literal(UnsafeArrayData.fromPrimitiveArray(Array(d)), ArrayType(DoubleType))
561+
}
562+
def createSafeDoubleArray(d: Double): Literal = {
563+
Literal(new GenericArrayData(Array(d)), ArrayType(DoubleType))
564+
}
565+
def createUnsafeFloatArray(d: Double): Literal = {
566+
Literal(UnsafeArrayData.fromPrimitiveArray(Array(d.toFloat)), ArrayType(FloatType))
567+
}
568+
def createSafeFloatArray(d: Double): Literal = {
569+
Literal(new GenericArrayData(Array(d.toFloat)), ArrayType(FloatType))
570+
}
571+
def checkExpr(
572+
exprBuilder: (Expression, Expression) => Expression,
573+
left: Double,
574+
right: Double,
575+
expected: Any): Unit = {
576+
// test double
577+
checkEvaluation(
578+
exprBuilder(createUnsafeDoubleArray(left), createUnsafeDoubleArray(right)), expected)
579+
checkEvaluation(
580+
exprBuilder(createUnsafeDoubleArray(left), createSafeDoubleArray(right)), expected)
581+
checkEvaluation(
582+
exprBuilder(createSafeDoubleArray(left), createSafeDoubleArray(right)), expected)
583+
// test float
584+
checkEvaluation(
585+
exprBuilder(createUnsafeFloatArray(left), createUnsafeFloatArray(right)), expected)
586+
checkEvaluation(
587+
exprBuilder(createUnsafeFloatArray(left), createSafeFloatArray(right)), expected)
588+
checkEvaluation(
589+
exprBuilder(createSafeFloatArray(left), createSafeFloatArray(right)), expected)
590+
}
591+
592+
checkExpr(EqualTo, Double.NaN, Double.NaN, true)
593+
checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false)
594+
checkExpr(EqualTo, 0.0, -0.0, true)
595+
checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true)
596+
checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
597+
checkExpr(GreaterThan, 0.0, -0.0, false)
598+
}
599+
600+
test("SPARK-32110: compare special double/float values in struct") {
601+
def createUnsafeDoubleRow(d: Double): Literal = {
602+
val dt = new StructType().add("d", "double")
603+
val converter = UnsafeProjection.create(dt)
604+
val unsafeRow = converter.apply(InternalRow(d))
605+
Literal(unsafeRow, dt)
606+
}
607+
def createSafeDoubleRow(d: Double): Literal = {
608+
Literal(InternalRow(d), new StructType().add("d", "double"))
609+
}
610+
def createUnsafeFloatRow(d: Double): Literal = {
611+
val dt = new StructType().add("f", "float")
612+
val converter = UnsafeProjection.create(dt)
613+
val unsafeRow = converter.apply(InternalRow(d.toFloat))
614+
Literal(unsafeRow, dt)
615+
}
616+
def createSafeFloatRow(d: Double): Literal = {
617+
Literal(InternalRow(d.toFloat), new StructType().add("f", "float"))
618+
}
619+
def checkExpr(
620+
exprBuilder: (Expression, Expression) => Expression,
621+
left: Double,
622+
right: Double,
623+
expected: Any): Unit = {
624+
// test double
625+
checkEvaluation(
626+
exprBuilder(createUnsafeDoubleRow(left), createUnsafeDoubleRow(right)), expected)
627+
checkEvaluation(
628+
exprBuilder(createUnsafeDoubleRow(left), createSafeDoubleRow(right)), expected)
629+
checkEvaluation(
630+
exprBuilder(createSafeDoubleRow(left), createSafeDoubleRow(right)), expected)
631+
// test float
632+
checkEvaluation(
633+
exprBuilder(createUnsafeFloatRow(left), createUnsafeFloatRow(right)), expected)
634+
checkEvaluation(
635+
exprBuilder(createUnsafeFloatRow(left), createSafeFloatRow(right)), expected)
636+
checkEvaluation(
637+
exprBuilder(createSafeFloatRow(left), createSafeFloatRow(right)), expected)
638+
}
639+
640+
checkExpr(EqualTo, Double.NaN, Double.NaN, true)
641+
checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false)
642+
checkExpr(EqualTo, 0.0, -0.0, true)
643+
checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true)
644+
checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
645+
checkExpr(GreaterThan, 0.0, -0.0, false)
646+
}
557647
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20+
import java.lang.{Double => JDouble}
2021
import java.util.Random
2122

2223
import scala.collection.mutable
2324

2425
import org.apache.spark.SparkFunSuite
2526
import org.apache.spark.sql.catalyst.InternalRow
2627
import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificInternalRow}
27-
import org.apache.spark.sql.types.{DataType, IntegerType}
28+
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType}
2829

2930
class HyperLogLogPlusPlusSuite extends SparkFunSuite {
3031

@@ -153,4 +154,25 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite {
153154
// Check if the buffers are equal.
154155
assert(buffer2 == buffer1a, "Buffers should be equal")
155156
}
157+
158+
test("SPARK-32110: add 0.0 and -0.0") {
159+
val (hll, input, buffer) = createEstimator(0.05, DoubleType)
160+
input.setDouble(0, 0.0)
161+
hll.update(buffer, input)
162+
input.setDouble(0, -0.0)
163+
hll.update(buffer, input)
164+
evaluateEstimate(hll, buffer, 1);
165+
}
166+
167+
test("SPARK-32110: add NaN") {
168+
val (hll, input, buffer) = createEstimator(0.05, DoubleType)
169+
input.setDouble(0, Double.NaN)
170+
hll.update(buffer, input)
171+
val specialNaN = JDouble.longBitsToDouble(0x7ff1234512345678L)
172+
assert(JDouble.isNaN(specialNaN))
173+
assert(JDouble.doubleToRawLongBits(Double.NaN) != JDouble.doubleToRawLongBits(specialNaN))
174+
input.setDouble(0, specialNaN)
175+
hll.update(buffer, input)
176+
evaluateEstimate(hll, buffer, 1);
177+
}
156178
}

0 commit comments

Comments
 (0)