Skip to content

Commit a6b7edf

Browse files
committed
Reinstate specific event loop if unset
1 parent 656207e commit a6b7edf

File tree

2 files changed

+417
-29
lines changed

2 files changed

+417
-29
lines changed

pytest_asyncio/plugin.py

Lines changed: 109 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@
5050
PytestPluginManager,
5151
)
5252

53+
_seen_markers: set[int] = set()
54+
55+
56+
def _warn_scope_deprecation_once(marker_id: int) -> None:
57+
"""Issues deprecation warning exactly once per marker ID."""
58+
if marker_id not in _seen_markers:
59+
_seen_markers.add(marker_id)
60+
warnings.warn(PytestDeprecationWarning(_MARKER_SCOPE_KWARG_DEPRECATION_WARNING))
61+
62+
5363
if sys.version_info >= (3, 10):
5464
from typing import ParamSpec
5565
else:
@@ -63,7 +73,9 @@
6373
_ScopeName = Literal["session", "package", "module", "class", "function"]
6474
_R = TypeVar("_R", bound=Union[Awaitable[Any], AsyncIterator[Any]])
6575
_P = ParamSpec("_P")
76+
T = TypeVar("T")
6677
FixtureFunction = Callable[_P, _R]
78+
CoroutineFunction = Callable[_P, Awaitable[T]]
6779

6880

6981
class PytestAsyncioError(Exception):
@@ -292,7 +304,7 @@ def _asyncgen_fixture_wrapper(
292304
gen_obj = fixture_function(*args, **kwargs)
293305

294306
async def setup():
295-
res = await gen_obj.__anext__() # type: ignore[union-attr]
307+
res = await gen_obj.__anext__()
296308
return res
297309

298310
context = contextvars.copy_context()
@@ -305,7 +317,7 @@ def finalizer() -> None:
305317

306318
async def async_finalizer() -> None:
307319
try:
308-
await gen_obj.__anext__() # type: ignore[union-attr]
320+
await gen_obj.__anext__()
309321
except StopAsyncIteration:
310322
pass
311323
else:
@@ -334,8 +346,7 @@ def _wrap_async_fixture(
334346
runner: Runner,
335347
request: FixtureRequest,
336348
) -> Callable[AsyncFixtureParams, AsyncFixtureReturnType]:
337-
338-
@functools.wraps(fixture_function) # type: ignore[arg-type]
349+
@functools.wraps(fixture_function)
339350
def _async_fixture_wrapper(
340351
*args: AsyncFixtureParams.args,
341352
**kwargs: AsyncFixtureParams.kwargs,
@@ -448,7 +459,7 @@ def _can_substitute(item: Function) -> bool:
448459
return inspect.iscoroutinefunction(func)
449460

450461
def runtest(self) -> None:
451-
synchronized_obj = wrap_in_sync(self.obj)
462+
synchronized_obj = get_async_test_wrapper(self, self.obj)
452463
with MonkeyPatch.context() as c:
453464
c.setattr(self, "obj", synchronized_obj)
454465
super().runtest()
@@ -490,7 +501,7 @@ def _can_substitute(item: Function) -> bool:
490501
)
491502

492503
def runtest(self) -> None:
493-
synchronized_obj = wrap_in_sync(self.obj)
504+
synchronized_obj = get_async_test_wrapper(self, self.obj)
494505
with MonkeyPatch.context() as c:
495506
c.setattr(self, "obj", synchronized_obj)
496507
super().runtest()
@@ -512,7 +523,10 @@ def _can_substitute(item: Function) -> bool:
512523
)
513524

514525
def runtest(self) -> None:
515-
synchronized_obj = wrap_in_sync(self.obj.hypothesis.inner_test)
526+
synchronized_obj = get_async_test_wrapper(
527+
self,
528+
self.obj.hypothesis.inner_test,
529+
)
516530
with MonkeyPatch.context() as c:
517531
c.setattr(self.obj.hypothesis, "inner_test", synchronized_obj)
518532
super().runtest()
@@ -603,10 +617,71 @@ def _set_event_loop(loop: AbstractEventLoop | None) -> None:
603617
asyncio.set_event_loop(loop)
604618

605619

606-
def _reinstate_event_loop_on_main_thread() -> None:
607-
if threading.current_thread() is threading.main_thread():
608-
policy = _get_event_loop_policy()
609-
policy.set_event_loop(policy.new_event_loop())
620+
_session_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
621+
contextvars.ContextVar(
622+
"_session_loop",
623+
default=None,
624+
)
625+
)
626+
_package_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
627+
contextvars.ContextVar(
628+
"_package_loop",
629+
default=None,
630+
)
631+
)
632+
_module_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
633+
contextvars.ContextVar(
634+
"_module_loop",
635+
default=None,
636+
)
637+
)
638+
_class_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
639+
contextvars.ContextVar(
640+
"_class_loop",
641+
default=None,
642+
)
643+
)
644+
_function_loop: contextvars.ContextVar[asyncio.AbstractEventLoop | None] = (
645+
contextvars.ContextVar(
646+
"_function_loop",
647+
default=None,
648+
)
649+
)
650+
651+
_SCOPE_TO_CONTEXTVAR = {
652+
"session": _session_loop,
653+
"package": _package_loop,
654+
"module": _module_loop,
655+
"class": _class_loop,
656+
"function": _function_loop,
657+
}
658+
659+
660+
def _get_or_restore_event_loop(loop_scope: _ScopeName) -> asyncio.AbstractEventLoop:
661+
"""
662+
Get or restore the appropriate event loop for the given scope.
663+
664+
If we have a shared loop for this scope, restore and return it.
665+
Otherwise, get the current event loop or create a new one.
666+
"""
667+
shared_loop = _SCOPE_TO_CONTEXTVAR[loop_scope].get()
668+
if shared_loop is not None:
669+
_reinstate_event_loop_on_main_thread(loop_scope)
670+
return shared_loop
671+
else:
672+
return _get_event_loop_no_warn()
673+
674+
675+
def _reinstate_event_loop_on_main_thread(loop_scope: _ScopeName) -> None:
676+
if threading.current_thread() is not threading.main_thread():
677+
return
678+
679+
shared_loop = _SCOPE_TO_CONTEXTVAR[loop_scope].get()
680+
if shared_loop is None:
681+
return
682+
683+
policy = _get_event_loop_policy()
684+
policy.set_event_loop(shared_loop)
610685

611686

612687
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
@@ -659,9 +734,22 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
659734
return None
660735

661736

662-
def wrap_in_sync(
663-
func: Callable[..., Awaitable[Any]],
664-
):
737+
def get_async_test_wrapper(
738+
item: Function,
739+
func: CoroutineFunction[_P, T],
740+
) -> Callable[_P, T]:
741+
"""Returns a synchronous wrapper for the specified async test function."""
742+
marker = item.get_closest_marker("asyncio")
743+
assert marker is not None
744+
default_loop_scope = _get_default_test_loop_scope(item.config)
745+
loop_scope = _get_marked_loop_scope(marker, default_loop_scope)
746+
return _wrap_in_sync(func, loop_scope)
747+
748+
749+
def _wrap_in_sync(
750+
func: CoroutineFunction[_P, T],
751+
loop_scope: _ScopeName,
752+
) -> Callable[_P, T]:
665753
"""
666754
Return a sync wrapper around an async function executing it in the
667755
current event loop.
@@ -670,15 +758,10 @@ def wrap_in_sync(
670758
@functools.wraps(func)
671759
def inner(*args, **kwargs):
672760
coro = func(*args, **kwargs)
673-
try:
674-
_loop = _get_event_loop_no_warn()
675-
except RuntimeError:
676-
# Handle situation where asyncio.set_event_loop(None) removes shared loops.
677-
_reinstate_event_loop_on_main_thread()
678-
_loop = _get_event_loop_no_warn()
761+
_loop = _get_or_restore_event_loop(loop_scope)
679762
task = asyncio.ensure_future(coro, loop=_loop)
680763
try:
681-
_loop.run_until_complete(task)
764+
return _loop.run_until_complete(task)
682765
except BaseException:
683766
# run_until_complete doesn't get the result from exceptions
684767
# that are not subclasses of `Exception`. Consume all
@@ -758,7 +841,7 @@ def _get_marked_loop_scope(
758841
if "scope" in asyncio_marker.kwargs:
759842
if "loop_scope" in asyncio_marker.kwargs:
760843
raise pytest.UsageError(_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR)
761-
warnings.warn(PytestDeprecationWarning(_MARKER_SCOPE_KWARG_DEPRECATION_WARNING))
844+
_warn_scope_deprecation_once(id(asyncio_marker))
762845
scope = asyncio_marker.kwargs.get("loop_scope") or asyncio_marker.kwargs.get(
763846
"scope"
764847
)
@@ -796,6 +879,8 @@ def _scoped_runner(
796879
debug_mode = _get_asyncio_debug(request.config)
797880
with _temporary_event_loop_policy(new_loop_policy):
798881
runner = Runner(debug=debug_mode).__enter__()
882+
shared_loop = runner.get_loop()
883+
_SCOPE_TO_CONTEXTVAR[scope].set(shared_loop)
799884
try:
800885
yield runner
801886
except Exception as e:
@@ -812,6 +897,8 @@ def _scoped_runner(
812897
_RUNNER_TEARDOWN_WARNING % traceback.format_exc(),
813898
RuntimeWarning,
814899
)
900+
finally:
901+
_SCOPE_TO_CONTEXTVAR[scope].set(None)
815902

816903
return _scoped_runner
817904

0 commit comments

Comments
 (0)