diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py index d591d0ebab481b..b8b6a2a4c97e4a 100644 --- a/Lib/asyncio/queues.py +++ b/Lib/asyncio/queues.py @@ -101,6 +101,8 @@ def _format(self): result += f' _putters[{len(self._putters)}]' if self._unfinished_tasks: result += f' tasks={self._unfinished_tasks}' + if self._shutdown_state is not _QueueState.alive: + result += f' shutdown={self._shutdown_state.value}' return result def qsize(self): @@ -133,7 +135,7 @@ async def put(self, item): Put an item into the queue. If the queue is full, wait until a free slot is available before adding item. """ - if self._shutdown_state != _QueueState.alive: + if self._shutdown_state is not _QueueState.alive: raise QueueShutDown while self.full(): putter = self._get_loop().create_future() @@ -154,7 +156,7 @@ async def put(self, item): # the call. Wake up the next in line. self._wakeup_next(self._putters) raise - if self._shutdown_state != _QueueState.alive: + if self._shutdown_state is not _QueueState.alive: raise QueueShutDown return self.put_nowait(item) @@ -163,7 +165,7 @@ def put_nowait(self, item): If no free slot is immediately available, raise QueueFull. """ - if self._shutdown_state != _QueueState.alive: + if self._shutdown_state is not _QueueState.alive: raise QueueShutDown if self.full(): raise QueueFull @@ -177,10 +179,10 @@ async def get(self): If queue is empty, wait until an item is available. """ - if self._shutdown_state == _QueueState.shutdown_immediate: + if self._shutdown_state is _QueueState.shutdown_immediate: raise QueueShutDown while self.empty(): - if self._shutdown_state != _QueueState.alive: + if self._shutdown_state is not _QueueState.alive: raise QueueShutDown getter = self._get_loop().create_future() self._getters.append(getter) @@ -200,7 +202,7 @@ async def get(self): # the call. Wake up the next in line. self._wakeup_next(self._getters) raise - if self._shutdown_state == _QueueState.shutdown_immediate: + if self._shutdown_state is _QueueState.shutdown_immediate: raise QueueShutDown return self.get_nowait() @@ -210,10 +212,10 @@ def get_nowait(self): Return an item if one is immediately available, else raise QueueEmpty. """ if self.empty(): - if self._shutdown_state != _QueueState.alive: + if self._shutdown_state is not _QueueState.alive: raise QueueShutDown raise QueueEmpty - elif self._shutdown_state == _QueueState.shutdown_immediate: + elif self._shutdown_state is _QueueState.shutdown_immediate: raise QueueShutDown item = self._get() self._wakeup_next(self._putters) @@ -233,6 +235,8 @@ def task_done(self): Raises ValueError if called more times than there were items placed in the queue. """ + if self._shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: + raise QueueShutDown if self._unfinished_tasks <= 0: raise ValueError('task_done() called too many times') self._unfinished_tasks -= 1 @@ -247,8 +251,12 @@ async def join(self): indicate that the item was retrieved and all work on it is complete. When the count of unfinished tasks drops to zero, join() unblocks. """ + if self._shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: + raise QueueShutDown if self._unfinished_tasks > 0: await self._finished.wait() + if self._shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: + raise QueueShutDown def shutdown(self, immediate=False): """Shut-down the queue, making queue gets and puts raise. @@ -259,20 +267,24 @@ def shutdown(self, immediate=False): All blocked callers of put() will be unblocked, and also get() and join() if 'immediate'. The QueueShutDown exception is raised. """ + if self._shutdown_state is _QueueState.shutdown_immediate: + return + # here _shutdown_state is ALIVE or SHUTDOWN if immediate: self._shutdown_state = _QueueState.shutdown_immediate while self._getters: getter = self._getters.popleft() if not getter.done(): getter.set_result(None) - else: + # Release 'blocked' tasks/coros via `.join()` + self._finished.set() + elif self._shutdown_state is _QueueState.alive: # here self._shutdown_state = _QueueState.shutdown while self._putters: putter = self._putters.popleft() if not putter.done(): putter.set_result(None) - # Release 'joined' tasks/coros - self._finished.set() + class PriorityQueue(Queue): """A subclass of Queue; retrieves entries in priority order (lowest first). diff --git a/Lib/multiprocessing/queues.py b/Lib/multiprocessing/queues.py index 5220504369937d..8f226c26edd9f3 100644 --- a/Lib/multiprocessing/queues.py +++ b/Lib/multiprocessing/queues.py @@ -56,7 +56,7 @@ def __init__(self, maxsize=0, *, ctx): self._ignore_epipe = False self._reset() self._shutdown_state = context._default_context.Value( - ctypes.c_uint8, lock=self._rlock + ctypes.c_uint8, _queue_alive, lock=True ) if sys.platform != 'win32': @@ -65,11 +65,13 @@ def __init__(self, maxsize=0, *, ctx): def __getstate__(self): context.assert_spawning(self) return (self._ignore_epipe, self._maxsize, self._reader, self._writer, - self._rlock, self._wlock, self._sem, self._opid) + self._rlock, self._wlock, self._sem, self._opid, + self._shutdown_state) def __setstate__(self, state): (self._ignore_epipe, self._maxsize, self._reader, self._writer, - self._rlock, self._wlock, self._sem, self._opid) = state + self._rlock, self._wlock, self._sem, self._opid, + self._shutdown_state) = state self._reset() def _after_fork(self): @@ -100,21 +102,19 @@ def put(self, obj, block=True, timeout=None): raise Full with self._notempty: - if self._shutdown_state.value != _queue_alive: - raise ShutDown if self._thread is None: self._start_thread() self._buffer.append(obj) self._notempty.notify() def get(self, block=True, timeout=None): - if self._shutdown_state.value == _queue_shutdown_immediate: - raise ShutDown if self._closed: raise ValueError(f"Queue {self!r} is closed") + if self._shutdown_state.value == _queue_shutdown_immediate: + raise ShutDown if block and timeout is None: with self._rlock: - if self._shutdown_state.value != _queue_alive: + if self._shutdown_state.value == _queue_shutdown_immediate: raise ShutDown res = self._recv_bytes() self._sem.release() @@ -127,10 +127,10 @@ def get(self, block=True, timeout=None): if block: timeout = deadline - time.monotonic() if not self._poll(timeout): - if self._shutdown_state.value != _queue_alive: + if self._shutdown_state.value == _queue_shutdown_immediate: raise ShutDown raise Empty - if self._shutdown_state.value != _queue_alive : + if self._shutdown_state.value == _queue_shutdown_immediate: raise ShutDown elif not self._poll(): raise Empty @@ -138,7 +138,7 @@ def get(self, block=True, timeout=None): self._sem.release() finally: self._rlock.release() - if self._shutdown_state.value == _queue_shutdown: + if self._shutdown_state.value == _queue_shutdown_immediate: raise ShutDown # unserialize the data after having released the lock return _ForkingPickler.loads(res) @@ -159,6 +159,17 @@ def get_nowait(self): def put_nowait(self, obj): return self.put(obj, False) + def shutdown(self, immediate=True): + with self._shutdown_state.get_lock(): + if self._shutdown_state.value == _queue_shutdown_immediate: + return + if immediate: + self._shutdown_state.value = _queue_shutdown_immediate + with self._notempty: + self._notempty.notify_all() # cf from @EpicWink + else: + self._shutdown_state.value = _queue_shutdown + def close(self): self._closed = True close = self._close @@ -340,6 +351,8 @@ def __setstate__(self, state): def put(self, obj, block=True, timeout=None): if self._closed: raise ValueError(f"Queue {self!r} is closed") + if self._shutdown_state.value != _queue_alive: + raise ShutDown if not self._sem.acquire(block, timeout): raise Full @@ -352,6 +365,8 @@ def put(self, obj, block=True, timeout=None): def task_done(self): with self._cond: + if self._shutdown_state.value == _queue_shutdown_immediate: + raise ShutDown if not self._unfinished_tasks.acquire(False): raise ValueError('task_done() called too many times') if self._unfinished_tasks._semlock._is_zero(): @@ -359,10 +374,19 @@ def task_done(self): def join(self): with self._cond: - if self._shutdown_state.value == _queue_shutdown_immediate: - return + if self._shutdown_state.value != _queue_alive: + raise ShutDown if not self._unfinished_tasks._semlock._is_zero(): self._cond.wait() + if self._shutdown_state.value == _queue_shutdown_immediate: + raise ShutDown + + def shutdown(self, immediate=True): + initial_shutdown = self._shutdown_state.value + super().shutdown(immediate) + if initial_shutdown == _queue_alive: + with self._cond: + self._cond.notify_all() # here to check YD # # Simplified Queue type -- really just a locked pipe diff --git a/Lib/queue.py b/Lib/queue.py index f08dbd47f188ee..933f1806205072 100644 --- a/Lib/queue.py +++ b/Lib/queue.py @@ -1,5 +1,6 @@ '''A multi-producer, multi-consumer queue.''' +import enum import threading import types from collections import deque @@ -29,9 +30,10 @@ class ShutDown(Exception): '''Raised when put/get with shut-down queue.''' -_queue_alive = "alive" -_queue_shutdown = "shutdown" -_queue_shutdown_immediate = "shutdown-immediate" +class _QueueState(enum.Enum): + ALIVE = "alive" + SHUTDOWN = "shutdown" + SHUTDOWN_IMMEDIATE = "shutdown-immediate" class Queue: @@ -64,7 +66,7 @@ def __init__(self, maxsize=0): self.unfinished_tasks = 0 # Queue shut-down state - self.shutdown_state = _queue_alive + self.shutdown_state = _QueueState.ALIVE def task_done(self): '''Indicate that a formerly enqueued task is complete. @@ -99,7 +101,7 @@ def join(self): ''' with self.all_tasks_done: while self.unfinished_tasks: - if self.shutdown_state == _queue_shutdown_immediate: + if self.shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: return self.all_tasks_done.wait() @@ -144,7 +146,7 @@ def put(self, item, block=True, timeout=None): is immediately available, else raise the Full exception ('timeout' is ignored in that case). ''' - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown with self.not_full: if self.maxsize > 0: @@ -154,7 +156,7 @@ def put(self, item, block=True, timeout=None): elif timeout is None: while self._qsize() >= self.maxsize: self.not_full.wait() - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") @@ -165,7 +167,7 @@ def put(self, item, block=True, timeout=None): if remaining <= 0.0: raise Full self.not_full.wait(remaining) - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown self._put(item) self.unfinished_tasks += 1 @@ -182,35 +184,35 @@ def get(self, block=True, timeout=None): available, else raise the Empty exception ('timeout' is ignored in that case). ''' - if self.shutdown_state == _queue_shutdown_immediate: + if self.shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: raise ShutDown with self.not_empty: if not block: if not self._qsize(): - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown raise Empty elif timeout is None: while not self._qsize(): - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown self.not_empty.wait() - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: endtime = time() + timeout while not self._qsize(): - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown remaining = endtime - time() if remaining <= 0.0: raise Empty self.not_empty.wait(remaining) - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown - if self.shutdown_state == _queue_shutdown_immediate: + if self.shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: raise ShutDown item = self._get() self.not_full.notify() @@ -242,8 +244,11 @@ def shutdown(self, immediate=False): and join() if 'immediate'. The ShutDown exception is raised. ''' with self.mutex: + if self.shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: + return + if immediate: - self.shutdown_state = _queue_shutdown_immediate + self.shutdown_state = _QueueState.SHUTDOWN_IMMEDIATE self.not_empty.notify_all() # set self.unfinished_tasks to 0 # to break the loop in 'self.join()' @@ -251,7 +256,7 @@ def shutdown(self, immediate=False): self.unfinished_tasks = 0 self.all_tasks_done.notify_all() else: - self.shutdown_state = _queue_shutdown + self.shutdown_state = _QueueState.SHUTDOWN self.not_full.notify_all() # Override these methods to implement other queue organizations diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py index 75b016f399a13b..16b30a5126feda 100644 --- a/Lib/test/test_asyncio/test_queues.py +++ b/Lib/test/test_asyncio/test_queues.py @@ -525,6 +525,104 @@ class PriorityQueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCa class _QueueShutdownTestMixin: q_class = None + async def _get(self, q, go, results): + await go.wait() + try: + msg = await q.get() + results.append(True) + return msg + except asyncio.QueueShutDown: + results.append(False) + return False + + async def _get_shutdown(self, q, go, results): + await go.wait() + try: + msg = await q.get() + results.append(False) + return msg + except asyncio.QueueShutDown: + results.append(True) + return False + + async def _get_nowait(self, q, go, results): + await go.wait() + try: + msg = q.get_nowait() + results.append(True) + return msg + except asyncio.QueueShutDown: + results.append(False) + return False + + async def _get_task_done(self, q, go, results): + await go.wait() + try: + msg = await q.get() + q.task_done() + results.append(True) + return msg + except asyncio.QueueShutDown: + results.append(False) + return False + + async def _get_nowait_shutdown(self, q, go, results): + await go.wait() + try: + msg = q.get_nowait() + results.append(False) + except asyncio.QueueShutDown: + results.append(True) + return True + + async def _put_shutdown(self, q, go, msg, results): + await go.wait() + try: + await q.put(msg) + results.append(False) + except asyncio.QueueShutDown: + results.append(True) + return msg + + async def _put_nowait_shutdown(self, q, go, msg, results): + await go.wait() + try: + q.put_nowait(msg) + results.append(False) + except asyncio.QueueShutDown: + results.append(True) + return msg + + async def _shutdown(self, q, go, immediate): + await asyncio.sleep(0.001) + q.shutdown(immediate) + await asyncio.sleep(0.001) + go.set() + await asyncio.sleep(0.001) + + async def _join(self, q, go, results): + await go.wait() + try: + await q.join() + results.append(True) + return True + except asyncio.QueueShutDown: + results.append(False) + return False + + async def _join_shutdown(self, q, go, results): + await go.wait() + try: + await q.join() + results.append(False) + return False + except asyncio.QueueShutDown: + results.append(True) + return True + except asyncio.CancelledError: + results.append(True) + raise + async def test_empty(self): q = self.q_class() q.shutdown() @@ -555,8 +653,11 @@ async def test_immediate(self): asyncio.QueueShutDown, msg="Didn't appear to shut-down queue" ): await q.get() - async def test_repr_shutdown(self): + + async def test_shutdown_repr(self): q = self.q_class() + self.assertNotIn("alive", repr(q)) + q.shutdown() self.assertIn("shutdown", repr(q)) @@ -564,89 +665,202 @@ async def test_repr_shutdown(self): q.shutdown(immediate=True) self.assertIn("shutdown-immediate", repr(q)) - async def test_get_shutdown_immediate(self): + async def test_shutdown_allowed_transitions(self): + # allowed transitions would be from alive via shutdown to immediate + q = self.q_class() + self.assertEqual("alive", q._shutdown_state.value) + + q.shutdown() + self.assertEqual("shutdown", q._shutdown_state.value) + + q.shutdown(immediate=True) + self.assertEqual("shutdown-immediate", q._shutdown_state.value) + + q.shutdown(immediate=False) + self.assertNotEqual("shutdown", q._shutdown_state.value) + + async def _shutdown_putters(self, immediate): + delay = 0.001 + q = self.q_class(2) results = [] - maxsize = 2 - delay = 1e-3 - - async def get_q(q): - try: - msg = await q.get() - results.append(False) - except asyncio.QueueShutDown: - results.append(True) - return True - - async def shutdown(q, delay, immediate): - await asyncio.sleep(delay) + await q.put("E") + await q.put("W") + # queue full + t = asyncio.create_task(q.put("Y")) + await asyncio.sleep(delay) + self.assertTrue(len(q._putters) == 1) + with self.assertRaises(asyncio.QueueShutDown): + # here `t` raises a QueueShuDown q.shutdown(immediate) - return True + await t + self.assertTrue(not q._putters) - q = self.q_class(maxsize) - t = [asyncio.create_task(get_q(q)) for _ in range(maxsize)] - t += [asyncio.create_task(shutdown(q, delay, True))] - res = await asyncio.gather(*t) + async def test_shutdown_putters_deque(self): + return await self._shutdown_putters(False) - self.assertEqual(results, [True]*maxsize) + async def test_shutdown_immediate_putters_deque(self): + return await self._shutdown_putters(True) - async def test_put_shutdown(self): - maxsize = 2 + async def _shutdown_getters(self, immediate): + delay = 0.001 + q = self.q_class(1) results = [] - delay = 1e-3 - - async def put_twice(q, delay, msg): - await q.put(msg) - await asyncio.sleep(delay) - try: - await q.put(msg+maxsize) - results.append(False) - except asyncio.QueueShutDown: - results.append(True) - return msg - - async def shutdown(q, delay, immediate): - await asyncio.sleep(delay) + await q.put("Y") + # queue full + asyncio.create_task(q.get()) + await asyncio.sleep(delay) + t = asyncio.create_task(q.get()) + await asyncio.sleep(delay) + self.assertTrue(len(q._getters) == 1) + if immediate: + # here `t` raises a QueueShuDown + with self.assertRaises(asyncio.QueueShutDown): + q.shutdown(immediate) + await t + self.assertTrue(not q._getters) + else: + # here `t` is always pending q.shutdown(immediate) + await asyncio.sleep(delay) + self.assertTrue(q._getters) - q = self.q_class(maxsize) - t = [asyncio.create_task(put_twice(q, delay, i+1)) for i in range(maxsize)] - t += [asyncio.create_task(shutdown(q, delay*2, False))] - res = await asyncio.gather(*t) + async def test_shutdown_getters_deque(self): + return await self._shutdown_getters(False) - self.assertEqual(results, [True]*maxsize) + async def test_shutdown_immediate_getters_deque(self): + return await self._shutdown_getters(True) - async def test_put_and_join_shutdown(self): - maxsize = 2 + async def _shutdown_get_nowait(self, immediate): + q = self.q_class(2) results = [] - delay = 1e-3 + go = asyncio.Event() + await q.put("Y") + await q.put("D") + nb = q.qsize() + # queue full + + if immediate: + coros = ( + (self._get_nowait_shutdown(q, go, results)), + (self._get_nowait_shutdown(q, go, results)), + ) + else: + coros = ( + (self._get_nowait(q, go, results)), + (self._get_nowait(q, go, results)), + ) + t = [] + for coro in coros: + t.append(asyncio.create_task(coro)) + t.append(asyncio.create_task(self._shutdown(q, go, immediate))) + res = await asyncio.gather(*t) - async def put_twice(q, delay, msg): - await q.put(msg) - await asyncio.sleep(delay) - try: - await q.put(msg+maxsize) - results.append(False) - except asyncio.QueueShutDown: - results.append(True) - return msg - - async def shutdown(q, delay, immediate): - await asyncio.sleep(delay) - q.shutdown(immediate) + self.assertEqual(results, [True]*len(coros)) + self.assertEqual(len(q._putters), 0) + if immediate: + self.assertEqual(len(q._getters), 0) + self.assertEqual(q._unfinished_tasks, nb) - async def join(q, delay): - await asyncio.sleep(delay) - await q.join() - results.append(True) - return True + async def test_shutdown_get_nowait(self): + return await self._shutdown_get_nowait(False) + + async def test_shutdown_immediate_get_nowait(self): + return await self._shutdown_get_nowait(True) - q = self.q_class(maxsize) - t = [asyncio.create_task(put_twice(q, delay, i+1)) for i in range(maxsize)] - t += [asyncio.create_task(shutdown(q, delay*2, True)), - asyncio.create_task(join(q, delay))] + async def test_shutdown_get_task_done_join(self, immediate=False): + q = self.q_class(2) + results = [] + go = asyncio.Event() + await q.put("Y") + await q.put("D") + self.assertEqual(q._unfinished_tasks, q.qsize()) + + # queue full + + coros = ( + (self._get_task_done(q, go, results)), + (self._get_task_done(q, go, results)), + (self._join(q, go, results)), + (self._join(q, go, results)), + ) + t = [] + for coro in coros: + t.append(asyncio.create_task(coro)) + t.append(asyncio.create_task(self._shutdown(q, go, False))) res = await asyncio.gather(*t) - self.assertEqual(results, [True]*(maxsize+1)) + self.assertEqual(results, [True]*len(coros)) + self.assertIn(t[0].result(), "YD") + self.assertIn(t[1].result(), "YD") + self.assertNotEqual(t[0].result(), t[1].result()) + self.assertEqual(q._unfinished_tasks, 0) + + async def _shutdown_put(self, immediate): + q = self.q_class() + results = [] + go = asyncio.Event() + # queue not empty + + coros = ( + (self._put_shutdown(q, go, "Y", results)), + (self._put_nowait_shutdown(q, go, "D", results)), + ) + t = [] + for coro in coros: + t.append(asyncio.create_task(coro)) + t.append(asyncio.create_task(self._shutdown(q, go, immediate))) + res = await asyncio.gather(*t) + + self.assertEqual(results, [True]*len(coros)) + + async def test_shutdown_put(self): + return await self._shutdown_put(False) + + async def test_shutdown_immediate_put(self): + return await self._shutdown_put(True) + + async def _shutdown_put_join(self, immediate): + q = self.q_class(2) + results = [] + go = asyncio.Event() + await q.put("Y") + await q.put("D") + nb = q.qsize() + # queue fulled + + async def _cancel_join_task(q, delay, t): + await asyncio.sleep(delay) + t.cancel() + await asyncio.sleep(0) + q._finished.set() + + coros = ( + (self._put_shutdown(q, go, "E", results)), + (self._put_nowait_shutdown(q, go, "W", results)), + (self._join_shutdown(q, go, results)), + ) + t = [] + for coro in coros: + t.append(asyncio.create_task(coro)) + t.append(asyncio.create_task(self._shutdown(q, go, immediate))) + if not immediate: + # Here calls `join` is a blocking operation + # so wait for a delay and cancel this blocked task + t.append(asyncio.create_task(_cancel_join_task(q, 0.01, t[2]))) + with self.assertRaises(asyncio.CancelledError) as e: + await asyncio.gather(*t) + else: + res = await asyncio.gather(*t) + + self.assertEqual(results, [True]*len(coros)) + self.assertTrue(q._finished.is_set()) + + async def test_shutdown_put_and_join(self): + return await self._shutdown_put_join(False) + + async def test_shutdown_immediate_put_and_join(self): + return await self._shutdown_put_join(True) + class QueueShutdownTests( _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py index 354299b9a5b16a..a2814ac40f2e18 100644 --- a/Lib/test/test_queue.py +++ b/Lib/test/test_queue.py @@ -244,37 +244,230 @@ def test_shrinking_queue(self): def test_shutdown_empty(self): q = self.type2test() q.shutdown() - try: + with self.assertRaises(self.queue.ShutDown): q.put("data") - self.fail("Didn't appear to shut-down queue") - except self.queue.ShutDown: - pass - try: + with self.assertRaises(self.queue.ShutDown): q.get() - self.fail("Didn't appear to shut-down queue") - except self.queue.ShutDown: - pass def test_shutdown_nonempty(self): q = self.type2test() q.put("data") q.shutdown() q.get() - try: + with self.assertRaises(self.queue.ShutDown): q.get() - self.fail("Didn't appear to shut-down queue") - except self.queue.ShutDown: - pass def test_shutdown_immediate(self): q = self.type2test() q.put("data") q.shutdown(immediate=True) - try: + with self.assertRaises(self.queue.ShutDown): q.get() - self.fail("Didn't appear to shut-down queue") - except self.queue.ShutDown: - pass + + def test_shutdown_transition(self): + # allowed transitions would be from alive via shutdown to immediate + q = self.type2test() + self.assertEqual("alive", q.shutdown_state.value) + + q.shutdown() + self.assertEqual("shutdown", q.shutdown_state.value) + + q.shutdown(immediate=True) + self.assertEqual("shutdown-immediate", q.shutdown_state.value) + + q.shutdown(immediate=False) + self.assertEqual("shutdown-immediate", q.shutdown_state.value) + + def test_shutdown_get(self): + q = self.type2test(2) + results = [] + go = threading.Event() + + def get_once(q, go): + go.wait() + try: + msg = q.get() + results.append(False) + except self.queue.ShutDown: + results.append(True) + return True + + thrds = ( + (get_once, (q, go)), + (get_once, (q, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + q.shutdown() + go.set() + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_put(self): + q = self.type2test(2) + results = [] + go = threading.Event() + q.put("Y") + q.put("D") + # queue fulled + + def put_once(q, msg, go): + go.wait() + try: + q.put(msg) + results.append(False) + except self.queue.ShutDown: + results.append(True) + return msg + + thrds = ( + (put_once, (q, 100, go)), + (put_once, (q, 200, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + q.shutdown() + go.set() + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def _shutdown_join(self, immediate): + q = self.type2test() + results = [] + go = threading.Event() + + def join(q, go): + go.wait() + q.join() + results.append(True) + + thrds = ( + (join, (q, go)), + (join, (q, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + go.set() + q.shutdown(immediate) + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_immediate_join(self): + return self._shutdown_join(True) + + def test_shutdown_join(self): + return self._shutdown_join(False) + + def _shutdown_put_and_join(self, immediate): + q = self.type2test(2) + results = [] + go = threading.Event() + q.put("Y") + q.put("D") + # queue fulled + + def put_once(q, msg, go): + go.wait() + try: + q.put(msg) + results.append(False) + except self.queue.ShutDown: + results.append(True) + return msg + + def join(q, go): + go.wait() + q.join() + results.append(True) + + thrds = ( + (put_once, (q, 100, go)), + (put_once, (q, 200, go)), + (join, (q, go)), + (join, (q, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + go.set() + q.shutdown(immediate) + if not immediate: + self.assertTrue(q.unfinished_tasks, 2) + for i in range(2): + thread = threading.Thread(target=q.task_done) + thread.start() + threads.append(thread) + + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_immediate_put_and_join(self): + return self._shutdown_put_and_join(True) + + def test_shutdown_put_and_join(self): + return self._shutdown_put_and_join(False) + + def _shutdown_get_and_join(self, immediate): + q = self.type2test() + results = [] + go = threading.Event() + + def get_once(q, go): + go.wait() + try: + msg = q.get() + results.append(False) + except self.queue.ShutDown: + results.append(True) + return True + + def join(q, go): + go.wait() + q.join() + results.append(True) + + thrds = ( + (get_once, (q, go)), + (get_once, (q, go)), + (join, (q, go)), + (join, (q, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + go.set() + q.shutdown(immediate) + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_immediate_get_and_join(self): + return self._shutdown_get_and_join(True) + + def test__shutdown_get_and_join(self): + return self._shutdown_get_and_join(False) class QueueTest(BaseQueueTestMixin):