|
43 | 43 | from pyspark.files import SparkFiles
|
44 | 44 | from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
|
45 | 45 | from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
|
| 46 | +from pyspark.sql import SQLContext, IntegerType |
46 | 47 |
|
47 | 48 | _have_scipy = False
|
48 | 49 | _have_numpy = False
|
@@ -424,6 +425,22 @@ def test_zip_with_different_number_of_items(self):
|
424 | 425 | self.assertEquals(a.count(), b.count())
|
425 | 426 | self.assertRaises(Exception, lambda: a.zip(b).count())
|
426 | 427 |
|
| 428 | + def test_count_approx_distinct(self): |
| 429 | + rdd = self.sc.parallelize(range(1000)) |
| 430 | + self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050) |
| 431 | + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050) |
| 432 | + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050) |
| 433 | + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050) |
| 434 | + |
| 435 | + rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) |
| 436 | + self.assertTrue(18 < rdd.countApproxDistinct() < 22) |
| 437 | + self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) |
| 438 | + self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) |
| 439 | + self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) |
| 440 | + |
| 441 | + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) |
| 442 | + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.5)) |
| 443 | + |
427 | 444 | def test_histogram(self):
|
428 | 445 | # empty
|
429 | 446 | rdd = self.sc.parallelize([])
|
@@ -537,6 +554,27 @@ def test_repartitionAndSortWithinPartitions(self):
|
537 | 554 | self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])
|
538 | 555 |
|
539 | 556 |
|
| 557 | +class TestSQL(PySparkTestCase): |
| 558 | + |
| 559 | + def setUp(self): |
| 560 | + PySparkTestCase.setUp(self) |
| 561 | + self.sqlCtx = SQLContext(self.sc) |
| 562 | + |
| 563 | + def test_udf(self): |
| 564 | + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) |
| 565 | + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() |
| 566 | + self.assertEqual(row[0], 5) |
| 567 | + |
| 568 | + def test_broadcast_in_udf(self): |
| 569 | + bar = {"a": "aa", "b": "bb", "c": "abc"} |
| 570 | + foo = self.sc.broadcast(bar) |
| 571 | + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') |
| 572 | + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() |
| 573 | + self.assertEqual("abc", res[0]) |
| 574 | + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() |
| 575 | + self.assertEqual("", res[0]) |
| 576 | + |
| 577 | + |
540 | 578 | class TestIO(PySparkTestCase):
|
541 | 579 |
|
542 | 580 | def test_stdout_redirection(self):
|
|
0 commit comments