Skip to content

Commit b0ac3a4

Browse files
Merge pull request #1 from megatron-me-uk/megatron-me-uk-patch-1
add optional argument 'mode' for rdd.pipe
2 parents a0c0161 + a307d13 commit b0ac3a4

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

python/pyspark/rdd.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,13 +687,25 @@ def groupBy(self, f, numPartitions=None):
687687
return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
688688

689689
@ignore_unicode_prefix
690-
def pipe(self, command, env={}):
690+
def pipe(self, command, env={}, mode='permissive'):
691691
"""
692692
Return an RDD created by piping elements to a forked external process.
693693
694694
>>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect()
695695
[u'1', u'2', u'', u'3']
696696
"""
697+
if mode == 'permissive':
698+
def fail_condition(x):
699+
return False
700+
elif mode == 'strict':
701+
def fail_condition(x):
702+
return x == 0
703+
elif mode == 'grep':
704+
def fail_condition(x):
705+
return x == 0 or x == 1
706+
else:
707+
raise ValueError("mode must be one of 'permissive', 'strict' or 'grep'.")
708+
697709
def func(iterator):
698710
pipe = Popen(
699711
shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
@@ -707,7 +719,7 @@ def pipe_objs(out):
707719

708720
def check_return_code():
709721
pipe.wait()
710-
if pipe.returncode:
722+
if fail_condition(pipe.returncode):
711723
raise Exception("Pipe function `%s' exited "
712724
"with error code %d" % (command, pipe.returncode))
713725
else:

python/pyspark/tests.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,10 +878,14 @@ def test_pipe_functions(self):
878878
data = ['1', '2', '3']
879879
rdd = self.sc.parallelize(data)
880880
with QuietTest(self.sc):
881-
self.assertRaises(Py4JJavaError, rdd.pipe('cc').collect)
881+
self.assertEqual([], rdd.pipe('cc').collect())
882+
self.assertRaises(Py4JJavaError, rdd.pipe('cc', mode='strict').collect)
882883
result = rdd.pipe('cat').collect()
883884
result.sort()
884885
[self.assertEqual(x, y) for x, y in zip(data, result)]
886+
self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', mode='strict').collect)
887+
self.assertEqual([], rdd.pipe('grep 4').collect())
888+
self.assertEqual([], rdd.pipe('grep 4', mode='grep').collect())
885889

886890

887891
class ProfilerTests(PySparkTestCase):

0 commit comments

Comments
 (0)