11use std:: borrow:: Cow ;
2+ use std:: iter;
23
34use anyhow:: Result ;
45use anyhow:: { bail, Context } ;
@@ -13,10 +14,11 @@ use ruff_python_ast::helpers::Truthiness;
1314use ruff_python_ast:: parenthesize:: parenthesized_range;
1415use ruff_python_ast:: visitor:: Visitor ;
1516use ruff_python_ast:: {
16- self as ast, Arguments , BoolOp , ExceptHandler , Expr , Keyword , Stmt , UnaryOp ,
17+ self as ast, AnyNodeRef , Arguments , BoolOp , ExceptHandler , Expr , Keyword , Stmt , UnaryOp ,
1718} ;
1819use ruff_python_ast:: { visitor, whitespace} ;
1920use ruff_python_codegen:: Stylist ;
21+ use ruff_python_semantic:: { Binding , BindingKind } ;
2022use ruff_source_file:: LineRanges ;
2123use ruff_text_size:: Ranged ;
2224
@@ -266,47 +268,48 @@ fn check_assert_in_except(name: &str, body: &[Stmt]) -> Vec<Diagnostic> {
266268
267269/// PT009
268270pub ( crate ) fn unittest_assertion (
269- checker : & Checker ,
271+ checker : & mut Checker ,
270272 expr : & Expr ,
271273 func : & Expr ,
272274 args : & [ Expr ] ,
273275 keywords : & [ Keyword ] ,
274- ) -> Option < Diagnostic > {
275- match func {
276- Expr :: Attribute ( ast :: ExprAttribute { attr , .. } ) => {
277- if let Ok ( unittest_assert ) = UnittestAssert :: try_from ( attr . as_str ( ) ) {
278- let mut diagnostic = Diagnostic :: new (
279- PytestUnittestAssertion {
280- assertion : unittest_assert . to_string ( ) ,
281- } ,
282- func . range ( ) ,
283- ) ;
284- // We're converting an expression to a statement, so avoid applying the fix if
285- // the assertion is part of a larger expression.
286- if checker . semantic ( ) . current_statement ( ) . is_expr_stmt ( )
287- && checker . semantic ( ) . current_expression_parent ( ) . is_none ( )
288- && !checker . comment_ranges ( ) . intersects ( expr . range ( ) )
289- {
290- if let Ok ( stmt ) = unittest_assert . generate_assert ( args , keywords ) {
291- diagnostic . set_fix ( Fix :: unsafe_edit ( Edit :: range_replacement (
292- checker. generator ( ) . stmt ( & stmt ) ,
293- parenthesized_range (
294- expr. into ( ) ,
295- checker . semantic ( ) . current_statement ( ) . into ( ) ,
296- checker . comment_ranges ( ) ,
297- checker . locator ( ) . contents ( ) ,
298- )
299- . unwrap_or ( expr . range ( ) ) ,
300- ) ) ) ;
301- }
302- }
303- Some ( diagnostic )
304- } else {
305- None
306- }
276+ ) {
277+ let Expr :: Attribute ( ast :: ExprAttribute { attr , .. } ) = func else {
278+ return ;
279+ } ;
280+
281+ let Ok ( unittest_assert ) = UnittestAssert :: try_from ( attr . as_str ( ) ) else {
282+ return ;
283+ } ;
284+
285+ let mut diagnostic = Diagnostic :: new (
286+ PytestUnittestAssertion {
287+ assertion : unittest_assert . to_string ( ) ,
288+ } ,
289+ func . range ( ) ,
290+ ) ;
291+
292+ // We're converting an expression to a statement, so avoid applying the fix if
293+ // the assertion is part of a larger expression.
294+ if checker. semantic ( ) . current_statement ( ) . is_expr_stmt ( )
295+ && checker . semantic ( ) . current_expression_parent ( ) . is_none ( )
296+ && !checker . comment_ranges ( ) . intersects ( expr. range ( ) )
297+ {
298+ if let Ok ( stmt ) = unittest_assert . generate_assert ( args , keywords ) {
299+ diagnostic . set_fix ( Fix :: unsafe_edit ( Edit :: range_replacement (
300+ checker . generator ( ) . stmt ( & stmt ) ,
301+ parenthesized_range (
302+ expr . into ( ) ,
303+ checker . semantic ( ) . current_statement ( ) . into ( ) ,
304+ checker . comment_ranges ( ) ,
305+ checker . locator ( ) . contents ( ) ,
306+ )
307+ . unwrap_or ( expr . range ( ) ) ,
308+ ) ) ) ;
307309 }
308- _ => None ,
309310 }
311+
312+ checker. diagnostics . push ( diagnostic) ;
310313}
311314
312315/// ## What it does
@@ -364,9 +367,96 @@ impl Violation for PytestUnittestRaisesAssertion {
364367}
365368
366369/// PT027
367- pub ( crate ) fn unittest_raises_assertion (
370+ pub ( crate ) fn unittest_raises_assertion_call ( checker : & mut Checker , call : & ast:: ExprCall ) {
371+ // Bindings in `with` statements are handled by `unittest_raises_assertion_bindings`.
372+ if let Stmt :: With ( ast:: StmtWith { items, .. } ) = checker. semantic ( ) . current_statement ( ) {
373+ let call_ref = AnyNodeRef :: from ( call) ;
374+
375+ if items. iter ( ) . any ( |item| {
376+ AnyNodeRef :: from ( & item. context_expr ) . ptr_eq ( call_ref) && item. optional_vars . is_some ( )
377+ } ) {
378+ return ;
379+ }
380+ }
381+
382+ if let Some ( diagnostic) = unittest_raises_assertion ( call, vec ! [ ] , checker) {
383+ checker. diagnostics . push ( diagnostic) ;
384+ }
385+ }
386+
387+ /// PT027
388+ pub ( crate ) fn unittest_raises_assertion_binding (
368389 checker : & Checker ,
390+ binding : & Binding ,
391+ ) -> Option < Diagnostic > {
392+ if !matches ! ( binding. kind, BindingKind :: WithItemVar ) {
393+ return None ;
394+ }
395+
396+ let semantic = checker. semantic ( ) ;
397+
398+ let Stmt :: With ( with) = binding. statement ( semantic) ? else {
399+ return None ;
400+ } ;
401+
402+ let Expr :: Call ( call) = corresponding_context_expr ( binding, with) ? else {
403+ return None ;
404+ } ;
405+
406+ let mut edits = vec ! [ ] ;
407+
408+ // Rewrite all references to `.exception` to `.value`:
409+ // ```py
410+ // # Before
411+ // with self.assertRaises(Exception) as e:
412+ // ...
413+ // print(e.exception)
414+ //
415+ // # After
416+ // with pytest.raises(Exception) as e:
417+ // ...
418+ // print(e.value)
419+ // ```
420+ for reference_id in binding. references ( ) {
421+ let reference = semantic. reference ( reference_id) ;
422+ let node_id = reference. expression_id ( ) ?;
423+
424+ let mut ancestors = semantic. expressions ( node_id) . skip ( 1 ) ;
425+
426+ let Expr :: Attribute ( ast:: ExprAttribute { attr, .. } ) = ancestors. next ( ) ? else {
427+ continue ;
428+ } ;
429+
430+ if attr. as_str ( ) == "exception" {
431+ edits. push ( Edit :: range_replacement ( "value" . to_string ( ) , attr. range ) ) ;
432+ }
433+ }
434+
435+ unittest_raises_assertion ( call, edits, checker)
436+ }
437+
438+ fn corresponding_context_expr < ' a > ( binding : & Binding , with : & ' a ast:: StmtWith ) -> Option < & ' a Expr > {
439+ with. items . iter ( ) . find_map ( |item| {
440+ let Some ( optional_var) = & item. optional_vars else {
441+ return None ;
442+ } ;
443+
444+ let Expr :: Name ( name) = optional_var. as_ref ( ) else {
445+ return None ;
446+ } ;
447+
448+ if name. range == binding. range {
449+ Some ( & item. context_expr )
450+ } else {
451+ None
452+ }
453+ } )
454+ }
455+
456+ fn unittest_raises_assertion (
369457 call : & ast:: ExprCall ,
458+ extra_edits : Vec < Edit > ,
459+ checker : & Checker ,
370460) -> Option < Diagnostic > {
371461 let Expr :: Attribute ( ast:: ExprAttribute { attr, .. } ) = call. func . as_ref ( ) else {
372462 return None ;
@@ -385,19 +475,25 @@ pub(crate) fn unittest_raises_assertion(
385475 } ,
386476 call. func . range ( ) ,
387477 ) ;
478+
388479 if !checker
389480 . comment_ranges ( )
390481 . has_comments ( call, checker. source ( ) )
391482 {
392483 if let Some ( args) = to_pytest_raises_args ( checker, attr. as_str ( ) , & call. arguments ) {
393484 diagnostic. try_set_fix ( || {
394- let ( import_edit , binding) = checker. importer ( ) . get_or_import_symbol (
485+ let ( import_pytest_raises , binding) = checker. importer ( ) . get_or_import_symbol (
395486 & ImportRequest :: import ( "pytest" , "raises" ) ,
396487 call. func . start ( ) ,
397488 checker. semantic ( ) ,
398489 ) ?;
399- let edit = Edit :: range_replacement ( format ! ( "{binding}({args})" ) , call. range ( ) ) ;
400- Ok ( Fix :: unsafe_edits ( import_edit, [ edit] ) )
490+ let replace_call =
491+ Edit :: range_replacement ( format ! ( "{binding}({args})" ) , call. range ( ) ) ;
492+
493+ Ok ( Fix :: unsafe_edits (
494+ import_pytest_raises,
495+ iter:: once ( replace_call) . chain ( extra_edits) ,
496+ ) )
401497 } ) ;
402498 }
403499 }
0 commit comments