Skip to content

Commit 644bc9f

Browse files
committed
Reinstate specific event loop if unset
1 parent 656207e commit 644bc9f

File tree

2 files changed

+414
-30
lines changed

2 files changed

+414
-30
lines changed

pytest_asyncio/plugin.py

Lines changed: 106 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import inspect
1111
import socket
1212
import sys
13-
import threading
1413
import traceback
1514
import warnings
1615
from asyncio import AbstractEventLoop, AbstractEventLoopPolicy
@@ -50,6 +49,16 @@
5049
PytestPluginManager,
5150
)
5251

52+
_seen_markers: set[int] = set()
53+
54+
55+
def _warn_scope_deprecation_once(marker_id: int) -> None:
56+
"""Issues deprecation warning exactly once per marker ID."""
57+
if marker_id not in _seen_markers:
58+
_seen_markers.add(marker_id)
59+
warnings.warn(PytestDeprecationWarning(_MARKER_SCOPE_KWARG_DEPRECATION_WARNING))
60+
61+
5362
if sys.version_info >= (3, 10):
5463
from typing import ParamSpec
5564
else:
@@ -63,7 +72,9 @@
6372
_ScopeName = Literal["session", "package", "module", "class", "function"]
6473
_R = TypeVar("_R", bound=Union[Awaitable[Any], AsyncIterator[Any]])
6574
_P = ParamSpec("_P")
75+
T = TypeVar("T")
6676
FixtureFunction = Callable[_P, _R]
77+
CoroutineFunction = Callable[_P, Awaitable[T]]
6778

6879

6980
class PytestAsyncioError(Exception):
@@ -292,7 +303,7 @@ def _asyncgen_fixture_wrapper(
292303
gen_obj = fixture_function(*args, **kwargs)
293304

294305
async def setup():
295-
res = await gen_obj.__anext__() # type: ignore[union-attr]
306+
res = await gen_obj.__anext__()
296307
return res
297308

298309
context = contextvars.copy_context()
@@ -305,7 +316,7 @@ def finalizer() -> None:
305316

306317
async def async_finalizer() -> None:
307318
try:
308-
await gen_obj.__anext__() # type: ignore[union-attr]
319+
await gen_obj.__anext__()
309320
except StopAsyncIteration:
310321
pass
311322
else:
@@ -334,8 +345,7 @@ def _wrap_async_fixture(
334345
runner: Runner,
335346
request: FixtureRequest,
336347
) -> Callable[AsyncFixtureParams, AsyncFixtureReturnType]:
337-
338-
@functools.wraps(fixture_function) # type: ignore[arg-type]
348+
@functools.wraps(fixture_function)
339349
def _async_fixture_wrapper(
340350
*args: AsyncFixtureParams.args,
341351
**kwargs: AsyncFixtureParams.kwargs,
@@ -448,7 +458,7 @@ def _can_substitute(item: Function) -> bool:
448458
return inspect.iscoroutinefunction(func)
449459

450460
def runtest(self) -> None:
451-
synchronized_obj = wrap_in_sync(self.obj)
461+
synchronized_obj = get_async_test_wrapper(self, self.obj)
452462
with MonkeyPatch.context() as c:
453463
c.setattr(self, "obj", synchronized_obj)
454464
super().runtest()
@@ -490,7 +500,7 @@ def _can_substitute(item: Function) -> bool:
490500
)
491501

492502
def runtest(self) -> None:
493-
synchronized_obj = wrap_in_sync(self.obj)
503+
synchronized_obj = get_async_test_wrapper(self, self.obj)
494504
with MonkeyPatch.context() as c:
495505
c.setattr(self, "obj", synchronized_obj)
496506
super().runtest()
@@ -512,7 +522,10 @@ def _can_substitute(item: Function) -> bool:
512522
)
513523

514524
def runtest(self) -> None:
515-
synchronized_obj = wrap_in_sync(self.obj.hypothesis.inner_test)
525+
synchronized_obj = get_async_test_wrapper(
526+
self,
527+
self.obj.hypothesis.inner_test,
528+
)
516529
with MonkeyPatch.context() as c:
517530
c.setattr(self.obj.hypothesis, "inner_test", synchronized_obj)
518531
super().runtest()
@@ -603,10 +616,68 @@ def _set_event_loop(loop: AbstractEventLoop | None) -> None:
603616
asyncio.set_event_loop(loop)
604617

605618

606-
def _reinstate_event_loop_on_main_thread() -> None:
607-
if threading.current_thread() is threading.main_thread():
608-
policy = _get_event_loop_policy()
609-
policy.set_event_loop(policy.new_event_loop())
619+
_session_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
620+
contextvars.ContextVar(
621+
"_session_loop",
622+
default=None,
623+
)
624+
)
625+
_package_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
626+
contextvars.ContextVar(
627+
"_package_loop",
628+
default=None,
629+
)
630+
)
631+
_module_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
632+
contextvars.ContextVar(
633+
"_module_loop",
634+
default=None,
635+
)
636+
)
637+
_class_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
638+
contextvars.ContextVar(
639+
"_class_loop",
640+
default=None,
641+
)
642+
)
643+
_function_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
644+
contextvars.ContextVar(
645+
"_function_loop",
646+
default=None,
647+
)
648+
)
649+
650+
_SCOPE_TO_CONTEXTVAR = {
651+
"session": _session_loop,
652+
"package": _package_loop,
653+
"module": _module_loop,
654+
"class": _class_loop,
655+
"function": _function_loop,
656+
}
657+
658+
659+
def _get_or_restore_event_loop(loop_scope: _ScopeName) -> asyncio.AbstractEventLoop:
660+
"""
661+
Get or restore the appropriate event loop for the given scope.
662+
663+
If we have a shared loop for this scope, restore and return it.
664+
Otherwise, get the current event loop or create a new one.
665+
"""
666+
shared_loop = _SCOPE_TO_CONTEXTVAR[loop_scope].get()
667+
if shared_loop is not None:
668+
_reinstate_event_loop_on_main_thread(loop_scope)
669+
return shared_loop
670+
else:
671+
return _get_event_loop_no_warn()
672+
673+
674+
def _reinstate_event_loop_on_main_thread(loop_scope: _ScopeName) -> None:
675+
shared_loop = _SCOPE_TO_CONTEXTVAR[loop_scope].get()
676+
if shared_loop is None:
677+
return
678+
679+
policy = _get_event_loop_policy()
680+
policy.set_event_loop(shared_loop)
610681

611682

612683
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
@@ -659,9 +730,22 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
659730
return None
660731

661732

662-
def wrap_in_sync(
663-
func: Callable[..., Awaitable[Any]],
664-
):
733+
def get_async_test_wrapper(
734+
item: Function,
735+
func: CoroutineFunction[_P, T],
736+
) -> Callable[_P, T]:
737+
"""Returns a synchronous wrapper for the specified async test function."""
738+
marker = item.get_closest_marker("asyncio")
739+
assert marker is not None
740+
default_loop_scope = _get_default_test_loop_scope(item.config)
741+
loop_scope = _get_marked_loop_scope(marker, default_loop_scope)
742+
return _wrap_in_sync(func, loop_scope)
743+
744+
745+
def _wrap_in_sync(
746+
func: CoroutineFunction[_P, T],
747+
loop_scope: _ScopeName,
748+
) -> Callable[_P, T]:
665749
"""
666750
Return a sync wrapper around an async function executing it in the
667751
current event loop.
@@ -670,15 +754,10 @@ def wrap_in_sync(
670754
@functools.wraps(func)
671755
def inner(*args, **kwargs):
672756
coro = func(*args, **kwargs)
673-
try:
674-
_loop = _get_event_loop_no_warn()
675-
except RuntimeError:
676-
# Handle situation where asyncio.set_event_loop(None) removes shared loops.
677-
_reinstate_event_loop_on_main_thread()
678-
_loop = _get_event_loop_no_warn()
757+
_loop = _get_or_restore_event_loop(loop_scope)
679758
task = asyncio.ensure_future(coro, loop=_loop)
680759
try:
681-
_loop.run_until_complete(task)
760+
return _loop.run_until_complete(task)
682761
except BaseException:
683762
# run_until_complete doesn't get the result from exceptions
684763
# that are not subclasses of `Exception`. Consume all
@@ -758,7 +837,7 @@ def _get_marked_loop_scope(
758837
if "scope" in asyncio_marker.kwargs:
759838
if "loop_scope" in asyncio_marker.kwargs:
760839
raise pytest.UsageError(_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR)
761-
warnings.warn(PytestDeprecationWarning(_MARKER_SCOPE_KWARG_DEPRECATION_WARNING))
840+
_warn_scope_deprecation_once(id(asyncio_marker))
762841
scope = asyncio_marker.kwargs.get("loop_scope") or asyncio_marker.kwargs.get(
763842
"scope"
764843
)
@@ -796,6 +875,8 @@ def _scoped_runner(
796875
debug_mode = _get_asyncio_debug(request.config)
797876
with _temporary_event_loop_policy(new_loop_policy):
798877
runner = Runner(debug=debug_mode).__enter__()
878+
shared_loop = runner.get_loop()
879+
_SCOPE_TO_CONTEXTVAR[scope].set(shared_loop)
799880
try:
800881
yield runner
801882
except Exception as e:
@@ -812,6 +893,8 @@ def _scoped_runner(
812893
_RUNNER_TEARDOWN_WARNING % traceback.format_exc(),
813894
RuntimeWarning,
814895
)
896+
finally:
897+
_SCOPE_TO_CONTEXTVAR[scope].set(None)
815898

816899
return _scoped_runner
817900

0 commit comments

Comments
 (0)