10
10
import inspect
11
11
import socket
12
12
import sys
13
- import threading
14
13
import traceback
15
14
import warnings
16
15
from asyncio import AbstractEventLoop , AbstractEventLoopPolicy
50
49
PytestPluginManager ,
51
50
)
52
51
52
+ _seen_markers : set [int ] = set ()
53
+
54
+
55
+ def _warn_scope_deprecation_once (marker_id : int ) -> None :
56
+ """Issues deprecation warning exactly once per marker ID."""
57
+ if marker_id not in _seen_markers :
58
+ _seen_markers .add (marker_id )
59
+ warnings .warn (PytestDeprecationWarning (_MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
60
+
61
+
53
62
if sys .version_info >= (3 , 10 ):
54
63
from typing import ParamSpec
55
64
else :
63
72
_ScopeName = Literal ["session" , "package" , "module" , "class" , "function" ]
64
73
_R = TypeVar ("_R" , bound = Union [Awaitable [Any ], AsyncIterator [Any ]])
65
74
_P = ParamSpec ("_P" )
75
+ T = TypeVar ("T" )
66
76
FixtureFunction = Callable [_P , _R ]
77
+ CoroutineFunction = Callable [_P , Awaitable [T ]]
67
78
68
79
69
80
class PytestAsyncioError (Exception ):
@@ -292,7 +303,7 @@ def _asyncgen_fixture_wrapper(
292
303
gen_obj = fixture_function (* args , ** kwargs )
293
304
294
305
async def setup ():
295
- res = await gen_obj .__anext__ () # type: ignore[union-attr]
306
+ res = await gen_obj .__anext__ ()
296
307
return res
297
308
298
309
context = contextvars .copy_context ()
@@ -305,7 +316,7 @@ def finalizer() -> None:
305
316
306
317
async def async_finalizer () -> None :
307
318
try :
308
- await gen_obj .__anext__ () # type: ignore[union-attr]
319
+ await gen_obj .__anext__ ()
309
320
except StopAsyncIteration :
310
321
pass
311
322
else :
@@ -334,8 +345,7 @@ def _wrap_async_fixture(
334
345
runner : Runner ,
335
346
request : FixtureRequest ,
336
347
) -> Callable [AsyncFixtureParams , AsyncFixtureReturnType ]:
337
-
338
- @functools .wraps (fixture_function ) # type: ignore[arg-type]
348
+ @functools .wraps (fixture_function )
339
349
def _async_fixture_wrapper (
340
350
* args : AsyncFixtureParams .args ,
341
351
** kwargs : AsyncFixtureParams .kwargs ,
@@ -448,7 +458,7 @@ def _can_substitute(item: Function) -> bool:
448
458
return inspect .iscoroutinefunction (func )
449
459
450
460
def runtest (self ) -> None :
451
- synchronized_obj = wrap_in_sync ( self .obj )
461
+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
452
462
with MonkeyPatch .context () as c :
453
463
c .setattr (self , "obj" , synchronized_obj )
454
464
super ().runtest ()
@@ -490,7 +500,7 @@ def _can_substitute(item: Function) -> bool:
490
500
)
491
501
492
502
def runtest (self ) -> None :
493
- synchronized_obj = wrap_in_sync ( self .obj )
503
+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
494
504
with MonkeyPatch .context () as c :
495
505
c .setattr (self , "obj" , synchronized_obj )
496
506
super ().runtest ()
@@ -512,7 +522,10 @@ def _can_substitute(item: Function) -> bool:
512
522
)
513
523
514
524
def runtest (self ) -> None :
515
- synchronized_obj = wrap_in_sync (self .obj .hypothesis .inner_test )
525
+ synchronized_obj = get_async_test_wrapper (
526
+ self ,
527
+ self .obj .hypothesis .inner_test ,
528
+ )
516
529
with MonkeyPatch .context () as c :
517
530
c .setattr (self .obj .hypothesis , "inner_test" , synchronized_obj )
518
531
super ().runtest ()
@@ -603,10 +616,68 @@ def _set_event_loop(loop: AbstractEventLoop | None) -> None:
603
616
asyncio .set_event_loop (loop )
604
617
605
618
606
- def _reinstate_event_loop_on_main_thread () -> None :
607
- if threading .current_thread () is threading .main_thread ():
608
- policy = _get_event_loop_policy ()
609
- policy .set_event_loop (policy .new_event_loop ())
619
+ _session_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
620
+ contextvars .ContextVar (
621
+ "_session_loop" ,
622
+ default = None ,
623
+ )
624
+ )
625
+ _package_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
626
+ contextvars .ContextVar (
627
+ "_package_loop" ,
628
+ default = None ,
629
+ )
630
+ )
631
+ _module_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
632
+ contextvars .ContextVar (
633
+ "_module_loop" ,
634
+ default = None ,
635
+ )
636
+ )
637
+ _class_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
638
+ contextvars .ContextVar (
639
+ "_class_loop" ,
640
+ default = None ,
641
+ )
642
+ )
643
+ _function_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
644
+ contextvars .ContextVar (
645
+ "_function_loop" ,
646
+ default = None ,
647
+ )
648
+ )
649
+
650
+ _SCOPE_TO_CONTEXTVAR = {
651
+ "session" : _session_loop ,
652
+ "package" : _package_loop ,
653
+ "module" : _module_loop ,
654
+ "class" : _class_loop ,
655
+ "function" : _function_loop ,
656
+ }
657
+
658
+
659
+ def _get_or_restore_event_loop (loop_scope : _ScopeName ) -> asyncio .AbstractEventLoop :
660
+ """
661
+ Get or restore the appropriate event loop for the given scope.
662
+
663
+ If we have a shared loop for this scope, restore and return it.
664
+ Otherwise, get the current event loop or create a new one.
665
+ """
666
+ shared_loop = _SCOPE_TO_CONTEXTVAR [loop_scope ].get ()
667
+ if shared_loop is not None :
668
+ _reinstate_event_loop_on_main_thread (loop_scope )
669
+ return shared_loop
670
+ else :
671
+ return _get_event_loop_no_warn ()
672
+
673
+
674
+ def _reinstate_event_loop_on_main_thread (loop_scope : _ScopeName ) -> None :
675
+ shared_loop = _SCOPE_TO_CONTEXTVAR [loop_scope ].get ()
676
+ if shared_loop is None :
677
+ return
678
+
679
+ policy = _get_event_loop_policy ()
680
+ policy .set_event_loop (shared_loop )
610
681
611
682
612
683
@pytest .hookimpl (tryfirst = True , hookwrapper = True )
@@ -659,9 +730,22 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
659
730
return None
660
731
661
732
662
- def wrap_in_sync (
663
- func : Callable [..., Awaitable [Any ]],
664
- ):
733
+ def get_async_test_wrapper (
734
+ item : Function ,
735
+ func : CoroutineFunction [_P , T ],
736
+ ) -> Callable [_P , T ]:
737
+ """Returns a synchronous wrapper for the specified async test function."""
738
+ marker = item .get_closest_marker ("asyncio" )
739
+ assert marker is not None
740
+ default_loop_scope = _get_default_test_loop_scope (item .config )
741
+ loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
742
+ return _wrap_in_sync (func , loop_scope )
743
+
744
+
745
+ def _wrap_in_sync (
746
+ func : CoroutineFunction [_P , T ],
747
+ loop_scope : _ScopeName ,
748
+ ) -> Callable [_P , T ]:
665
749
"""
666
750
Return a sync wrapper around an async function executing it in the
667
751
current event loop.
@@ -670,15 +754,10 @@ def wrap_in_sync(
670
754
@functools .wraps (func )
671
755
def inner (* args , ** kwargs ):
672
756
coro = func (* args , ** kwargs )
673
- try :
674
- _loop = _get_event_loop_no_warn ()
675
- except RuntimeError :
676
- # Handle situation where asyncio.set_event_loop(None) removes shared loops.
677
- _reinstate_event_loop_on_main_thread ()
678
- _loop = _get_event_loop_no_warn ()
757
+ _loop = _get_or_restore_event_loop (loop_scope )
679
758
task = asyncio .ensure_future (coro , loop = _loop )
680
759
try :
681
- _loop .run_until_complete (task )
760
+ return _loop .run_until_complete (task )
682
761
except BaseException :
683
762
# run_until_complete doesn't get the result from exceptions
684
763
# that are not subclasses of `Exception`. Consume all
@@ -758,7 +837,7 @@ def _get_marked_loop_scope(
758
837
if "scope" in asyncio_marker .kwargs :
759
838
if "loop_scope" in asyncio_marker .kwargs :
760
839
raise pytest .UsageError (_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR )
761
- warnings . warn ( PytestDeprecationWarning ( _MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
840
+ _warn_scope_deprecation_once ( id ( asyncio_marker ))
762
841
scope = asyncio_marker .kwargs .get ("loop_scope" ) or asyncio_marker .kwargs .get (
763
842
"scope"
764
843
)
@@ -796,6 +875,8 @@ def _scoped_runner(
796
875
debug_mode = _get_asyncio_debug (request .config )
797
876
with _temporary_event_loop_policy (new_loop_policy ):
798
877
runner = Runner (debug = debug_mode ).__enter__ ()
878
+ shared_loop = runner .get_loop ()
879
+ _SCOPE_TO_CONTEXTVAR [scope ].set (shared_loop )
799
880
try :
800
881
yield runner
801
882
except Exception as e :
@@ -812,6 +893,8 @@ def _scoped_runner(
812
893
_RUNNER_TEARDOWN_WARNING % traceback .format_exc (),
813
894
RuntimeWarning ,
814
895
)
896
+ finally :
897
+ _SCOPE_TO_CONTEXTVAR [scope ].set (None )
815
898
816
899
return _scoped_runner
817
900
0 commit comments