50
50
PytestPluginManager ,
51
51
)
52
52
53
+ _seen_markers : set [int ] = set ()
54
+
55
+
56
+ def _warn_scope_deprecation_once (marker_id : int ) -> None :
57
+ """Issues deprecation warning exactly once per marker ID."""
58
+ if marker_id not in _seen_markers :
59
+ _seen_markers .add (marker_id )
60
+ warnings .warn (PytestDeprecationWarning (_MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
61
+
62
+
53
63
if sys .version_info >= (3 , 10 ):
54
64
from typing import ParamSpec
55
65
else :
63
73
_ScopeName = Literal ["session" , "package" , "module" , "class" , "function" ]
64
74
_R = TypeVar ("_R" , bound = Union [Awaitable [Any ], AsyncIterator [Any ]])
65
75
_P = ParamSpec ("_P" )
76
+ T = TypeVar ("T" )
66
77
FixtureFunction = Callable [_P , _R ]
78
+ CoroutineFunction = Callable [_P , Awaitable [T ]]
67
79
68
80
69
81
class PytestAsyncioError (Exception ):
@@ -292,7 +304,7 @@ def _asyncgen_fixture_wrapper(
292
304
gen_obj = fixture_function (* args , ** kwargs )
293
305
294
306
async def setup ():
295
- res = await gen_obj .__anext__ () # type: ignore[union-attr]
307
+ res = await gen_obj .__anext__ ()
296
308
return res
297
309
298
310
context = contextvars .copy_context ()
@@ -305,7 +317,7 @@ def finalizer() -> None:
305
317
306
318
async def async_finalizer () -> None :
307
319
try :
308
- await gen_obj .__anext__ () # type: ignore[union-attr]
320
+ await gen_obj .__anext__ ()
309
321
except StopAsyncIteration :
310
322
pass
311
323
else :
@@ -334,8 +346,7 @@ def _wrap_async_fixture(
334
346
runner : Runner ,
335
347
request : FixtureRequest ,
336
348
) -> Callable [AsyncFixtureParams , AsyncFixtureReturnType ]:
337
-
338
- @functools .wraps (fixture_function ) # type: ignore[arg-type]
349
+ @functools .wraps (fixture_function )
339
350
def _async_fixture_wrapper (
340
351
* args : AsyncFixtureParams .args ,
341
352
** kwargs : AsyncFixtureParams .kwargs ,
@@ -448,7 +459,7 @@ def _can_substitute(item: Function) -> bool:
448
459
return inspect .iscoroutinefunction (func )
449
460
450
461
def runtest (self ) -> None :
451
- synchronized_obj = wrap_in_sync ( self .obj )
462
+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
452
463
with MonkeyPatch .context () as c :
453
464
c .setattr (self , "obj" , synchronized_obj )
454
465
super ().runtest ()
@@ -490,7 +501,7 @@ def _can_substitute(item: Function) -> bool:
490
501
)
491
502
492
503
def runtest (self ) -> None :
493
- synchronized_obj = wrap_in_sync ( self .obj )
504
+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
494
505
with MonkeyPatch .context () as c :
495
506
c .setattr (self , "obj" , synchronized_obj )
496
507
super ().runtest ()
@@ -512,7 +523,10 @@ def _can_substitute(item: Function) -> bool:
512
523
)
513
524
514
525
def runtest (self ) -> None :
515
- synchronized_obj = wrap_in_sync (self .obj .hypothesis .inner_test )
526
+ synchronized_obj = get_async_test_wrapper (
527
+ self ,
528
+ self .obj .hypothesis .inner_test ,
529
+ )
516
530
with MonkeyPatch .context () as c :
517
531
c .setattr (self .obj .hypothesis , "inner_test" , synchronized_obj )
518
532
super ().runtest ()
@@ -603,10 +617,71 @@ def _set_event_loop(loop: AbstractEventLoop | None) -> None:
603
617
asyncio .set_event_loop (loop )
604
618
605
619
606
- def _reinstate_event_loop_on_main_thread () -> None :
607
- if threading .current_thread () is threading .main_thread ():
608
- policy = _get_event_loop_policy ()
609
- policy .set_event_loop (policy .new_event_loop ())
620
+ _session_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
621
+ contextvars .ContextVar (
622
+ "_session_loop" ,
623
+ default = None ,
624
+ )
625
+ )
626
+ _package_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
627
+ contextvars .ContextVar (
628
+ "_package_loop" ,
629
+ default = None ,
630
+ )
631
+ )
632
+ _module_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
633
+ contextvars .ContextVar (
634
+ "_module_loop" ,
635
+ default = None ,
636
+ )
637
+ )
638
+ _class_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
639
+ contextvars .ContextVar (
640
+ "_class_loop" ,
641
+ default = None ,
642
+ )
643
+ )
644
+ _function_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
645
+ contextvars .ContextVar (
646
+ "_function_loop" ,
647
+ default = None ,
648
+ )
649
+ )
650
+
651
+ _SCOPE_TO_CONTEXTVAR = {
652
+ "session" : _session_loop ,
653
+ "package" : _package_loop ,
654
+ "module" : _module_loop ,
655
+ "class" : _class_loop ,
656
+ "function" : _function_loop ,
657
+ }
658
+
659
+
660
+ def _get_or_restore_event_loop (loop_scope : _ScopeName ) -> asyncio .AbstractEventLoop :
661
+ """
662
+ Get or restore the appropriate event loop for the given scope.
663
+
664
+ If we have a shared loop for this scope, restore and return it.
665
+ Otherwise, get the current event loop or create a new one.
666
+ """
667
+ shared_loop = _SCOPE_TO_CONTEXTVAR [loop_scope ].get ()
668
+ if shared_loop is not None :
669
+ _reinstate_event_loop_on_main_thread (loop_scope )
670
+ return shared_loop
671
+ else :
672
+ return _get_event_loop_no_warn ()
673
+
674
+
675
+ def _reinstate_event_loop_on_main_thread (loop_scope : _ScopeName ) -> None :
676
+ if threading .current_thread () is not threading .main_thread ():
677
+ return
678
+
679
+ shared_loop = _SCOPE_TO_CONTEXTVAR [loop_scope ].get ()
680
+ if shared_loop is None :
681
+ return
682
+
683
+ policy = _get_event_loop_policy ()
684
+ policy .set_event_loop (shared_loop )
610
685
611
686
612
687
@pytest .hookimpl (tryfirst = True , hookwrapper = True )
@@ -659,9 +734,22 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
659
734
return None
660
735
661
736
662
- def wrap_in_sync (
663
- func : Callable [..., Awaitable [Any ]],
664
- ):
737
+ def get_async_test_wrapper (
738
+ item : Function ,
739
+ func : CoroutineFunction [_P , T ],
740
+ ) -> Callable [_P , T ]:
741
+ """Returns a synchronous wrapper for the specified async test function."""
742
+ marker = item .get_closest_marker ("asyncio" )
743
+ assert marker is not None
744
+ default_loop_scope = _get_default_test_loop_scope (item .config )
745
+ loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
746
+ return _wrap_in_sync (func , loop_scope )
747
+
748
+
749
+ def _wrap_in_sync (
750
+ func : CoroutineFunction [_P , T ],
751
+ loop_scope : _ScopeName ,
752
+ ) -> Callable [_P , T ]:
665
753
"""
666
754
Return a sync wrapper around an async function executing it in the
667
755
current event loop.
@@ -670,15 +758,10 @@ def wrap_in_sync(
670
758
@functools .wraps (func )
671
759
def inner (* args , ** kwargs ):
672
760
coro = func (* args , ** kwargs )
673
- try :
674
- _loop = _get_event_loop_no_warn ()
675
- except RuntimeError :
676
- # Handle situation where asyncio.set_event_loop(None) removes shared loops.
677
- _reinstate_event_loop_on_main_thread ()
678
- _loop = _get_event_loop_no_warn ()
761
+ _loop = _get_or_restore_event_loop (loop_scope )
679
762
task = asyncio .ensure_future (coro , loop = _loop )
680
763
try :
681
- _loop .run_until_complete (task )
764
+ return _loop .run_until_complete (task )
682
765
except BaseException :
683
766
# run_until_complete doesn't get the result from exceptions
684
767
# that are not subclasses of `Exception`. Consume all
@@ -758,7 +841,7 @@ def _get_marked_loop_scope(
758
841
if "scope" in asyncio_marker .kwargs :
759
842
if "loop_scope" in asyncio_marker .kwargs :
760
843
raise pytest .UsageError (_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR )
761
- warnings . warn ( PytestDeprecationWarning ( _MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
844
+ _warn_scope_deprecation_once ( id ( asyncio_marker ))
762
845
scope = asyncio_marker .kwargs .get ("loop_scope" ) or asyncio_marker .kwargs .get (
763
846
"scope"
764
847
)
@@ -796,6 +879,8 @@ def _scoped_runner(
796
879
debug_mode = _get_asyncio_debug (request .config )
797
880
with _temporary_event_loop_policy (new_loop_policy ):
798
881
runner = Runner (debug = debug_mode ).__enter__ ()
882
+ shared_loop = runner .get_loop ()
883
+ _SCOPE_TO_CONTEXTVAR [scope ].set (shared_loop )
799
884
try :
800
885
yield runner
801
886
except Exception as e :
@@ -812,6 +897,8 @@ def _scoped_runner(
812
897
_RUNNER_TEARDOWN_WARNING % traceback .format_exc (),
813
898
RuntimeWarning ,
814
899
)
900
+ finally :
901
+ _SCOPE_TO_CONTEXTVAR [scope ].set (None )
815
902
816
903
return _scoped_runner
817
904
0 commit comments