22# pylint: disable=missing-function-docstring
33import asyncio
44import concurrent .futures
5+ import contextlib
56import functools
67import gc
78import inspect
89import math
910import sys
1011import threading
12+ import traceback
1113import unittest
1214import weakref
1315
@@ -120,6 +122,41 @@ def counting_fn(*args) -> int:
120122 return counting_fn , counter
121123
122124
125+ class LineCapture :
126+ def __init__ (self ):
127+ self .line = None
128+
129+ def record_next_line (self ):
130+ """Record the next line in the parent frame"""
131+ self .line = inspect .currentframe ().f_back .f_lineno + 1
132+
133+
134+ class ExceptionContextManager :
135+ exception : Exception
136+
137+
138+ @contextlib .contextmanager
139+ def assertRaisesWithLineInStackTrace (test : unittest .TestCase , exception_type , line : LineCapture ):
140+ try :
141+ container = ExceptionContextManager ()
142+ yield container
143+ except exception_type as exception :
144+ container .exception = exception
145+ traceback_exception = traceback .TracebackException .from_exception (exception )
146+ if not len (traceback_exception .stack ):
147+ test .fail ("Exception stack not preserved. Did you use the raw assertRaises by mistake?" )
148+ locations = [(frame .filename , frame .lineno ) for frame in traceback_exception .stack ]
149+ line_number = line .line
150+ error_message = [
151+ f"Traceback for exception { repr (exception )} did not have frame on line { line_number } . Exception below\n "
152+ ]
153+ error_message .extend (traceback_exception .format ())
154+ test .assertIn ((__file__ , line_number ), locations , msg = "" .join (error_message ))
155+
156+ else :
157+ test .fail ("expected exception not called" )
158+
159+
123160class TestFunctionInspection (unittest .TestCase ):
124161 """Unit tests for function inspection"""
125162
@@ -317,33 +354,42 @@ def test_partial(self):
317354
318355 def test_failing_function (self ):
319356 counter = Counter ()
357+ failing_line = LineCapture ()
320358
321359 @once .once
322360 def sample_failing_fn ():
361+ nonlocal failing_line
323362 if counter .get_incremented () < 4 :
363+ failing_line .record_next_line ()
324364 raise ValueError ("expected failure" )
325365 return 1
326366
327- with self .assertRaises (ValueError ):
367+ with assertRaisesWithLineInStackTrace (self , ValueError , failing_line ):
368+ sample_failing_fn ()
369+ with assertRaisesWithLineInStackTrace (self , ValueError , failing_line ) as cm :
328370 sample_failing_fn ()
371+ self .assertEqual (cm .exception .args [0 ], "expected failure" )
329372 self .assertEqual (counter .get_incremented (), 2 )
330- with self . assertRaises ( ValueError ):
373+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
331374 sample_failing_fn ()
332375 self .assertEqual (counter .get_incremented (), 3 , "Function call incremented the counter" )
333376
334377 def test_failing_function_retry_exceptions (self ):
335378 counter = Counter ()
379+ failing_line = LineCapture ()
336380
337381 @once .once (retry_exceptions = True )
338382 def sample_failing_fn ():
383+ nonlocal failing_line
339384 if counter .get_incremented () < 4 :
385+ failing_line .record_next_line ()
340386 raise ValueError ("expected failure" )
341387 return 1
342388
343- with self . assertRaises ( ValueError ):
389+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
344390 sample_failing_fn ()
345391 self .assertEqual (counter .get_incremented (), 2 )
346- with self . assertRaises ( ValueError ):
392+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
347393 sample_failing_fn ()
348394 # This ensures that this was a new function call, not a cached result.
349395 self .assertEqual (counter .get_incremented (), 4 )
@@ -363,13 +409,15 @@ def yielding_iterator():
363409
364410 def test_failing_generator (self ):
365411 counter = Counter ()
412+ failing_line = LineCapture ()
366413
367414 @once .once
368415 def sample_failing_fn ():
369416 yield counter .get_incremented ()
370417 result = counter .get_incremented ()
371418 yield result
372419 if result == 2 :
420+ failing_line .record_next_line ()
373421 raise ValueError ("expected failure after 2." )
374422
375423 # Both of these calls should return the same results.
@@ -379,9 +427,9 @@ def sample_failing_fn():
379427 self .assertEqual (next (call2 ), 1 )
380428 self .assertEqual (next (call1 ), 2 )
381429 self .assertEqual (next (call2 ), 2 )
382- with self . assertRaises ( ValueError ):
430+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
383431 next (call1 )
384- with self . assertRaises ( ValueError ):
432+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
385433 next (call2 )
386434 # These next 2 calls should also fail.
387435 call3 = sample_failing_fn ()
@@ -390,20 +438,22 @@ def sample_failing_fn():
390438 self .assertEqual (next (call4 ), 1 )
391439 self .assertEqual (next (call3 ), 2 )
392440 self .assertEqual (next (call4 ), 2 )
393- with self . assertRaises ( ValueError ):
441+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
394442 next (call3 )
395- with self . assertRaises ( ValueError ):
443+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
396444 next (call4 )
397445
398446 def test_failing_generator_retry_exceptions (self ):
399447 counter = Counter ()
448+ failing_line = LineCapture ()
400449
401450 @once .once (retry_exceptions = True )
402451 def sample_failing_fn ():
403452 yield counter .get_incremented ()
404453 result = counter .get_incremented ()
405454 yield result
406455 if result == 2 :
456+ failing_line .record_next_line ()
407457 raise ValueError ("expected failure after 2." )
408458
409459 # Both of these calls should return the same results.
@@ -413,9 +463,9 @@ def sample_failing_fn():
413463 self .assertEqual (next (call2 ), 1 )
414464 self .assertEqual (next (call1 ), 2 )
415465 self .assertEqual (next (call2 ), 2 )
416- with self . assertRaises ( ValueError ):
466+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
417467 next (call1 )
418- with self . assertRaises ( ValueError ):
468+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
419469 next (call2 )
420470 # These next 2 calls should succeed.
421471 call3 = sample_failing_fn ()
@@ -906,33 +956,37 @@ def execute(*args):
906956
907957 async def test_failing_function (self ):
908958 counter = Counter ()
959+ failing_line = LineCapture ()
909960
910961 @once .once
911962 async def sample_failing_fn ():
912963 if counter .get_incremented () < 4 :
964+ failing_line .record_next_line ()
913965 raise ValueError ("expected failure" )
914966 return 1
915967
916- with self . assertRaises ( ValueError ):
968+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
917969 await sample_failing_fn ()
918970 self .assertEqual (counter .get_incremented (), 2 )
919- with self . assertRaises ( ValueError ):
971+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
920972 await sample_failing_fn ()
921973 self .assertEqual (counter .get_incremented (), 3 , "Function call incremented the counter" )
922974
923975 async def test_failing_function_retry_exceptions (self ):
924976 counter = Counter ()
977+ failing_line = LineCapture ()
925978
926979 @once .once (retry_exceptions = True )
927980 async def sample_failing_fn ():
928981 if counter .get_incremented () < 4 :
982+ failing_line .record_next_line ()
929983 raise ValueError ("expected failure" )
930984 return 1
931985
932- with self . assertRaises ( ValueError ):
986+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
933987 await sample_failing_fn ()
934988 self .assertEqual (counter .get_incremented (), 2 )
935- with self . assertRaises ( ValueError ):
989+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
936990 await sample_failing_fn ()
937991 # This ensures that this was a new function call, not a cached result.
938992 self .assertEqual (counter .get_incremented (), 4 )
@@ -985,13 +1039,15 @@ async def async_yielding_iterator():
9851039
9861040 async def test_failing_generator (self ):
9871041 counter = Counter ()
1042+ failing_line = LineCapture ()
9881043
9891044 @once .once
9901045 async def sample_failing_fn ():
9911046 yield counter .get_incremented ()
9921047 result = counter .get_incremented ()
9931048 yield result
9941049 if result == 2 :
1050+ failing_line .record_next_line ()
9951051 raise ValueError ("we raise an error when result is exactly 2" )
9961052
9971053 # Both of these calls should return the same results.
@@ -1001,9 +1057,9 @@ async def sample_failing_fn():
10011057 self .assertEqual (await anext (call2 ), 1 )
10021058 self .assertEqual (await anext (call1 ), 2 )
10031059 self .assertEqual (await anext (call2 ), 2 )
1004- with self . assertRaises ( ValueError ):
1060+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
10051061 await anext (call1 )
1006- with self . assertRaises ( ValueError ):
1062+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
10071063 await anext (call2 )
10081064 # These next 2 calls should also fail.
10091065 call3 = sample_failing_fn ()
@@ -1012,20 +1068,22 @@ async def sample_failing_fn():
10121068 self .assertEqual (await anext (call4 ), 1 )
10131069 self .assertEqual (await anext (call3 ), 2 )
10141070 self .assertEqual (await anext (call4 ), 2 )
1015- with self . assertRaises ( ValueError ):
1071+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
10161072 await anext (call3 )
1017- with self . assertRaises ( ValueError ):
1073+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
10181074 await anext (call4 )
10191075
10201076 async def test_failing_generator_retry_exceptions (self ):
10211077 counter = Counter ()
1078+ failing_line = LineCapture ()
10221079
10231080 @once .once (retry_exceptions = True )
10241081 async def sample_failing_fn ():
10251082 yield counter .get_incremented ()
10261083 result = counter .get_incremented ()
10271084 yield result
10281085 if result == 2 :
1086+ failing_line .record_next_line ()
10291087 raise ValueError ("we raise an error when result is exactly 2" )
10301088
10311089 # Both of these calls should return the same results.
@@ -1035,9 +1093,9 @@ async def sample_failing_fn():
10351093 self .assertEqual (await anext (call2 ), 1 )
10361094 self .assertEqual (await anext (call1 ), 2 )
10371095 self .assertEqual (await anext (call2 ), 2 )
1038- with self . assertRaises ( ValueError ):
1096+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
10391097 await anext (call1 )
1040- with self . assertRaises ( ValueError ):
1098+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
10411099 await anext (call2 )
10421100 # These next 2 calls should succeed.
10431101 call3 = sample_failing_fn ()
0 commit comments