|
1 | 1 | import asyncio
|
2 | 2 | import inspect
|
| 3 | +import signal |
3 | 4 | from concurrent.futures import Executor
|
4 | 5 | from logging import getLogger
|
5 | 6 | from time import time
|
@@ -334,6 +335,12 @@ async def listen(self) -> None: # pragma: no cover
|
334 | 335 | gr.start_soon(self.prefetcher, queue)
|
335 | 336 | gr.start_soon(self.runner, queue)
|
336 | 337 |
|
| 338 | + # Propagate cancellation to the prefetcher & runner |
| 339 | + def _cancel(*_: Any) -> None: |
| 340 | + gr.cancel_scope.cancel() |
| 341 | + |
| 342 | + signal.signal(signal.SIGINT, _cancel) |
| 343 | + |
337 | 344 | if self.on_exit is not None:
|
338 | 345 | self.on_exit(self)
|
339 | 346 |
|
@@ -361,9 +368,7 @@ async def prefetcher(
|
361 | 368 | message = await iterator.__anext__()
|
362 | 369 | fetched_tasks += 1
|
363 | 370 | await queue.put(message)
|
364 |
| - except asyncio.CancelledError: |
365 |
| - break |
366 |
| - except StopAsyncIteration: |
| 371 | + except (asyncio.CancelledError, StopAsyncIteration): |
367 | 372 | break
|
368 | 373 |
|
369 | 374 | await queue.put(QUEUE_DONE)
|
@@ -394,31 +399,35 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
|
394 | 399 | self.sem.release()
|
395 | 400 |
|
396 | 401 | while True:
|
397 |
| - # Waits for semaphore to be released. |
398 |
| - if self.sem is not None: |
399 |
| - await self.sem.acquire() |
400 |
| - |
401 |
| - self.sem_prefetch.release() |
402 |
| - message = await queue.get() |
403 |
| - if message is QUEUE_DONE: |
404 |
| - # asyncio.wait will throw an error if there is nothing to wait for |
405 |
| - if tasks: |
406 |
| - logger.info("Waiting for running tasks to complete.") |
407 |
| - await asyncio.wait(tasks, timeout=self.wait_tasks_timeout) |
408 |
| - break |
| 402 | + try: |
| 403 | + # Waits for semaphore to be released. |
| 404 | + if self.sem is not None: |
| 405 | + await self.sem.acquire() |
| 406 | + |
| 407 | + self.sem_prefetch.release() |
| 408 | + message = await queue.get() |
| 409 | + if message is QUEUE_DONE: |
| 410 | + # asyncio.wait will throw an error if there is nothing to wait for |
| 411 | + if tasks: |
| 412 | + logger.info("Waiting for running tasks to complete.") |
| 413 | + await asyncio.wait(tasks, timeout=self.wait_tasks_timeout) |
| 414 | + break |
409 | 415 |
|
410 |
| - task = asyncio.create_task( |
411 |
| - self.callback(message=message, raise_err=False), |
412 |
| - ) |
413 |
| - tasks.add(task) |
414 |
| - |
415 |
| - # We want the task to remove itself from the set when it's done. |
416 |
| - # |
417 |
| - # Because if we won't save it anywhere, |
418 |
| - # python's GC can silently cancel task |
419 |
| - # and this behaviour considered to be a Hisenbug. |
420 |
| - # https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/ |
421 |
| - task.add_done_callback(task_cb) |
| 416 | + task = asyncio.create_task( |
| 417 | + self.callback(message=message, raise_err=False), |
| 418 | + ) |
| 419 | + tasks.add(task) |
| 420 | + |
| 421 | + # We want the task to remove itself from the set when it's done. |
| 422 | + # |
| 423 | + # Because if we won't save it anywhere, |
| 424 | + # python's GC can silently cancel task |
| 425 | + # and this behaviour considered to be a Hisenbug. |
| 426 | + # https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/ |
| 427 | + task.add_done_callback(task_cb) |
| 428 | + |
| 429 | + except asyncio.CancelledError: |
| 430 | + break |
422 | 431 |
|
423 | 432 | def _prepare_task(self, name: str, handler: Callable[..., Any]) -> None:
|
424 | 433 | """
|
|
0 commit comments