diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 25a650ffbb7446..1fe9743a331e36 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -105,7 +105,7 @@ def __init__(self, coro, *, loop=None, name=None): else: self._name = str(name) - self._cancel_requested = False + self._num_cancels_requested = 0 self._must_cancel = False self._fut_waiter = None self._coro = coro @@ -202,9 +202,9 @@ def cancel(self, msg=None): self._log_traceback = False if self.done(): return False - if self._cancel_requested: + self._num_cancels_requested += 1 + if self._num_cancels_requested > 1: return False - self._cancel_requested = True if self._fut_waiter is not None: if self._fut_waiter.cancel(msg=msg): # Leave self._fut_waiter; it may be a Task that @@ -216,15 +216,13 @@ def cancel(self, msg=None): self._cancel_message = msg return True - def cancelling(self): - return self._cancel_requested + def cancelling(self) -> int: + return self._num_cancels_requested - def uncancel(self): - if self._cancel_requested: - self._cancel_requested = False - return True - else: - return False + def uncancel(self) -> int: + if self._num_cancels_requested > 0: + self._num_cancels_requested -= 1 + return self._num_cancels_requested def __step(self, exc=None): if self.done(): diff --git a/Lib/asyncio/timeouts.py b/Lib/asyncio/timeouts.py index 9a9b56f9df622b..dd5d2beb4f683f 100644 --- a/Lib/asyncio/timeouts.py +++ b/Lib/asyncio/timeouts.py @@ -90,11 +90,11 @@ async def __aexit__( if self._state is _State.CANCELLING: self._state = _State.CANCELLED - counter = _COUNTERS[self._task] - if counter == 1: + + if self._task.uncancel() == 0: + # Since there are no outstanding cancel requests, we're + # handling this. raise TimeoutError - else: - _COUNTERS[self._task] = counter - 1 elif self._state is _State.ENTERED: self._state = _State.EXITED @@ -106,19 +106,6 @@ def _on_timeout(self) -> None: self._state = _State.CANCELLING # drop the reference early self._timeout_handler = None - counter = _COUNTERS.get(self._task) - if counter is None: - _COUNTERS[self._task] = 1 - self._task.add_done_callback(_drop_task) - else: - _COUNTERS[self._task] = counter + 1 - - -_COUNTERS: Dict[tasks.Task, int] = {} - - -def _drop_task(task: tasks.Task) -> None: - del _COUNTERS[task] def timeout(delay: Optional[float]) -> Timeout: diff --git a/Lib/test/test_asyncio/test_timeouts.py b/Lib/test/test_asyncio/test_timeouts.py index ce985e75ac6301..e08c258cd47240 100644 --- a/Lib/test/test_asyncio/test_timeouts.py +++ b/Lib/test/test_asyncio/test_timeouts.py @@ -1,6 +1,7 @@ """Tests for asyncio/timeouts.py""" import unittest +import time import asyncio from asyncio import tasks @@ -16,6 +17,17 @@ class BaseTimeoutTests: def new_task(self, loop, coro, name='TestTask'): return self.__class__.Task(coro, loop=loop, name=name) + def _setupAsyncioLoop(self): + assert self._asyncioTestLoop is None, 'asyncio test loop already initialized' + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.set_debug(True) + self._asyncioTestLoop = loop + loop.set_task_factory(self.new_task) + fut = loop.create_future() + self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut)) + loop.run_until_complete(fut) + async def test_timeout_basic(self): with self.assertRaises(TimeoutError): async with asyncio.timeout(0.01) as cm: @@ -137,6 +149,50 @@ async def outer() -> None: assert not task.cancelled() assert task.done() + async def test_nested_timeouts(self): + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.1) as outer: + try: + async with asyncio.timeout(0.2) as inner: + await asyncio.sleep(10) + except asyncio.TimeoutError: + # Pretend we start a super long operation here. + self.assertTrue(False) + + async def test_nested_timeouts_concurrent(self): + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.002): + try: + async with asyncio.timeout(0.003): + # Pretend we crunch some numbers. + time.sleep(0.005) + await asyncio.sleep(1) + except asyncio.TimeoutError: + pass + + async def test_nested_timeouts_loop_busy(self): + """ + After the inner timeout is an expensive operation which should + be stopped by the outer timeout. + + Note: this fails for now. + """ + start = time.perf_counter() + try: + async with asyncio.timeout(0.002) as outer: + try: + async with asyncio.timeout(0.001) as inner: + # Pretend the loop is busy for a while. + time.sleep(0.010) + await asyncio.sleep(0.001) + except asyncio.TimeoutError: + # This sleep should be interrupted. + await asyncio.sleep(0.050) + except asyncio.TimeoutError: + pass + took = time.perf_counter() - start + self.assertTrue(took <= 0.015) + @unittest.skipUnless(hasattr(tasks, '_CTask'), 'requires the C _asyncio module') diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 6725e2eba79bc2..51e1e083f44ae5 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -91,7 +91,7 @@ typedef struct { PyObject *task_context; int task_must_cancel; int task_log_destroy_pending; - int task_cancel_requested; + int task_num_cancels_requested; } TaskObj; typedef struct { @@ -2040,7 +2040,7 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, Py_CLEAR(self->task_fut_waiter); self->task_must_cancel = 0; self->task_log_destroy_pending = 1; - self->task_cancel_requested = 0; + self->task_num_cancels_requested = 0; Py_INCREF(coro); Py_XSETREF(self->task_coro, coro); @@ -2207,10 +2207,10 @@ _asyncio_Task_cancel_impl(TaskObj *self, PyObject *msg) Py_RETURN_FALSE; } - if (self->task_cancel_requested) { + self->task_num_cancels_requested += 1; + if (self->task_num_cancels_requested > 1) { Py_RETURN_FALSE; } - self->task_cancel_requested = 1; if (self->task_fut_waiter) { PyObject *res; @@ -2256,12 +2256,7 @@ _asyncio_Task_cancelling_impl(TaskObj *self) /*[clinic end generated code: output=803b3af96f917d7e input=c50e50f9c3ca4676]*/ /*[clinic end generated code]*/ { - if (self->task_cancel_requested) { - Py_RETURN_TRUE; - } - else { - Py_RETURN_FALSE; - } + return PyLong_FromLong(self->task_num_cancels_requested); } /*[clinic input] @@ -2280,13 +2275,10 @@ _asyncio_Task_uncancel_impl(TaskObj *self) /*[clinic end generated code: output=58184d236a817d3c input=5db95e28fcb6f7cd]*/ /*[clinic end generated code]*/ { - if (self->task_cancel_requested) { - self->task_cancel_requested = 0; - Py_RETURN_TRUE; - } - else { - Py_RETURN_FALSE; + if (self->task_num_cancels_requested > 0) { + self->task_num_cancels_requested -= 1; } + return PyLong_FromLong(self->task_num_cancels_requested); } /*[clinic input]