6
6
from collections .abc import Mapping
7
7
from collections .abc import Sequence
8
8
from collections .abc import Sized
9
- from contextlib import AbstractContextManager
10
9
from decimal import Decimal
11
10
import math
12
11
from numbers import Complex
13
12
import pprint
14
13
import re
15
14
import sys
16
- from types import TracebackType
17
15
from typing import Any
18
- from typing import cast
19
- from typing import final
20
- from typing import get_args
21
- from typing import get_origin
22
16
from typing import overload
23
17
from typing import TYPE_CHECKING
24
18
from typing import TypeVar
25
19
26
20
import _pytest ._code
27
21
from _pytest .outcomes import fail
22
+ from _pytest .raises_group import BaseExcT_co_default
23
+ from _pytest .raises_group import RaisesExc
28
24
29
25
30
26
if sys .version_info < (3 , 11 ):
31
- from exceptiongroup import BaseExceptionGroup
32
- from exceptiongroup import ExceptionGroup
27
+ pass
33
28
34
29
if TYPE_CHECKING :
35
30
from numpy import ndarray
@@ -791,15 +786,29 @@ def _as_numpy_array(obj: object) -> ndarray | None:
791
786
792
787
# builtin pytest.raises helper
793
788
794
- E = TypeVar ("E" , bound = BaseException )
789
+ E = TypeVar ("E" , bound = BaseException , default = BaseException )
795
790
796
791
797
792
@overload
798
793
def raises (
799
794
expected_exception : type [E ] | tuple [type [E ], ...],
800
795
* ,
801
796
match : str | re .Pattern [str ] | None = ...,
802
- ) -> RaisesContext [E ]: ...
797
+ check : Callable [[BaseExcT_co_default ], bool ] = ...,
798
+ ) -> RaisesExc [E ]: ...
799
+
800
+
801
+ @overload
802
+ def raises (
803
+ * ,
804
+ match : str | re .Pattern [str ],
805
+ # If exception_type is not provided, check() must do any typechecks itself.
806
+ check : Callable [[BaseException ], bool ] = ...,
807
+ ) -> RaisesExc [BaseException ]: ...
808
+
809
+
810
+ @overload
811
+ def raises (* , check : Callable [[BaseException ], bool ]) -> RaisesExc [BaseException ]: ...
803
812
804
813
805
814
@overload
@@ -812,8 +821,10 @@ def raises(
812
821
813
822
814
823
def raises (
815
- expected_exception : type [E ] | tuple [type [E ], ...], * args : Any , ** kwargs : Any
816
- ) -> RaisesContext [E ] | _pytest ._code .ExceptionInfo [E ]:
824
+ expected_exception : type [E ] | tuple [type [E ], ...] | None = None ,
825
+ * args : Any ,
826
+ ** kwargs : Any ,
827
+ ) -> RaisesExc [BaseException ] | _pytest ._code .ExceptionInfo [E ]:
817
828
r"""Assert that a code block/function call raises an exception type, or one of its subclasses.
818
829
819
830
:param expected_exception:
@@ -960,117 +971,38 @@ def raises(
960
971
"""
961
972
__tracebackhide__ = True
962
973
974
+ if not args :
975
+ if set (kwargs ) - {"match" , "check" , "expected_exception" }:
976
+ msg = "Unexpected keyword arguments passed to pytest.raises: "
977
+ msg += ", " .join (sorted (kwargs ))
978
+ msg += "\n Use context-manager form instead?"
979
+ raise TypeError (msg )
980
+
981
+ if expected_exception is None :
982
+ return RaisesExc (** kwargs )
983
+ return RaisesExc (expected_exception , ** kwargs )
984
+
963
985
if not expected_exception :
964
986
raise ValueError (
965
987
f"Expected an exception type or a tuple of exception types, but got `{ expected_exception !r} `. "
966
988
f"Raising exceptions is already understood as failing the test, so you don't need "
967
989
f"any special code to say 'this should never raise an exception'."
968
990
)
969
-
970
- expected_exceptions : tuple [type [E ], ...]
971
- origin_exc : type [E ] | None = get_origin (expected_exception )
972
- if isinstance (expected_exception , type ):
973
- expected_exceptions = (expected_exception ,)
974
- elif origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
975
- expected_exceptions = (cast (type [E ], expected_exception ),)
976
- else :
977
- expected_exceptions = expected_exception
978
-
979
- def validate_exc (exc : type [E ]) -> type [E ]:
980
- __tracebackhide__ = True
981
- origin_exc : type [E ] | None = get_origin (exc )
982
- if origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
983
- exc_type = get_args (exc )[0 ]
984
- if (
985
- issubclass (origin_exc , ExceptionGroup ) and exc_type in (Exception , Any )
986
- ) or (
987
- issubclass (origin_exc , BaseExceptionGroup )
988
- and exc_type in (BaseException , Any )
989
- ):
990
- return cast (type [E ], origin_exc )
991
- else :
992
- raise ValueError (
993
- f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` "
994
- f"are accepted as generic types but got `{ exc } `. "
995
- f"As `raises` will catch all instances of the specified group regardless of the "
996
- f"generic argument specific nested exceptions has to be checked "
997
- f"with `ExceptionInfo.group_contains()`"
998
- )
999
-
1000
- elif not isinstance (exc , type ) or not issubclass (exc , BaseException ):
1001
- msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
1002
- not_a = exc .__name__ if isinstance (exc , type ) else type (exc ).__name__
1003
- raise TypeError (msg .format (not_a ))
1004
- else :
1005
- return exc
1006
-
1007
- expected_exceptions = tuple (validate_exc (exc ) for exc in expected_exceptions )
1008
-
1009
- message = f"DID NOT RAISE { expected_exception } "
1010
-
1011
- if not args :
1012
- match : str | re .Pattern [str ] | None = kwargs .pop ("match" , None )
1013
- if kwargs :
1014
- msg = "Unexpected keyword arguments passed to pytest.raises: "
1015
- msg += ", " .join (sorted (kwargs ))
1016
- msg += "\n Use context-manager form instead?"
1017
- raise TypeError (msg )
1018
- return RaisesContext (expected_exceptions , message , match )
1019
- else :
1020
- func = args [0 ]
1021
- if not callable (func ):
1022
- raise TypeError (f"{ func !r} object (type: { type (func )} ) must be callable" )
1023
- try :
1024
- func (* args [1 :], ** kwargs )
1025
- except expected_exceptions as e :
1026
- return _pytest ._code .ExceptionInfo .from_exception (e )
1027
- fail (message )
1028
-
1029
-
1030
- # This doesn't work with mypy for now. Use fail.Exception instead.
1031
- raises .Exception = fail .Exception # type: ignore
1032
-
1033
-
1034
- @final
1035
- class RaisesContext (AbstractContextManager [_pytest ._code .ExceptionInfo [E ]]):
1036
- def __init__ (
1037
- self ,
1038
- expected_exception : type [E ] | tuple [type [E ], ...],
1039
- message : str ,
1040
- match_expr : str | re .Pattern [str ] | None = None ,
1041
- ) -> None :
1042
- self .expected_exception = expected_exception
1043
- self .message = message
1044
- self .match_expr = match_expr
1045
- self .excinfo : _pytest ._code .ExceptionInfo [E ] | None = None
1046
- if self .match_expr is not None :
1047
- re_error = None
1048
- try :
1049
- re .compile (self .match_expr )
1050
- except re .error as e :
1051
- re_error = e
1052
- if re_error is not None :
1053
- fail (f"Invalid regex pattern provided to 'match': { re_error } " )
1054
-
1055
- def __enter__ (self ) -> _pytest ._code .ExceptionInfo [E ]:
1056
- self .excinfo = _pytest ._code .ExceptionInfo .for_later ()
1057
- return self .excinfo
1058
-
1059
- def __exit__ (
1060
- self ,
1061
- exc_type : type [BaseException ] | None ,
1062
- exc_val : BaseException | None ,
1063
- exc_tb : TracebackType | None ,
1064
- ) -> bool :
1065
- __tracebackhide__ = True
1066
- if exc_type is None :
1067
- fail (self .message )
1068
- assert self .excinfo is not None
1069
- if not issubclass (exc_type , self .expected_exception ):
1070
- return False
1071
- # Cast to narrow the exception type now that it's verified.
1072
- exc_info = cast (tuple [type [E ], E , TracebackType ], (exc_type , exc_val , exc_tb ))
1073
- self .excinfo .fill_unfilled (exc_info )
1074
- if self .match_expr is not None :
1075
- self .excinfo .match (self .match_expr )
1076
- return True
991
+ func = args [0 ]
992
+ if not callable (func ):
993
+ raise TypeError (f"{ func !r} object (type: { type (func )} ) must be callable" )
994
+ with RaisesExc (expected_exception ) as excinfo :
995
+ func (* args [1 :], ** kwargs )
996
+ try :
997
+ return excinfo
998
+ finally :
999
+ del excinfo
1000
+
1001
+
1002
+ # note: RaisesExc/RaisesGroup uses fail() internally, so this alias
1003
+ # indicates (to [internal] plugins?) that `pytest.raises` will
1004
+ # raise `_pytest.outcomes.Failed`, where
1005
+ # `outcomes.Failed is outcomes.fail.Exception is raises.Exception`
1006
+ # note: this is *not* the same as `_pytest.main.Failed`
1007
+ # note: mypy does not recognize this attribute
1008
+ raises .Exception = fail .Exception # type: ignore[attr-defined]
0 commit comments