Skip to content

Commit e162822

Browse files
committed
added gorupByKey testcase
1 parent 97742fe commit e162822

File tree

1 file changed

+54
-16
lines changed

1 file changed

+54
-16
lines changed

python/pyspark/streaming_tests.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def test_func(dstream):
275275
self.assertEqual(expected_output, output)
276276

277277
def test_mapPartitions_batch(self):
278-
"""Basic operation test for DStream.mapPartitions with batch deserializer"""
278+
"""Basic operation test for DStream.mapPartitions with batch deserializer."""
279279
test_input = [range(1, 5), range(5, 9), range(9, 13)]
280280
numSlices = 2
281281

@@ -288,7 +288,7 @@ def f(iterator):
288288
self.assertEqual(expected_output, output)
289289

290290
def test_mapPartitions_unbatch(self):
291-
"""Basic operation test for DStream.mapPartitions with unbatch deserializer"""
291+
"""Basic operation test for DStream.mapPartitions with unbatch deserializer."""
292292
test_input = [range(1, 4), range(4, 7), range(7, 10)]
293293
numSlices = 2
294294

@@ -301,8 +301,8 @@ def f(iterator):
301301
self.assertEqual(expected_output, output)
302302

303303
def test_countByValue_batch(self):
304-
"""Basic operation test for DStream.countByValue with batch deserializer"""
305-
test_input = [range(1, 5) + range(1,5), range(5, 7) + range(5, 9), ["a"] * 2 + ["b"] + [""] ]
304+
"""Basic operation test for DStream.countByValue with batch deserializer."""
305+
test_input = [range(1, 5) + range(1,5), range(5, 7) + range(5, 9), ["a", "a", "b", ""]]
306306

307307
def test_func(dstream):
308308
return dstream.countByValue()
@@ -315,7 +315,7 @@ def test_func(dstream):
315315
self.assertEqual(expected_output, output)
316316

317317
def test_countByValue_unbatch(self):
318-
"""Basic operation test for DStream.countByValue with unbatch deserializer"""
318+
"""Basic operation test for DStream.countByValue with unbatch deserializer."""
319319
test_input = [range(1, 4), [1, 1, ""], ["a", "a", "b"]]
320320

321321
def test_func(dstream):
@@ -328,30 +328,72 @@ def test_func(dstream):
328328
self._sort_result_based_on_key(result)
329329
self.assertEqual(expected_output, output)
330330

331+
def test_groupByKey_batch(self):
332+
"""Basic operation test for DStream.groupByKey with batch deserializer."""
333+
test_input = [range(1, 5), [1, 1, 1, 2, 2, 3], ["a", "a", "b", "", "", ""]]
334+
def test_func(dstream):
335+
return dstream.map(lambda x: (x,1)).groupByKey()
336+
expected_output = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])],
337+
[(1, [1, 1, 1]), (2, [1, 1]), (3, [1])],
338+
[("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]]
339+
scattered_output = self._run_stream(test_input, test_func, expected_output)
340+
output = self._convert_iter_value_to_list(scattered_output)
341+
for result in (output, expected_output):
342+
self._sort_result_based_on_key(result)
343+
self.assertEqual(expected_output, output)
344+
345+
def test_groupByKey_unbatch(self):
346+
"""Basic operation test for DStream.groupByKey with unbatch deserializer."""
347+
test_input = [range(1, 4), [1, 1, ""], ["a", "a", "b"]]
348+
def test_func(dstream):
349+
return dstream.map(lambda x: (x,1)).groupByKey()
350+
expected_output = [[(1, [1]), (2, [1]), (3, [1])],
351+
[(1, [1, 1]), ("", [1])],
352+
[("a", [1, 1]), ("b", [1])]]
353+
scattered_output = self._run_stream(test_input, test_func, expected_output)
354+
output = self._convert_iter_value_to_list(scattered_output)
355+
for result in (output, expected_output):
356+
self._sort_result_based_on_key(result)
357+
self.assertEqual(expected_output, output)
358+
359+
def _convert_iter_value_to_list(self, outputs):
360+
"""Return key value pair list. Value is converted to iterator to list."""
361+
result = list()
362+
for output in outputs:
363+
result.append(map(lambda (x, y): (x, list(y)), output))
364+
return result
365+
331366
def _sort_result_based_on_key(self, outputs):
367+
"""Sort the list base onf first value."""
332368
for output in outputs:
333369
output.sort(key=lambda x: x[0])
334370

335371
def _run_stream(self, test_input, test_func, expected_output, numSlices=None):
336-
"""Start stream and return the output"""
337-
# Generate input stream with user-defined input
372+
"""
373+
Start stream and return the output.
374+
@param test_input: dataset for the test. This should be list of lists.
375+
@param test_func: wrapped test_function. This function should return PythonDstream object.
376+
@param expexted_output: expected output for this testcase.
377+
@param numSlices: the number of slices in the rdd in the dstream.
378+
"""
379+
# Generate input stream with user-defined input.
338380
numSlices = numSlices or self.numInputPartitions
339381
test_input_stream = self.ssc._testInputStream(test_input, numSlices)
340-
# Apply test function to stream
382+
# Apply test function to stream.
341383
test_stream = test_func(test_input_stream)
342-
# Add job to get output from stream
384+
# Add job to get output from stream.
343385
test_stream._test_output(self.result)
344386
self.ssc.start()
345387

346388
start_time = time.time()
347-
# loop until get the result from stream
389+
# Loop until get the expected the number of the result from the stream.
348390
while True:
349391
current_time = time.time()
350-
# check time out
392+
# Check time out.
351393
if (current_time - start_time) > self.timeout:
352394
break
353395
self.ssc.awaitTermination(50)
354-
# check if the output is the same length of expexted output
396+
# Check if the output is the same length of expexted output.
355397
if len(expected_output) == len(self.result):
356398
break
357399

@@ -372,9 +414,5 @@ def tearDownClass(cls):
372414
PySparkStreamingTestCase.tearDownClass()
373415

374416

375-
376-
377-
378-
379417
if __name__ == "__main__":
380418
unittest.main()

0 commit comments

Comments
 (0)