Skip to content

Commit f99a612

Browse files
committed
Fix bugs in string prefix comparison.
1 parent 9d00afc commit f99a612

File tree

2 files changed

+46
-17
lines changed

2 files changed

+46
-17
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import com.google.common.base.Charsets;
2121
import com.google.common.primitives.Longs;
22+
import com.google.common.primitives.UnsignedBytes;
2223

2324
import org.apache.spark.annotation.Private;
2425
import org.apache.spark.unsafe.types.UTF8String;
@@ -35,36 +36,33 @@ private PrefixComparators() {}
3536
public static final class StringPrefixComparator extends PrefixComparator {
3637
@Override
3738
public int compare(long aPrefix, long bPrefix) {
38-
// TODO: this can certainly be done more efficiently
39+
// TODO: can done more efficiently
3940
byte[] a = Longs.toByteArray(aPrefix);
4041
byte[] b = Longs.toByteArray(bPrefix);
4142
for (int i = 0; i < 8; i++) {
42-
if (a[i] == b[i]) continue;
43-
if (a[i] > b[i]) return -1;
44-
else if (a[i] < b[i]) return 1;
43+
int c = UnsignedBytes.compare(a[i], b[i]);
44+
if (c != 0) return c;
4545
}
4646
return 0;
4747
}
4848

49-
public long computePrefix(UTF8String value) {
50-
// TODO: this can certainly be done more efficiently
51-
return value == null ? 0L : computePrefix(value.toString());
52-
}
53-
54-
public long computePrefix(String value) {
55-
// TODO: this can certainly be done more efficiently
56-
if (value == null || value.length() == 0) {
49+
public long computePrefix(byte[] bytes) {
50+
if (bytes == null) {
5751
return 0L;
5852
} else {
59-
String first4Chars = value.substring(0, Math.min(3, value.length() - 1));
60-
byte[] utf16Bytes = first4Chars.getBytes(Charsets.UTF_16);
6153
byte[] padded = new byte[8];
62-
if (utf16Bytes.length < 8) {
63-
System.arraycopy(utf16Bytes, 0, padded, 0, utf16Bytes.length);
64-
}
54+
System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8));
6555
return Longs.fromByteArray(padded);
6656
}
6757
}
58+
59+
public long computePrefix(String value) {
60+
return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8));
61+
}
62+
63+
public long computePrefix(UTF8String value) {
64+
return value == null ? 0L : computePrefix(value.getBytes());
65+
}
6866
}
6967

7068
/**
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package org.apache.spark.util.collection.unsafe.sort
2+
3+
import org.scalatest.prop.PropertyChecks
4+
5+
import org.apache.spark.SparkFunSuite
6+
7+
class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
8+
9+
test("String prefix comparator") {
10+
11+
def testPrefixComparison(s1: String, s2: String): Unit = {
12+
val s1Prefix = PrefixComparators.STRING.computePrefix(s1)
13+
val s2Prefix = PrefixComparators.STRING.computePrefix(s2)
14+
val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)
15+
assert(
16+
(prefixComparisonResult == 0) ||
17+
(prefixComparisonResult < 0 && s1 < s2) ||
18+
(prefixComparisonResult > 0 && s1 > s2))
19+
}
20+
21+
val regressionTests = Table(
22+
("s1", "s2"),
23+
("abc", "世界"),
24+
("你好", "世界"),
25+
("你好123", "你好122")
26+
)
27+
28+
forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
29+
forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
30+
}
31+
}

0 commit comments

Comments
 (0)