@@ -89,15 +89,15 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
89
89
}
90
90
91
91
SuiteKind :: Function => {
92
- if let Some ( docstring) = DocstringStmt :: try_from_statement ( first) {
92
+ if let Some ( docstring) = DocstringStmt :: try_from_statement ( first, self . kind ) {
93
93
SuiteChildStatement :: Docstring ( docstring)
94
94
} else {
95
95
SuiteChildStatement :: Other ( first)
96
96
}
97
97
}
98
98
99
99
SuiteKind :: Class => {
100
- if let Some ( docstring) = DocstringStmt :: try_from_statement ( first) {
100
+ if let Some ( docstring) = DocstringStmt :: try_from_statement ( first, self . kind ) {
101
101
if !comments. has_leading ( first)
102
102
&& lines_before ( first. start ( ) , source) > 1
103
103
&& !source_type. is_stub ( )
@@ -150,7 +150,7 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
150
150
true
151
151
} else if f. options ( ) . preview ( ) . is_enabled ( )
152
152
&& self . kind == SuiteKind :: TopLevel
153
- && DocstringStmt :: try_from_statement ( first. statement ( ) ) . is_some ( )
153
+ && DocstringStmt :: try_from_statement ( first. statement ( ) , self . kind ) . is_some ( )
154
154
{
155
155
// Only in preview mode, insert a newline after a module level docstring, but treat
156
156
// it as a docstring otherwise. See: https://github.com/psf/black/pull/3932.
@@ -543,17 +543,25 @@ impl<'ast> IntoFormat<PyFormatContext<'ast>> for Suite {
543
543
544
544
/// A statement representing a docstring.
545
545
#[ derive( Copy , Clone , Debug ) ]
546
- pub ( crate ) struct DocstringStmt < ' a > ( & ' a Stmt ) ;
546
+ pub ( crate ) struct DocstringStmt < ' a > {
547
+ /// The [`Stmt::Expr`]
548
+ docstring : & ' a Stmt ,
549
+ /// The parent suite kind
550
+ suite_kind : SuiteKind ,
551
+ }
547
552
548
553
impl < ' a > DocstringStmt < ' a > {
549
554
/// Checks if the statement is a simple string that can be formatted as a docstring
550
- fn try_from_statement ( stmt : & ' a Stmt ) -> Option < DocstringStmt < ' a > > {
555
+ fn try_from_statement ( stmt : & ' a Stmt , suite_kind : SuiteKind ) -> Option < DocstringStmt < ' a > > {
551
556
let Stmt :: Expr ( ast:: StmtExpr { value, .. } ) = stmt else {
552
557
return None ;
553
558
} ;
554
559
555
560
match value. as_ref ( ) {
556
- Expr :: StringLiteral ( value) if !value. implicit_concatenated => Some ( DocstringStmt ( stmt) ) ,
561
+ Expr :: StringLiteral ( value) if !value. implicit_concatenated => Some ( DocstringStmt {
562
+ docstring : stmt,
563
+ suite_kind,
564
+ } ) ,
557
565
_ => None ,
558
566
}
559
567
}
@@ -562,14 +570,14 @@ impl<'a> DocstringStmt<'a> {
562
570
impl Format < PyFormatContext < ' _ > > for DocstringStmt < ' _ > {
563
571
fn fmt ( & self , f : & mut Formatter < PyFormatContext < ' _ > > ) -> FormatResult < ( ) > {
564
572
let comments = f. context ( ) . comments ( ) . clone ( ) ;
565
- let node_comments = comments. leading_dangling_trailing ( self . 0 ) ;
573
+ let node_comments = comments. leading_dangling_trailing ( self . docstring ) ;
566
574
567
575
if FormatStmtExpr . is_suppressed ( node_comments. trailing , f. context ( ) ) {
568
- suppressed_node ( self . 0 ) . fmt ( f)
576
+ suppressed_node ( self . docstring ) . fmt ( f)
569
577
} else {
570
578
// SAFETY: Safe because `DocStringStmt` guarantees that it only ever wraps a `ExprStmt` containing a `ExprStringLiteral`.
571
579
let string_literal = self
572
- . 0
580
+ . docstring
573
581
. as_expr_stmt ( )
574
582
. unwrap ( )
575
583
. value
@@ -587,23 +595,25 @@ impl Format<PyFormatContext<'_>> for DocstringStmt<'_> {
587
595
]
588
596
) ?;
589
597
590
- // Comments after docstrings need a newline between the docstring and the comment.
591
- // (https://github.com/astral-sh/ruff/issues/7948)
592
- // ```python
593
- // class ModuleBrowser:
594
- // """Browse module classes and functions in IDLE."""
595
- // # ^ Insert a newline above here
596
- //
597
- // def __init__(self, master, path, *, _htest=False, _utest=False):
598
- // pass
599
- // ```
600
- if let Some ( own_line) = node_comments
601
- . trailing
602
- . iter ( )
603
- . find ( |comment| comment. line_position ( ) . is_own_line ( ) )
604
- {
605
- if lines_before ( own_line. start ( ) , f. context ( ) . source ( ) ) < 2 {
606
- empty_line ( ) . fmt ( f) ?;
598
+ if self . suite_kind == SuiteKind :: Class {
599
+ // Comments after class docstrings need a newline between the docstring and the
600
+ // comment (https://github.com/astral-sh/ruff/issues/7948).
601
+ // ```python
602
+ // class ModuleBrowser:
603
+ // """Browse module classes and functions in IDLE."""
604
+ // # ^ Insert a newline above here
605
+ //
606
+ // def __init__(self, master, path, *, _htest=False, _utest=False):
607
+ // pass
608
+ // ```
609
+ if let Some ( own_line) = node_comments
610
+ . trailing
611
+ . iter ( )
612
+ . find ( |comment| comment. line_position ( ) . is_own_line ( ) )
613
+ {
614
+ if lines_before ( own_line. start ( ) , f. context ( ) . source ( ) ) < 2 {
615
+ empty_line ( ) . fmt ( f) ?;
616
+ }
607
617
}
608
618
}
609
619
@@ -625,7 +635,7 @@ pub(crate) enum SuiteChildStatement<'a> {
625
635
impl < ' a > SuiteChildStatement < ' a > {
626
636
pub ( crate ) const fn statement ( self ) -> & ' a Stmt {
627
637
match self {
628
- SuiteChildStatement :: Docstring ( docstring) => docstring. 0 ,
638
+ SuiteChildStatement :: Docstring ( docstring) => docstring. docstring ,
629
639
SuiteChildStatement :: Other ( statement) => statement,
630
640
}
631
641
}
0 commit comments