From 33509c14c7d2d368d735ccb9d5938bce9ff516bc Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 20 Jan 2025 16:53:55 +0000 Subject: [PATCH 1/2] [3.13] gh-128308: pass `**kwargs` to asyncio task_factory (GH-128768) (cherry picked from commit 38a99568763604ccec5d5027f0658100ad76876f) Co-authored-by: Thomas Grainger Co-authored-by: Kumar Aditya --- Doc/library/asyncio-eventloop.rst | 4 +- Lib/asyncio/base_events.py | 26 ++-- Lib/asyncio/events.py | 2 +- Lib/test/test_asyncio/test_base_events.py | 4 +- .../test_asyncio/test_eager_task_factory.py | 12 ++ Lib/test/test_asyncio/test_free_threading.py | 136 ++++++++++++++++++ Lib/test/test_asyncio/test_taskgroups.py | 12 ++ ...-01-13-07-54-32.gh-issue-128308.kYSDRF.rst | 1 + 8 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 Lib/test/test_asyncio/test_free_threading.py create mode 100644 Misc/NEWS.d/next/Library/2025-01-13-07-54-32.gh-issue-128308.kYSDRF.rst diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index 8027d3525e5999..3642a56785f3fa 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -382,9 +382,9 @@ Creating Futures and Tasks If *factory* is ``None`` the default task factory will be set. Otherwise, *factory* must be a *callable* with the signature matching - ``(loop, coro, context=None)``, where *loop* is a reference to the active + ``(loop, coro, **kwargs)``, where *loop* is a reference to the active event loop, and *coro* is a coroutine object. The callable - must return a :class:`asyncio.Future`-compatible object. + must pass on all *kwargs*, and return a :class:`asyncio.Task`-compatible object. .. method:: loop.get_task_factory() diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 910fc76e884d2c..b4a654c2dc2c66 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -458,25 +458,18 @@ def create_future(self): """Create a Future object attached to the loop.""" return futures.Future(loop=self) - def create_task(self, coro, *, name=None, context=None): + def create_task(self, coro, **kwargs): """Schedule a coroutine object. Return a task object. """ self._check_closed() - if self._task_factory is None: - task = tasks.Task(coro, loop=self, name=name, context=context) - if task._source_traceback: - del task._source_traceback[-1] - else: - if context is None: - # Use legacy API if context is not needed - task = self._task_factory(self, coro) - else: - task = self._task_factory(self, coro, context=context) - - task.set_name(name) + if self._task_factory is not None: + return self._task_factory(self, coro, **kwargs) + task = tasks.Task(coro, loop=self, **kwargs) + if task._source_traceback: + del task._source_traceback[-1] try: return task finally: @@ -490,9 +483,10 @@ def set_task_factory(self, factory): If factory is None the default task factory will be set. If factory is a callable, it should have a signature matching - '(loop, coro)', where 'loop' will be a reference to the active - event loop, 'coro' will be a coroutine object. The callable - must return a Future. + '(loop, coro, **kwargs)', where 'loop' will be a reference to the active + event loop, 'coro' will be a coroutine object, and **kwargs will be + arbitrary keyword arguments that should be passed on to Task. + The callable must return a Task. """ if factory is not None and not callable(factory): raise TypeError('task factory must be a callable or None') diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index be495469a0558b..3b740b9b905c0d 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -292,7 +292,7 @@ def create_future(self): # Method scheduling a coroutine object: create a task. - def create_task(self, coro, *, name=None, context=None): + def create_task(self, coro, **kwargs): raise NotImplementedError # Methods for interacting with threads. diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index c14a0bb180d79b..3efb3ae9359db6 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -833,8 +833,8 @@ async def test(): loop.close() def test_create_named_task_with_custom_factory(self): - def task_factory(loop, coro): - return asyncio.Task(coro, loop=loop) + def task_factory(loop, coro, **kwargs): + return asyncio.Task(coro, loop=loop, **kwargs) async def test(): pass diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py index 0e2b189f761521..179a6e44b59934 100644 --- a/Lib/test/test_asyncio/test_eager_task_factory.py +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -302,6 +302,18 @@ async def run(): self.run_coro(run()) + def test_name(self): + name = None + async def coro(): + nonlocal name + name = asyncio.current_task().get_name() + + async def main(): + task = self.loop.create_task(coro(), name="test name") + self.assertEqual(name, "test name") + await task + + self.run_coro(coro()) class AsyncTaskCounter: def __init__(self, loop, *, task_class, eager): diff --git a/Lib/test/test_asyncio/test_free_threading.py b/Lib/test/test_asyncio/test_free_threading.py new file mode 100644 index 00000000000000..05106a2c2fe3f6 --- /dev/null +++ b/Lib/test/test_asyncio/test_free_threading.py @@ -0,0 +1,136 @@ +import asyncio +import unittest +from threading import Thread +from unittest import TestCase + +from test.support import threading_helper + +threading_helper.requires_working_threading(module=True) + + +class MyException(Exception): + pass + + +def tearDownModule(): + asyncio._set_event_loop_policy(None) + + +class TestFreeThreading: + def test_all_tasks_race(self) -> None: + async def main(): + loop = asyncio.get_running_loop() + future = loop.create_future() + + async def coro(): + await future + + tasks = set() + + async with asyncio.TaskGroup() as tg: + for _ in range(100): + tasks.add(tg.create_task(coro())) + + all_tasks = self.all_tasks(loop) + self.assertEqual(len(all_tasks), 101) + + for task in all_tasks: + self.assertEqual(task.get_loop(), loop) + self.assertFalse(task.done()) + + current = self.current_task() + self.assertEqual(current.get_loop(), loop) + self.assertSetEqual(all_tasks, tasks | {current}) + future.set_result(None) + + def runner(): + with asyncio.Runner() as runner: + loop = runner.get_loop() + loop.set_task_factory(self.factory) + runner.run(main()) + + threads = [] + + for _ in range(10): + thread = Thread(target=runner) + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + def test_run_coroutine_threadsafe(self) -> None: + results = [] + + def in_thread(loop: asyncio.AbstractEventLoop): + coro = asyncio.sleep(0.1, result=42) + fut = asyncio.run_coroutine_threadsafe(coro, loop) + result = fut.result() + self.assertEqual(result, 42) + results.append(result) + + async def main(): + loop = asyncio.get_running_loop() + async with asyncio.TaskGroup() as tg: + for _ in range(10): + tg.create_task(asyncio.to_thread(in_thread, loop)) + self.assertEqual(results, [42] * 10) + + with asyncio.Runner() as r: + loop = r.get_loop() + loop.set_task_factory(self.factory) + r.run(main()) + + def test_run_coroutine_threadsafe_exception(self) -> None: + async def coro(): + await asyncio.sleep(0) + raise MyException("test") + + def in_thread(loop: asyncio.AbstractEventLoop): + fut = asyncio.run_coroutine_threadsafe(coro(), loop) + return fut.result() + + async def main(): + loop = asyncio.get_running_loop() + tasks = [] + for _ in range(10): + task = loop.create_task(asyncio.to_thread(in_thread, loop)) + tasks.append(task) + results = await asyncio.gather(*tasks, return_exceptions=True) + + self.assertEqual(len(results), 10) + for result in results: + self.assertIsInstance(result, MyException) + self.assertEqual(str(result), "test") + + with asyncio.Runner() as r: + loop = r.get_loop() + loop.set_task_factory(self.factory) + r.run(main()) + + +class TestPyFreeThreading(TestFreeThreading, TestCase): + all_tasks = staticmethod(asyncio.tasks._py_all_tasks) + current_task = staticmethod(asyncio.tasks._py_current_task) + + def factory(self, loop, coro, **kwargs): + return asyncio.tasks._PyTask(coro, loop=loop, **kwargs) + + +@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio") +class TestCFreeThreading(TestFreeThreading, TestCase): + all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None)) + current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None)) + + def factory(self, loop, coro, **kwargs): + return asyncio.tasks._CTask(coro, loop=loop, **kwargs) + + +class TestEagerPyFreeThreading(TestPyFreeThreading): + def factory(self, loop, coro, eager_start=True, **kwargs): + return asyncio.tasks._PyTask(coro, loop=loop, **kwargs, eager_start=eager_start) + + +@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio") +class TestEagerCFreeThreading(TestCFreeThreading, TestCase): + def factory(self, loop, coro, eager_start=True, **kwargs): + return asyncio.tasks._CTask(coro, loop=loop, **kwargs, eager_start=eager_start) diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index ad61cb46c7c07c..9f2211e3232e54 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -1081,6 +1081,18 @@ async def throw_error(): # cancellation happens here and error is more understandable await asyncio.sleep(0) + async def test_name(self): + name = None + + async def asyncfn(): + nonlocal name + name = asyncio.current_task().get_name() + + async with asyncio.TaskGroup() as tg: + tg.create_task(asyncfn(), name="example name") + + self.assertEqual(name, "example name") + class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase): loop_factory = asyncio.EventLoop diff --git a/Misc/NEWS.d/next/Library/2025-01-13-07-54-32.gh-issue-128308.kYSDRF.rst b/Misc/NEWS.d/next/Library/2025-01-13-07-54-32.gh-issue-128308.kYSDRF.rst new file mode 100644 index 00000000000000..efa613876a35fd --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-01-13-07-54-32.gh-issue-128308.kYSDRF.rst @@ -0,0 +1 @@ +Support the *name* keyword argument for eager tasks in :func:`asyncio.loop.create_task`, :func:`asyncio.create_task` and :func:`asyncio.TaskGroup.create_task`, by passing on all *kwargs* to the task factory set by :func:`asyncio.loop.set_task_factory`. From ab7e48e628c047f7fa920873f2ff98d5cb617af1 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 13 Feb 2025 15:50:00 +0100 Subject: [PATCH 2/2] Drop free threading tests --- Lib/test/test_asyncio/test_free_threading.py | 136 ------------------- 1 file changed, 136 deletions(-) delete mode 100644 Lib/test/test_asyncio/test_free_threading.py diff --git a/Lib/test/test_asyncio/test_free_threading.py b/Lib/test/test_asyncio/test_free_threading.py deleted file mode 100644 index 05106a2c2fe3f6..00000000000000 --- a/Lib/test/test_asyncio/test_free_threading.py +++ /dev/null @@ -1,136 +0,0 @@ -import asyncio -import unittest -from threading import Thread -from unittest import TestCase - -from test.support import threading_helper - -threading_helper.requires_working_threading(module=True) - - -class MyException(Exception): - pass - - -def tearDownModule(): - asyncio._set_event_loop_policy(None) - - -class TestFreeThreading: - def test_all_tasks_race(self) -> None: - async def main(): - loop = asyncio.get_running_loop() - future = loop.create_future() - - async def coro(): - await future - - tasks = set() - - async with asyncio.TaskGroup() as tg: - for _ in range(100): - tasks.add(tg.create_task(coro())) - - all_tasks = self.all_tasks(loop) - self.assertEqual(len(all_tasks), 101) - - for task in all_tasks: - self.assertEqual(task.get_loop(), loop) - self.assertFalse(task.done()) - - current = self.current_task() - self.assertEqual(current.get_loop(), loop) - self.assertSetEqual(all_tasks, tasks | {current}) - future.set_result(None) - - def runner(): - with asyncio.Runner() as runner: - loop = runner.get_loop() - loop.set_task_factory(self.factory) - runner.run(main()) - - threads = [] - - for _ in range(10): - thread = Thread(target=runner) - threads.append(thread) - - with threading_helper.start_threads(threads): - pass - - def test_run_coroutine_threadsafe(self) -> None: - results = [] - - def in_thread(loop: asyncio.AbstractEventLoop): - coro = asyncio.sleep(0.1, result=42) - fut = asyncio.run_coroutine_threadsafe(coro, loop) - result = fut.result() - self.assertEqual(result, 42) - results.append(result) - - async def main(): - loop = asyncio.get_running_loop() - async with asyncio.TaskGroup() as tg: - for _ in range(10): - tg.create_task(asyncio.to_thread(in_thread, loop)) - self.assertEqual(results, [42] * 10) - - with asyncio.Runner() as r: - loop = r.get_loop() - loop.set_task_factory(self.factory) - r.run(main()) - - def test_run_coroutine_threadsafe_exception(self) -> None: - async def coro(): - await asyncio.sleep(0) - raise MyException("test") - - def in_thread(loop: asyncio.AbstractEventLoop): - fut = asyncio.run_coroutine_threadsafe(coro(), loop) - return fut.result() - - async def main(): - loop = asyncio.get_running_loop() - tasks = [] - for _ in range(10): - task = loop.create_task(asyncio.to_thread(in_thread, loop)) - tasks.append(task) - results = await asyncio.gather(*tasks, return_exceptions=True) - - self.assertEqual(len(results), 10) - for result in results: - self.assertIsInstance(result, MyException) - self.assertEqual(str(result), "test") - - with asyncio.Runner() as r: - loop = r.get_loop() - loop.set_task_factory(self.factory) - r.run(main()) - - -class TestPyFreeThreading(TestFreeThreading, TestCase): - all_tasks = staticmethod(asyncio.tasks._py_all_tasks) - current_task = staticmethod(asyncio.tasks._py_current_task) - - def factory(self, loop, coro, **kwargs): - return asyncio.tasks._PyTask(coro, loop=loop, **kwargs) - - -@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio") -class TestCFreeThreading(TestFreeThreading, TestCase): - all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None)) - current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None)) - - def factory(self, loop, coro, **kwargs): - return asyncio.tasks._CTask(coro, loop=loop, **kwargs) - - -class TestEagerPyFreeThreading(TestPyFreeThreading): - def factory(self, loop, coro, eager_start=True, **kwargs): - return asyncio.tasks._PyTask(coro, loop=loop, **kwargs, eager_start=eager_start) - - -@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio") -class TestEagerCFreeThreading(TestCFreeThreading, TestCase): - def factory(self, loop, coro, eager_start=True, **kwargs): - return asyncio.tasks._CTask(coro, loop=loop, **kwargs, eager_start=eager_start)