Skip to content

Commit da9f24b

Browse files
committed
fix: wip
1 parent 9868d3d commit da9f24b

File tree

4 files changed

+80
-38
lines changed

4 files changed

+80
-38
lines changed

src/api/mod.rs

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,75 @@
1-
use anyhow::Result;
1+
use anyhow::{Result, Error};
22
use futures_util::future::OptionFuture;
3-
use futures_util::{FutureExt, TryStreamExt};
3+
use futures_util::stream::{BoxStream, Stream};
4+
use futures_util::{FutureExt, StreamExt, TryStreamExt};
5+
use metrics::{metrics, metrics_wrapper};
46
use sqlx::PgPool;
7+
use tokio::io::{Error as IoError, Result as IoResult};
58
use tokio::net::TcpListener;
69
use tokio::net::ToSocketAddrs;
710
use tokio_stream::wrappers::TcpListenerStream;
811
use tracing::{debug_span, info};
9-
use tracing_futures::Instrument;
1012

11-
use metrics::{metrics, metrics_wrapper};
12-
use proxy::ProxyStream;
13+
use crate::api::proxy::{PeerAddr, ProxyStream};
14+
use futures_util::io::{AsyncRead, AsyncWrite};
1315

1416
mod metrics;
1517
mod proxy;
1618
mod routes;
1719
mod tls;
1820

19-
pub struct Api {
20-
http: Option<TcpListenerStream>,
21-
https: Option<TcpListenerStream>,
22-
prom: Option<TcpListenerStream>,
21+
type Connection = Box<dyn PeerAddr<IoError> + Send + Unpin + 'static>;
22+
type Listener = BoxStream<'static, IoResult<Connection>>;
23+
24+
pub struct Api<H, S> {
25+
http: Option<H>,
26+
https: Option<S>,
27+
prom: Option<H>,
2328
pool: PgPool,
2429
}
2530

26-
impl Api {
31+
impl <H, S> Api<H, S> {
32+
fn prepare_listener(listener: TcpListener, proxy: bool) -> Listener {
33+
let listener = match listener {
34+
Some(listener) => TcpListenerStream::new(listener),
35+
None => return None,
36+
};
37+
38+
let mapper = match proxy {
39+
true => |stream| Box::new(ProxyStream::from(stream)) as Connection,
40+
false => |stream| Box::new(stream) as Connection
41+
};
42+
43+
let listener = listener
44+
.map_ok(mapper)
45+
.boxed();
46+
47+
Some(listener)
48+
}
49+
2750
pub async fn new<A: ToSocketAddrs>(
28-
http: Option<A>,
29-
https: Option<A>,
30-
prom: Option<A>,
51+
(http, http_proxy): (Option<A>, bool),
52+
(https, https_proxy): (Option<A>, bool),
53+
(prom, prom_proxy): (Option<A>, bool),
3154
pool: PgPool,
32-
) -> Result<Self> {
55+
) -> Result<Api<
56+
impl Stream<Item = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin + 'static>> + Send,
57+
impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static, Error>> + Send
58+
>> {
3359
let http = OptionFuture::from(http.map(TcpListener::bind)).map(Option::transpose);
3460
let https = OptionFuture::from(https.map(TcpListener::bind)).map(Option::transpose);
3561
let prom = OptionFuture::from(prom.map(TcpListener::bind)).map(Option::transpose);
3662

3763
let (http, https, prom) = tokio::try_join!(http, https, prom)?;
3864

65+
let http = Api::prepare_listener(http, http_proxy);
66+
let https = Api::prepare_listener(https, https_proxy);
67+
let prom = Api::prepare_listener(prom, prom_proxy);
68+
3969
Ok(Api {
40-
http: http.map(TcpListenerStream::new),
41-
https: https.map(TcpListenerStream::new),
42-
prom: prom.map(TcpListenerStream::new),
70+
http: proxy::wrap(http).try_buffer_unordered(100),
71+
https: tls::stream(https, pool.clone()),
72+
prom: proxy::wrap(prom).try_buffer_unordered(100),
4373
pool,
4474
})
4575
}
@@ -52,31 +82,29 @@ impl Api {
5282

5383
let http = self
5484
.http
55-
.map(|http| {
56-
let addr = http.as_ref().local_addr();
57-
proxy::wrap(http.map_ok(ProxyStream::from))
58-
.try_buffer_unordered(100)
59-
.instrument(debug_span!("HTTP", local.addr = ?addr))
60-
})
85+
.map(|http|
86+
proxy::wrap(http).try_buffer_unordered(100)
87+
)
6188
.map(|http| warp::serve(routes.clone()).serve_incoming(http))
6289
.map(tokio::spawn);
6390

6491
let pool = self.pool.clone();
6592
let https = self
6693
.https
6794
.map(|https| {
68-
let addr = https.as_ref().local_addr();
69-
let https = https.map_ok(ProxyStream::from);
70-
tls::stream(https, pool).instrument(debug_span!("HTTPS", local.addr = ?addr))
95+
//let addr = https.as_ref().local_addr();
96+
tls::stream(https, pool)
97+
//.instrument(debug_span!("HTTPS", local.addr = ?addr))
7198
})
7299
.map(|https| warp::serve(routes).serve_incoming(https))
73100
.map(tokio::spawn);
74101

75102
let prom = self
76103
.prom
77104
.map(|prom| {
78-
let addr = prom.as_ref().local_addr();
79-
prom.instrument(debug_span!("PROM", local.addr = ?addr))
105+
//let addr = prom.as_ref().local_addr();
106+
prom
107+
//.instrument(debug_span!("PROM", local.addr = ?addr))
80108
})
81109
.map(|prom| warp::serve(metrics()).serve_incoming(prom))
82110
.map(tokio::spawn);

src/api/proxy.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ use std::future::Future;
77
use std::io::{Cursor, IoSlice, Write};
88
use std::mem::MaybeUninit;
99
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
10+
use std::ops::DerefMut;
1011
use std::pin::Pin;
1112
use std::task::{Context, Poll};
1213
use tokio::io::{AsyncRead, AsyncWrite, Error as IoError, ErrorKind, ReadBuf, Result as IoResult};
1314
use tokio::net::TcpStream;
1415
use tracing::field::{display, Empty};
15-
use tracing::{debug_span, error, info, Instrument};
16+
use tracing::{debug_span, error, info, Instrument, Span};
1617

1718
pub(super) fn wrap<S, O, E, P>(
1819
stream: S,
@@ -24,8 +25,8 @@ where
2425
{
2526
stream.map_ok(|mut conn| {
2627
let span = debug_span!("ADDR", remote.addr = Empty);
27-
let span_two = span.clone();
2828
async move {
29+
let span = Span::current();
2930
match conn.proxy_peer().await {
3031
Ok(addr) => {
3132
span.record("remote.addr", &display(addr));
@@ -38,7 +39,7 @@ where
3839
}
3940
Ok(conn)
4041
}
41-
.instrument(span_two)
42+
.instrument(span)
4243
})
4344
}
4445

@@ -124,8 +125,8 @@ impl<'a> Future for PeerAddrFuture<'a> {
124125
type Output = IoResult<SocketAddr>;
125126

126127
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127-
let this = &mut self.get_mut();
128-
// add option again to make impossible to pull future later
128+
let this = self.get_mut();
129+
129130
let data = match &mut this.stream.data {
130131
Some(ref mut data) => data,
131132
None => unreachable!("Future cannot be polled anymore"),
@@ -136,7 +137,7 @@ impl<'a> Future for PeerAddrFuture<'a> {
136137

137138
let stream = Pin::new(&mut this.stream.stream);
138139
let buf = match stream.poll_read(cx, &mut buf) {
139-
Poll::Ready(Ok(_)) => buf.filled_mut(),
140+
Poll::Ready(Ok(_)) => buf.filled(),
140141
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
141142
Poll::Pending => return Poll::Pending,
142143
};
@@ -178,7 +179,7 @@ impl AsyncRead for ProxyStream {
178179
buf: &mut ReadBuf<'_>,
179180
) -> Poll<IoResult<()>> {
180181
let this = self.get_mut();
181-
// handle the case were the full data has no space in the first place
182+
// todo: handle the case were the data has no space in buf
182183
if let Some(data) = this.data.take() {
183184
buf.put_slice(&data.get_ref()[this.start_of_data..])
184185
}
@@ -215,3 +216,9 @@ impl AsyncWrite for ProxyStream {
215216
self.stream.is_write_vectored()
216217
}
217218
}
219+
220+
impl<E: std::error::Error> PeerAddr<E> for Box<(dyn PeerAddr<E> + Send + Unpin + 'static)> {
221+
fn proxy_peer<'a>(&'a mut self) -> BoxFuture<'a, Result<SocketAddr, E>> {
222+
self.deref_mut().proxy_peer()
223+
}
224+
}

src/config.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@ use tracing::{debug, info, info_span};
88
#[derive(Deserialize, Debug)]
99
pub struct Api {
1010
pub http: Option<String>,
11+
#[serde(default)]
12+
pub http_proxy: bool,
1113
pub https: Option<String>,
14+
#[serde(default)]
15+
pub https_proxy: bool,
1216
pub prom: Option<String>,
17+
#[serde(default)]
18+
pub prom_proxy: bool,
1319
}
1420

1521
fn default_acme() -> String {

src/main.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,11 @@ fn run() -> Result<()> {
5353
DatabaseAuthority::new(pool.clone(), &config.general.name, config.records);
5454
let dns = DNS::new(&config.general.dns, authority);
5555

56+
let api = &config.api;
5657
let api = Api::new(
57-
config.api.http.as_deref(),
58-
config.api.https.as_deref(),
59-
config.api.prom.as_deref(),
58+
(api.http.as_deref(), api.http_proxy),
59+
(api.https.as_deref(), api.https_proxy),
60+
(api.prom.as_deref(), api.prom_proxy),
6061
pool.clone(),
6162
)
6263
.and_then(Api::spawn);

0 commit comments

Comments
 (0)