Skip to content

Commit 921e914

Browse files
committed
2 parents 789ea21 + e2c901b commit 921e914

File tree

3 files changed

+51
-6
lines changed

3 files changed

+51
-6
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ abstract class RDD[T: ClassTag](
993993
*/
994994
@Experimental
995995
def countApproxDistinct(p: Int, sp: Int): Long = {
996-
require(p >= 4, s"p ($p) must be greater than 0")
996+
require(p >= 4, s"p ($p) must be at least 4")
997997
require(sp <= 32, s"sp ($sp) cannot be greater than 32")
998998
require(sp == 0 || p <= sp, s"p ($p) cannot be greater than sp ($sp)")
999999
val zeroCounter = new HyperLogLogPlus(p, sp)

python/pyspark/rdd.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def portable_hash(x):
6262
6363
>>> portable_hash(None)
6464
0
65-
>>> portable_hash((None, 1))
65+
>>> portable_hash((None, 1)) & 0xffffffff
6666
219750521
6767
"""
6868
if x is None:
@@ -72,7 +72,7 @@ def portable_hash(x):
7272
for i in x:
7373
h ^= portable_hash(i)
7474
h *= 1000003
75-
h &= 0xffffffff
75+
h &= sys.maxint
7676
h ^= len(x)
7777
if h == -1:
7878
h = -2
@@ -1942,7 +1942,7 @@ def _is_pickled(self):
19421942
return True
19431943
return False
19441944

1945-
def _to_jrdd(self):
1945+
def _to_java_object_rdd(self):
19461946
""" Return an JavaRDD of Object by unpickling
19471947
19481948
It will convert each Python object into Java object by Pyrolite, whenever the
@@ -1977,7 +1977,7 @@ def sumApprox(self, timeout, confidence=0.95):
19771977
>>> (rdd.sumApprox(1000) - r) / r < 0.05
19781978
True
19791979
"""
1980-
jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_jrdd()
1980+
jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd()
19811981
jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
19821982
r = jdrdd.sumApprox(timeout, confidence).getFinalValue()
19831983
return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())
@@ -1993,11 +1993,40 @@ def meanApprox(self, timeout, confidence=0.95):
19931993
>>> (rdd.meanApprox(1000) - r) / r < 0.05
19941994
True
19951995
"""
1996-
jrdd = self.map(float)._to_jrdd()
1996+
jrdd = self.map(float)._to_java_object_rdd()
19971997
jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
19981998
r = jdrdd.meanApprox(timeout, confidence).getFinalValue()
19991999
return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())
20002000

2001+
def countApproxDistinct(self, relativeSD=0.05):
2002+
"""
2003+
:: Experimental ::
2004+
Return approximate number of distinct elements in the RDD.
2005+
2006+
The algorithm used is based on streamlib's implementation of
2007+
"HyperLogLog in Practice: Algorithmic Engineering of a State
2008+
of The Art Cardinality Estimation Algorithm", available
2009+
<a href="http://dx.doi.org/10.1145/2452376.2452456">here</a>.
2010+
2011+
@param relativeSD Relative accuracy. Smaller values create
2012+
counters that require more space.
2013+
It must be greater than 0.000017.
2014+
2015+
>>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct()
2016+
>>> 950 < n < 1050
2017+
True
2018+
>>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct()
2019+
>>> 18 < n < 22
2020+
True
2021+
"""
2022+
if relativeSD < 0.000017:
2023+
raise ValueError("relativeSD should be greater than 0.000017")
2024+
if relativeSD > 0.37:
2025+
raise ValueError("relativeSD should be smaller than 0.37")
2026+
# the hash space in Java is 2^32
2027+
hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
2028+
return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)
2029+
20012030

20022031
class PipelinedRDD(RDD):
20032032

python/pyspark/tests.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,22 @@ def test_zip_with_different_number_of_items(self):
404404
self.assertEquals(a.count(), b.count())
405405
self.assertRaises(Exception, lambda: a.zip(b).count())
406406

407+
def test_count_approx_distinct(self):
408+
rdd = self.sc.parallelize(range(1000))
409+
self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050)
410+
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050)
411+
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050)
412+
self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050)
413+
414+
rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
415+
self.assertTrue(18 < rdd.countApproxDistinct() < 22)
416+
self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22)
417+
self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22)
418+
self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22)
419+
420+
self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001))
421+
self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.5))
422+
407423
def test_histogram(self):
408424
# empty
409425
rdd = self.sc.parallelize([])

0 commit comments

Comments
 (0)