@@ -15,7 +15,7 @@ use futures_util::{
1515 future:: { self , poll_fn, MapErr } ,
1616 TryFutureExt ,
1717} ;
18- use http:: { Request , Response } ;
18+ use http:: { HeaderMap , Request , Response } ;
1919use hyper:: {
2020 server:: { accept:: Accept , conn} ,
2121 Body ,
@@ -39,9 +39,11 @@ use tower::{
3939} ;
4040#[ cfg( feature = "tls" ) ]
4141use tracing:: error;
42+ use tracing_futures:: { Instrument , Instrumented } ;
4243
4344type BoxService = tower:: util:: BoxService < Request < Body > , Response < BoxBody > , crate :: Error > ;
4445type Interceptor = Arc < dyn Layer < BoxService , Service = BoxService > + Send + Sync + ' static > ;
46+ type TraceInterceptor = Arc < dyn Fn ( & HeaderMap ) -> tracing:: Span + Send + Sync + ' static > ;
4547
4648/// A default batteries included `transport` server.
4749///
@@ -54,6 +56,7 @@ type Interceptor = Arc<dyn Layer<BoxService, Service = BoxService> + Send + Sync
5456#[ derive( Default , Clone ) ]
5557pub struct Server {
5658 interceptor : Option < Interceptor > ,
59+ trace_interceptor : Option < TraceInterceptor > ,
5760 concurrency_limit : Option < usize > ,
5861 // timeout: Option<Duration>,
5962 #[ cfg( feature = "tls" ) ]
@@ -211,6 +214,17 @@ impl Server {
211214 }
212215 }
213216
217+ /// Intercept inbound headers and add a [`tracing::Span`] to each response future.
218+ pub fn trace_fn < F > ( self , f : F ) -> Self
219+ where
220+ F : Fn ( & HeaderMap ) -> tracing:: Span + Send + Sync + ' static ,
221+ {
222+ Server {
223+ trace_interceptor : Some ( Arc :: new ( f) ) ,
224+ ..self
225+ }
226+ }
227+
214228 /// Create a router with the `S` typed service as the first service.
215229 ///
216230 /// This will clone the `Server` builder and create a router that will
@@ -241,6 +255,7 @@ impl Server {
241255 F : Future < Output = ( ) > ,
242256 {
243257 let interceptor = self . interceptor . clone ( ) ;
258+ let span = self . trace_interceptor . clone ( ) ;
244259 let concurrency_limit = self . concurrency_limit ;
245260 let init_connection_window_size = self . init_connection_window_size ;
246261 let init_stream_window_size = self . init_stream_window_size ;
@@ -282,6 +297,7 @@ impl Server {
282297 interceptor,
283298 concurrency_limit,
284299 // timeout,
300+ span,
285301 } ;
286302
287303 let server = hyper:: Server :: builder ( incoming)
@@ -399,8 +415,10 @@ impl fmt::Debug for Server {
399415 }
400416}
401417
402- #[ derive( Debug ) ]
403- struct Svc < S > ( S ) ;
418+ struct Svc < S > {
419+ inner : S ,
420+ span : Option < TraceInterceptor > ,
421+ }
404422
405423impl < S > Service < Request < Body > > for Svc < S >
406424where
@@ -409,14 +427,26 @@ where
409427{
410428 type Response = Response < BoxBody > ;
411429 type Error = crate :: Error ;
412- type Future = MapErr < S :: Future , fn ( S :: Error ) -> crate :: Error > ;
430+ type Future = MapErr < Instrumented < S :: Future > , fn ( S :: Error ) -> crate :: Error > ;
413431
414432 fn poll_ready ( & mut self , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Self :: Error > > {
415- self . 0 . poll_ready ( cx) . map_err ( Into :: into)
433+ self . inner . poll_ready ( cx) . map_err ( Into :: into)
416434 }
417435
418436 fn call ( & mut self , req : Request < Body > ) -> Self :: Future {
419- self . 0 . call ( req) . map_err ( |e| e. into ( ) )
437+ let span = if let Some ( trace_interceptor) = & self . span {
438+ trace_interceptor ( req. headers ( ) )
439+ } else {
440+ tracing:: Span :: none ( )
441+ } ;
442+
443+ self . inner . call ( req) . instrument ( span) . map_err ( |e| e. into ( ) )
444+ }
445+ }
446+
447+ impl < S > fmt:: Debug for Svc < S > {
448+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
449+ f. debug_struct ( "Svc" ) . finish ( )
420450 }
421451}
422452
@@ -425,6 +455,7 @@ struct MakeSvc<S> {
425455 concurrency_limit : Option < usize > ,
426456 // timeout: Option<Duration>,
427457 inner : S ,
458+ span : Option < TraceInterceptor > ,
428459}
429460
430461impl < S , T > Service < T > for MakeSvc < S >
@@ -447,6 +478,7 @@ where
447478 let svc = self . inner . clone ( ) ;
448479 let concurrency_limit = self . concurrency_limit ;
449480 // let timeout = self.timeout.clone();
481+ let span = self . span . clone ( ) ;
450482
451483 Box :: pin ( async move {
452484 let svc = ServiceBuilder :: new ( )
@@ -455,10 +487,10 @@ where
455487 . service ( svc) ;
456488
457489 let svc = if let Some ( interceptor) = interceptor {
458- let layered = interceptor. layer ( BoxService :: new ( Svc ( svc) ) ) ;
459- BoxService :: new ( Svc ( layered) )
490+ let layered = interceptor. layer ( BoxService :: new ( Svc { inner : svc, span } ) ) ;
491+ BoxService :: new ( layered)
460492 } else {
461- BoxService :: new ( Svc ( svc) )
493+ BoxService :: new ( Svc { inner : svc, span } )
462494 } ;
463495
464496 Ok ( svc)
0 commit comments