|
| 1 | +use std::iter::once; |
| 2 | +use std::mem::take; |
| 3 | + |
1 | 4 | use proc_macro2::{Span, TokenStream}; |
2 | 5 | use proc_macro_error::emit_error; |
3 | 6 | use quote::{quote, ToTokens}; |
| 7 | +use syn::punctuated::{Pair, Punctuated}; |
4 | 8 | use syn::spanned::Spanned; |
5 | 9 | use syn::visit_mut::VisitMut; |
6 | 10 | use syn::{ |
7 | | - parse_quote, parse_quote_spanned, token, visit_mut, FnArg, GenericParam, Ident, Lifetime, Pat, |
8 | | - Receiver, ReturnType, Signature, Type, TypeImplTrait, TypeReference, WhereClause, |
| 11 | + parse_quote, parse_quote_spanned, visit_mut, FnArg, GenericParam, Ident, Lifetime, |
| 12 | + LifetimeParam, Pat, Receiver, ReturnType, Signature, Type, TypeImplTrait, TypeParam, |
| 13 | + TypeParamBound, TypeReference, WherePredicate, |
9 | 14 | }; |
10 | 15 |
|
11 | 16 | use super::lifetime; |
12 | 17 |
|
| 18 | +fn type_is_generic(ty: &Type, param: &TypeParam) -> bool { |
| 19 | + match ty { |
| 20 | + Type::Path(path) => path.path.is_ident(¶m.ident), |
| 21 | + _ => false, |
| 22 | + } |
| 23 | +} |
| 24 | + |
13 | 25 | #[derive(Default)] |
14 | 26 | pub struct CollectArgs { |
15 | 27 | needs_boxing: bool, |
@@ -99,48 +111,68 @@ impl HookSignature { |
99 | 111 | .. |
100 | 112 | } = sig; |
101 | 113 |
|
102 | | - let hook_lifetime = { |
103 | | - let hook_lifetime = Lifetime::new("'hook", Span::mixed_site()); |
104 | | - generics.params = { |
105 | | - let elided_lifetimes = &lifetimes.elided; |
106 | | - let params = &generics.params; |
107 | | - |
108 | | - parse_quote!(#hook_lifetime, #(#elided_lifetimes,)* #params) |
109 | | - }; |
110 | | - |
111 | | - let mut where_clause = generics |
112 | | - .where_clause |
113 | | - .clone() |
114 | | - .unwrap_or_else(|| WhereClause { |
115 | | - where_token: token::Where { |
116 | | - span: Span::mixed_site(), |
117 | | - }, |
118 | | - predicates: Default::default(), |
119 | | - }); |
| 114 | + let hook_lifetime = Lifetime::new("'hook", Span::mixed_site()); |
| 115 | + let mut params: Punctuated<_, _> = once(hook_lifetime.clone()) |
| 116 | + .chain(lifetimes.elided) |
| 117 | + .map(|lifetime| { |
| 118 | + GenericParam::Lifetime(LifetimeParam { |
| 119 | + attrs: vec![], |
| 120 | + lifetime, |
| 121 | + colon_token: None, |
| 122 | + bounds: Default::default(), |
| 123 | + }) |
| 124 | + }) |
| 125 | + .map(|param| Pair::new(param, Some(Default::default()))) |
| 126 | + .chain(take(&mut generics.params).into_pairs()) |
| 127 | + .collect(); |
120 | 128 |
|
121 | | - for elided in lifetimes.elided.iter() { |
122 | | - where_clause |
123 | | - .predicates |
124 | | - .push(parse_quote!(#elided: #hook_lifetime)); |
125 | | - } |
| 129 | + for type_param in params.iter_mut().skip(1) { |
| 130 | + match type_param { |
| 131 | + GenericParam::Lifetime(param) => { |
| 132 | + if let Some(predicate) = generics |
| 133 | + .where_clause |
| 134 | + .iter_mut() |
| 135 | + .flat_map(|c| &mut c.predicates) |
| 136 | + .find_map(|predicate| match predicate { |
| 137 | + WherePredicate::Lifetime(p) if p.lifetime == param.lifetime => Some(p), |
| 138 | + _ => None, |
| 139 | + }) |
| 140 | + { |
| 141 | + predicate.bounds.push(hook_lifetime.clone()); |
| 142 | + } else { |
| 143 | + param.colon_token = Some(param.colon_token.unwrap_or_default()); |
| 144 | + param.bounds.push(hook_lifetime.clone()); |
| 145 | + } |
| 146 | + } |
126 | 147 |
|
127 | | - for explicit in lifetimes.explicit.iter() { |
128 | | - where_clause |
129 | | - .predicates |
130 | | - .push(parse_quote!(#explicit: #hook_lifetime)); |
131 | | - } |
| 148 | + GenericParam::Type(param) => { |
| 149 | + if let Some(predicate) = generics |
| 150 | + .where_clause |
| 151 | + .iter_mut() |
| 152 | + .flat_map(|c| &mut c.predicates) |
| 153 | + .find_map(|predicate| match predicate { |
| 154 | + WherePredicate::Type(p) if type_is_generic(&p.bounded_ty, param) => { |
| 155 | + Some(p) |
| 156 | + } |
| 157 | + _ => None, |
| 158 | + }) |
| 159 | + { |
| 160 | + predicate |
| 161 | + .bounds |
| 162 | + .push(TypeParamBound::Lifetime(hook_lifetime.clone())); |
| 163 | + } else { |
| 164 | + param.colon_token = Some(param.colon_token.unwrap_or_default()); |
| 165 | + param |
| 166 | + .bounds |
| 167 | + .push(TypeParamBound::Lifetime(hook_lifetime.clone())); |
| 168 | + } |
| 169 | + } |
132 | 170 |
|
133 | | - for type_param in generics.type_params() { |
134 | | - let type_param_ident = &type_param.ident; |
135 | | - where_clause |
136 | | - .predicates |
137 | | - .push(parse_quote!(#type_param_ident: #hook_lifetime)); |
| 171 | + GenericParam::Const(_) => {} |
138 | 172 | } |
| 173 | + } |
139 | 174 |
|
140 | | - generics.where_clause = Some(where_clause); |
141 | | - |
142 | | - hook_lifetime |
143 | | - }; |
| 175 | + generics.params = params; |
144 | 176 |
|
145 | 177 | let (output, output_type) = Self::rewrite_return_type(&hook_lifetime, return_type); |
146 | 178 | sig.output = output; |
|
0 commit comments