@@ -2,10 +2,11 @@ use std::iter;
22
33use proc_macro2:: TokenStream ;
44use quote:: { quote, quote_spanned, ToTokens } ;
5+ use syn:: visit_mut:: VisitMut ;
56use syn:: {
67 punctuated:: Punctuated , spanned:: Spanned , Block , Expr , ExprAsync , ExprCall , FieldPat , FnArg ,
78 Ident , Item , ItemFn , Pat , PatIdent , PatReference , PatStruct , PatTuple , PatTupleStruct , PatType ,
8- Path , Signature , Stmt , Token , TypePath ,
9+ Path , ReturnType , Signature , Stmt , Token , Type , TypePath ,
910} ;
1011
1112use crate :: {
@@ -18,7 +19,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
1819 input : MaybeItemFnRef < ' a , B > ,
1920 args : InstrumentArgs ,
2021 instrumented_function_name : & str ,
21- self_type : Option < & syn :: TypePath > ,
22+ self_type : Option < & TypePath > ,
2223) -> proc_macro2:: TokenStream {
2324 // these are needed ahead of time, as ItemFn contains the function body _and_
2425 // isn't representable inside a quote!/quote_spanned! macro
@@ -31,7 +32,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
3132 } = input;
3233
3334 let Signature {
34- output : return_type ,
35+ output,
3536 inputs : params,
3637 unsafety,
3738 asyncness,
@@ -49,8 +50,37 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
4950
5051 let warnings = args. warnings ( ) ;
5152
53+ let block = if let ReturnType :: Type ( _, return_type) = & output {
54+ let return_type = erase_impl_trait ( return_type) ;
55+ // Install a fake return statement as the first thing in the function
56+ // body, so that we eagerly infer that the return type is what we
57+ // declared in the async fn signature.
58+ // The `#[allow(unreachable_code)]` is given because the return
59+ // statement is unreachable, but does affect inference.
60+ let fake_return_edge = quote_spanned ! { return_type. span( ) =>
61+ #[ allow( unreachable_code) ]
62+ if false {
63+ let __tracing_attr_fake_return: #return_type =
64+ unreachable!( "this is just for type inference, and is unreachable code" ) ;
65+ return __tracing_attr_fake_return;
66+ }
67+ } ;
68+ quote ! {
69+ {
70+ #fake_return_edge
71+ #block
72+ }
73+ }
74+ } else {
75+ quote ! {
76+ {
77+ let _: ( ) = #block;
78+ }
79+ }
80+ } ;
81+
5282 let body = gen_block (
53- block,
83+ & block,
5484 params,
5585 asyncness. is_some ( ) ,
5686 args,
@@ -60,7 +90,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
6090
6191 quote ! (
6292 #( #attrs) *
63- #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>( #params) #return_type
93+ #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>( #params) #output
6494 #where_clause
6595 {
6696 #warnings
@@ -76,7 +106,7 @@ fn gen_block<B: ToTokens>(
76106 async_context : bool ,
77107 mut args : InstrumentArgs ,
78108 instrumented_function_name : & str ,
79- self_type : Option < & syn :: TypePath > ,
109+ self_type : Option < & TypePath > ,
80110) -> proc_macro2:: TokenStream {
81111 // generate the span's name
82112 let span_name = args
@@ -393,11 +423,11 @@ impl RecordType {
393423 "Wrapping" ,
394424 ] ;
395425
396- /// Parse `RecordType` from [syn:: Type] by looking up
426+ /// Parse `RecordType` from [Type] by looking up
397427 /// the [RecordType::TYPES_FOR_VALUE] array.
398- fn parse_from_ty ( ty : & syn :: Type ) -> Self {
428+ fn parse_from_ty ( ty : & Type ) -> Self {
399429 match ty {
400- syn :: Type :: Path ( syn :: TypePath { path, .. } )
430+ Type :: Path ( TypePath { path, .. } )
401431 if path
402432 . segments
403433 . iter ( )
@@ -410,9 +440,7 @@ impl RecordType {
410440 {
411441 RecordType :: Value
412442 }
413- syn:: Type :: Reference ( syn:: TypeReference { elem, .. } ) => {
414- RecordType :: parse_from_ty ( & * elem)
415- }
443+ Type :: Reference ( syn:: TypeReference { elem, .. } ) => RecordType :: parse_from_ty ( & * elem) ,
416444 _ => RecordType :: Debug ,
417445 }
418446 }
@@ -471,7 +499,7 @@ pub(crate) struct AsyncInfo<'block> {
471499 // statement that must be patched
472500 source_stmt : & ' block Stmt ,
473501 kind : AsyncKind < ' block > ,
474- self_type : Option < syn :: TypePath > ,
502+ self_type : Option < TypePath > ,
475503 input : & ' block ItemFn ,
476504}
477505
@@ -606,11 +634,11 @@ impl<'block> AsyncInfo<'block> {
606634 if ident == "_self" {
607635 let mut ty = * ty. ty . clone ( ) ;
608636 // extract the inner type if the argument is "&self" or "&mut self"
609- if let syn :: Type :: Reference ( syn:: TypeReference { elem, .. } ) = ty {
637+ if let Type :: Reference ( syn:: TypeReference { elem, .. } ) = ty {
610638 ty = * elem;
611639 }
612640
613- if let syn :: Type :: Path ( tp) = ty {
641+ if let Type :: Path ( tp) = ty {
614642 self_type = Some ( tp) ;
615643 break ;
616644 }
@@ -722,7 +750,7 @@ struct IdentAndTypesRenamer<'a> {
722750 idents : Vec < ( Ident , Ident ) > ,
723751}
724752
725- impl < ' a > syn :: visit_mut :: VisitMut for IdentAndTypesRenamer < ' a > {
753+ impl < ' a > VisitMut for IdentAndTypesRenamer < ' a > {
726754 // we deliberately compare strings because we want to ignore the spans
727755 // If we apply clippy's lint, the behavior changes
728756 #[ allow( clippy:: cmp_owned) ]
@@ -734,11 +762,11 @@ impl<'a> syn::visit_mut::VisitMut for IdentAndTypesRenamer<'a> {
734762 }
735763 }
736764
737- fn visit_type_mut ( & mut self , ty : & mut syn :: Type ) {
765+ fn visit_type_mut ( & mut self , ty : & mut Type ) {
738766 for ( type_name, new_type) in & self . types {
739- if let syn :: Type :: Path ( TypePath { path, .. } ) = ty {
767+ if let Type :: Path ( TypePath { path, .. } ) = ty {
740768 if path_to_string ( path) == * type_name {
741- * ty = syn :: Type :: Path ( new_type. clone ( ) ) ;
769+ * ty = Type :: Path ( new_type. clone ( ) ) ;
742770 }
743771 }
744772 }
@@ -751,10 +779,33 @@ struct AsyncTraitBlockReplacer<'a> {
751779 patched_block : Block ,
752780}
753781
754- impl < ' a > syn :: visit_mut :: VisitMut for AsyncTraitBlockReplacer < ' a > {
782+ impl < ' a > VisitMut for AsyncTraitBlockReplacer < ' a > {
755783 fn visit_block_mut ( & mut self , i : & mut Block ) {
756784 if i == self . block {
757785 * i = self . patched_block . clone ( ) ;
758786 }
759787 }
760788}
789+
790+ // Replaces any `impl Trait` with `_` so it can be used as the type in
791+ // a `let` statement's LHS.
792+ struct ImplTraitEraser ;
793+
794+ impl VisitMut for ImplTraitEraser {
795+ fn visit_type_mut ( & mut self , t : & mut Type ) {
796+ if let Type :: ImplTrait ( ..) = t {
797+ * t = syn:: TypeInfer {
798+ underscore_token : Token ! [ _] ( t. span ( ) ) ,
799+ }
800+ . into ( ) ;
801+ } else {
802+ syn:: visit_mut:: visit_type_mut ( self , t) ;
803+ }
804+ }
805+ }
806+
807+ fn erase_impl_trait ( ty : & Type ) -> Type {
808+ let mut ty = ty. clone ( ) ;
809+ ImplTraitEraser . visit_type_mut ( & mut ty) ;
810+ ty
811+ }
0 commit comments