Skip to content

Commit f46a454

Browse files
authored
feat(transport): Add tracing support to server (#175)
* feat(transport): Add tracing to server * Add tracing example * Fix duplicate versions of tracing-subscriber
1 parent 1626c2e commit f46a454

File tree

8 files changed

+197
-10
lines changed

8 files changed

+197
-10
lines changed

tonic-examples/Cargo.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ path = "src/multiplex/client.rs"
6666
name = "gcp-client"
6767
path = "src/gcp/client.rs"
6868

69+
[[bin]]
70+
name = "tracing-client"
71+
path = "src/tracing/client.rs"
72+
73+
[[bin]]
74+
name = "tracing-server"
75+
path = "src/tracing/server.rs"
76+
6977
[dependencies]
7078
tonic = { path = "../tonic", features = ["tls"] }
7179
bytes = "0.4"
@@ -82,6 +90,12 @@ serde = { version = "1.0", features = ["derive"] }
8290
serde_json = "1.0"
8391
rand = "0.6"
8492

93+
# Tracing
94+
tracing = "0.1"
95+
tracing-subscriber = { version = "0.2.0-alpha", features = ["tracing-log"] }
96+
tracing-attributes = "0.1"
97+
tracing-futures = "0.2"
98+
8599
# Required for wellknown types
86100
prost-types = "0.5"
87101

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
pub mod hello_world {
2+
tonic::include_proto!("helloworld");
3+
}
4+
5+
use hello_world::{greeter_client::GreeterClient, HelloRequest};
6+
use tracing_attributes::instrument;
7+
8+
#[tokio::main]
9+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
10+
tracing_subscriber::FmtSubscriber::builder()
11+
.with_max_level(tracing::Level::DEBUG)
12+
.init();
13+
14+
say_hi("Bob".into()).await?;
15+
16+
Ok(())
17+
}
18+
19+
#[instrument]
20+
async fn say_hi(name: String) -> Result<(), Box<dyn std::error::Error>> {
21+
let mut client = GreeterClient::connect("http://[::1]:50051").await?;
22+
23+
let request = tonic::Request::new(HelloRequest { name });
24+
25+
tracing::info!(
26+
message = "Sending request.",
27+
request = %request.get_ref().name
28+
);
29+
30+
let response = client.say_hello(request).await?;
31+
32+
tracing::info!(
33+
message = "Got a response.",
34+
response = %response.get_ref().message
35+
);
36+
37+
Ok(())
38+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use tonic::{transport::Server, Request, Response, Status};
2+
3+
pub mod hello_world {
4+
tonic::include_proto!("helloworld");
5+
}
6+
7+
use hello_world::{
8+
greeter_server::{Greeter, GreeterServer},
9+
HelloReply, HelloRequest,
10+
};
11+
12+
#[derive(Default)]
13+
pub struct MyGreeter {}
14+
15+
#[tonic::async_trait]
16+
impl Greeter for MyGreeter {
17+
async fn say_hello(
18+
&self,
19+
request: Request<HelloRequest>,
20+
) -> Result<Response<HelloReply>, Status> {
21+
tracing::info!(message = "Inbound request.", metadata = ?request.metadata());
22+
23+
let reply = hello_world::HelloReply {
24+
message: format!("Hello {}!", request.into_inner().name).into(),
25+
};
26+
27+
tracing::debug!(message = "Sending reply.", response = %reply.message);
28+
29+
Ok(Response::new(reply))
30+
}
31+
}
32+
33+
#[tokio::main]
34+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
35+
tracing_subscriber::FmtSubscriber::builder()
36+
.with_max_level(tracing::Level::DEBUG)
37+
.init();
38+
39+
let addr = "[::1]:50051".parse().unwrap();
40+
let greeter = MyGreeter::default();
41+
42+
tracing::info!(message = "Starting server.", %addr);
43+
44+
Server::builder()
45+
.trace_fn(|_| tracing::info_span!("helloworld_server"))
46+
.add_service(GreeterServer::new(greeter))
47+
.serve(addr)
48+
.await?;
49+
50+
Ok(())
51+
}

tonic-interop/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ console = "0.9"
3232
structopt = "0.2"
3333

3434
tracing = "0.1"
35-
tracing-subscriber = "0.1.3"
35+
tracing-subscriber = "0.2.0-alpha"
3636
tracing-log = "0.1.0"
3737

3838
[build-dependencies]

tonic/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ transport = [
3131
"tower",
3232
"tower-balance",
3333
"tower-load",
34+
"tracing-futures",
3435
]
3536
tls = ["tokio-rustls"]
3637
tls-roots = ["rustls-native-certs"]
@@ -68,6 +69,7 @@ tower = { git = "https://github.com/tower-rs/tower", optional = true}
6869
tower-make = { version = "0.3", features = ["connect"] }
6970
tower-balance = { git = "https://github.com/tower-rs/tower", optional = true }
7071
tower-load = { git = "https://github.com/tower-rs/tower", optional = true }
72+
tracing-futures = { version = "0.2", optional = true }
7173

7274
# rustls
7375
tokio-rustls = { version = "0.12", optional = true }
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
use pin_project::pin_project;
2+
use std::{
3+
future::Future,
4+
pin::Pin,
5+
task::{Context, Poll},
6+
};
7+
8+
#[pin_project]
9+
pub struct ResponseFuture<F> {
10+
inner: F,
11+
span: Option<Span>,
12+
}
13+
14+
impl<F: Future> Future for ResponseFuture<F> {
15+
type Output = F::Output;
16+
17+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
18+
let me = self.project();
19+
20+
if let Some(span) = me.span.clone().take() {
21+
let _enter = span.enter();
22+
// me.poll(cx).map_err(Into::into)
23+
}
24+
}
25+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
use pin_project::pin_project;
2+
use std::{
3+
future::Future,
4+
pin::Pin,
5+
task::{Context, Poll},
6+
};
7+
8+
#[pin_project]
9+
pub struct ResponseFuture<F> {
10+
inner: F,
11+
span: Option<Span>,
12+
}
13+
14+
impl<F: Future> Future for ResponseFuture<F> {
15+
type Output = F::Output;
16+
17+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
18+
let me = self.project();
19+
20+
if let Some(span) = me.span.clone().take() {
21+
let _enter = span.enter();
22+
// me.poll(cx).map_err(Into::into)
23+
}
24+
}
25+
}

tonic/src/transport/server/mod.rs

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use futures_util::{
1515
future::{self, poll_fn, MapErr},
1616
TryFutureExt,
1717
};
18-
use http::{Request, Response};
18+
use http::{HeaderMap, Request, Response};
1919
use hyper::{
2020
server::{accept::Accept, conn},
2121
Body,
@@ -39,9 +39,11 @@ use tower::{
3939
};
4040
#[cfg(feature = "tls")]
4141
use tracing::error;
42+
use tracing_futures::{Instrument, Instrumented};
4243

4344
type BoxService = tower::util::BoxService<Request<Body>, Response<BoxBody>, crate::Error>;
4445
type Interceptor = Arc<dyn Layer<BoxService, Service = BoxService> + Send + Sync + 'static>;
46+
type TraceInterceptor = Arc<dyn Fn(&HeaderMap) -> tracing::Span + Send + Sync + 'static>;
4547

4648
/// A default batteries included `transport` server.
4749
///
@@ -54,6 +56,7 @@ type Interceptor = Arc<dyn Layer<BoxService, Service = BoxService> + Send + Sync
5456
#[derive(Default, Clone)]
5557
pub struct Server {
5658
interceptor: Option<Interceptor>,
59+
trace_interceptor: Option<TraceInterceptor>,
5760
concurrency_limit: Option<usize>,
5861
// timeout: Option<Duration>,
5962
#[cfg(feature = "tls")]
@@ -211,6 +214,17 @@ impl Server {
211214
}
212215
}
213216

217+
/// Intercept inbound headers and add a [`tracing::Span`] to each response future.
218+
pub fn trace_fn<F>(self, f: F) -> Self
219+
where
220+
F: Fn(&HeaderMap) -> tracing::Span + Send + Sync + 'static,
221+
{
222+
Server {
223+
trace_interceptor: Some(Arc::new(f)),
224+
..self
225+
}
226+
}
227+
214228
/// Create a router with the `S` typed service as the first service.
215229
///
216230
/// This will clone the `Server` builder and create a router that will
@@ -241,6 +255,7 @@ impl Server {
241255
F: Future<Output = ()>,
242256
{
243257
let interceptor = self.interceptor.clone();
258+
let span = self.trace_interceptor.clone();
244259
let concurrency_limit = self.concurrency_limit;
245260
let init_connection_window_size = self.init_connection_window_size;
246261
let init_stream_window_size = self.init_stream_window_size;
@@ -282,6 +297,7 @@ impl Server {
282297
interceptor,
283298
concurrency_limit,
284299
// timeout,
300+
span,
285301
};
286302

287303
let server = hyper::Server::builder(incoming)
@@ -399,8 +415,10 @@ impl fmt::Debug for Server {
399415
}
400416
}
401417

402-
#[derive(Debug)]
403-
struct Svc<S>(S);
418+
struct Svc<S> {
419+
inner: S,
420+
span: Option<TraceInterceptor>,
421+
}
404422

405423
impl<S> Service<Request<Body>> for Svc<S>
406424
where
@@ -409,14 +427,26 @@ where
409427
{
410428
type Response = Response<BoxBody>;
411429
type Error = crate::Error;
412-
type Future = MapErr<S::Future, fn(S::Error) -> crate::Error>;
430+
type Future = MapErr<Instrumented<S::Future>, fn(S::Error) -> crate::Error>;
413431

414432
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
415-
self.0.poll_ready(cx).map_err(Into::into)
433+
self.inner.poll_ready(cx).map_err(Into::into)
416434
}
417435

418436
fn call(&mut self, req: Request<Body>) -> Self::Future {
419-
self.0.call(req).map_err(|e| e.into())
437+
let span = if let Some(trace_interceptor) = &self.span {
438+
trace_interceptor(req.headers())
439+
} else {
440+
tracing::Span::none()
441+
};
442+
443+
self.inner.call(req).instrument(span).map_err(|e| e.into())
444+
}
445+
}
446+
447+
impl<S> fmt::Debug for Svc<S> {
448+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449+
f.debug_struct("Svc").finish()
420450
}
421451
}
422452

@@ -425,6 +455,7 @@ struct MakeSvc<S> {
425455
concurrency_limit: Option<usize>,
426456
// timeout: Option<Duration>,
427457
inner: S,
458+
span: Option<TraceInterceptor>,
428459
}
429460

430461
impl<S, T> Service<T> for MakeSvc<S>
@@ -447,6 +478,7 @@ where
447478
let svc = self.inner.clone();
448479
let concurrency_limit = self.concurrency_limit;
449480
// let timeout = self.timeout.clone();
481+
let span = self.span.clone();
450482

451483
Box::pin(async move {
452484
let svc = ServiceBuilder::new()
@@ -455,10 +487,10 @@ where
455487
.service(svc);
456488

457489
let svc = if let Some(interceptor) = interceptor {
458-
let layered = interceptor.layer(BoxService::new(Svc(svc)));
459-
BoxService::new(Svc(layered))
490+
let layered = interceptor.layer(BoxService::new(Svc { inner: svc, span }));
491+
BoxService::new(layered)
460492
} else {
461-
BoxService::new(Svc(svc))
493+
BoxService::new(Svc { inner: svc, span })
462494
};
463495

464496
Ok(svc)

0 commit comments

Comments
 (0)