Skip to content

Commit db33fcf

Browse files
committed
fix: wip
1 parent da9f24b commit db33fcf

File tree

4 files changed

+102
-117
lines changed

4 files changed

+102
-117
lines changed

src/api/mod.rs

Lines changed: 59 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,82 @@
1-
use anyhow::{Result, Error};
1+
use anyhow::{Error, Result};
22
use futures_util::future::OptionFuture;
3-
use futures_util::stream::{BoxStream, Stream};
4-
use futures_util::{FutureExt, StreamExt, TryStreamExt};
3+
use futures_util::stream::{Stream, TryStream};
4+
use futures_util::{FutureExt, TryStreamExt};
55
use metrics::{metrics, metrics_wrapper};
66
use sqlx::PgPool;
77
use tokio::io::{Error as IoError, Result as IoResult};
88
use tokio::net::TcpListener;
99
use tokio::net::ToSocketAddrs;
1010
use tokio_stream::wrappers::TcpListenerStream;
11-
use tracing::{debug_span, info};
11+
use tracing::info;
12+
use tokio::io::{AsyncWrite, AsyncRead};
1213

13-
use crate::api::proxy::{PeerAddr, ProxyStream};
14-
use futures_util::io::{AsyncRead, AsyncWrite};
14+
use crate::api::proxy::{ProxyProtocol, ToProxyStream};
1515

1616
mod metrics;
1717
mod proxy;
1818
mod routes;
1919
mod tls;
2020

21-
type Connection = Box<dyn PeerAddr<IoError> + Send + Unpin + 'static>;
22-
type Listener = BoxStream<'static, IoResult<Connection>>;
23-
24-
pub struct Api<H, S> {
21+
pub struct Api<H, P, S> {
2522
http: Option<H>,
2623
https: Option<S>,
27-
prom: Option<H>,
24+
prom: Option<P>,
2825
pool: PgPool,
2926
}
3027

31-
impl <H, S> Api<H, S> {
32-
fn prepare_listener(listener: TcpListener, proxy: bool) -> Listener {
33-
let listener = match listener {
34-
Some(listener) => TcpListenerStream::new(listener),
35-
None => return None,
36-
};
37-
38-
let mapper = match proxy {
39-
true => |stream| Box::new(ProxyStream::from(stream)) as Connection,
40-
false => |stream| Box::new(stream) as Connection
41-
};
42-
43-
let listener = listener
44-
.map_ok(mapper)
45-
.boxed();
46-
47-
Some(listener)
48-
}
28+
pub async fn new<A: ToSocketAddrs>(
29+
(http, http_proxy): (Option<A>, bool),
30+
(https, https_proxy): (Option<A>, bool),
31+
(prom, prom_proxy): (Option<A>, bool),
32+
pool: PgPool,
33+
) -> Result<
34+
Api<
35+
impl Stream<Item = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin>> + Send,
36+
impl Stream<Item = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin>> + Send,
37+
impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin, Error>>
38+
+ Send,
39+
>,
40+
> {
41+
let http = OptionFuture::from(http.map(TcpListener::bind)).map(Option::transpose);
42+
let https = OptionFuture::from(https.map(TcpListener::bind)).map(Option::transpose);
43+
let prom = OptionFuture::from(prom.map(TcpListener::bind)).map(Option::transpose);
44+
45+
let (http, https, prom) = tokio::try_join!(http, https, prom)?;
46+
47+
let http = http.map(|http| {
48+
let http =
49+
TcpListenerStream::new(http).map_ok(|stream| stream.source(ProxyProtocol::Enabled));
50+
proxy::wrap(http).try_buffer_unordered(100)
51+
});
52+
let https = https.map(|https| {
53+
let https =
54+
TcpListenerStream::new(https).map_ok(|stream| stream.source(ProxyProtocol::Enabled));
55+
tls::stream(https, pool.clone())
56+
});
57+
let prom = prom.map(|prom| {
58+
let prom =
59+
TcpListenerStream::new(prom).map_ok(|stream| stream.source(ProxyProtocol::Enabled));
60+
proxy::wrap(prom).try_buffer_unordered(100)
61+
});
62+
63+
Ok(Api {
64+
http,
65+
https,
66+
prom,
67+
pool,
68+
})
69+
}
4970

50-
pub async fn new<A: ToSocketAddrs>(
51-
(http, http_proxy): (Option<A>, bool),
52-
(https, https_proxy): (Option<A>, bool),
53-
(prom, prom_proxy): (Option<A>, bool),
54-
pool: PgPool,
55-
) -> Result<Api<
56-
impl Stream<Item = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin + 'static>> + Send,
57-
impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static, Error>> + Send
58-
>> {
59-
let http = OptionFuture::from(http.map(TcpListener::bind)).map(Option::transpose);
60-
let https = OptionFuture::from(https.map(TcpListener::bind)).map(Option::transpose);
61-
let prom = OptionFuture::from(prom.map(TcpListener::bind)).map(Option::transpose);
62-
63-
let (http, https, prom) = tokio::try_join!(http, https, prom)?;
64-
65-
let http = Api::prepare_listener(http, http_proxy);
66-
let https = Api::prepare_listener(https, https_proxy);
67-
let prom = Api::prepare_listener(prom, prom_proxy);
68-
69-
Ok(Api {
70-
http: proxy::wrap(http).try_buffer_unordered(100),
71-
https: tls::stream(https, pool.clone()),
72-
prom: proxy::wrap(prom).try_buffer_unordered(100),
73-
pool,
74-
})
75-
}
71+
impl<H, P, S> Api<H, P, S>
72+
where
73+
H: TryStream<Error = IoError> + Send + Unpin + 'static,
74+
H::Ok: AsyncRead + AsyncWrite + Send + Unpin + 'static,
75+
P: TryStream<Error = IoError> + Send + Unpin + 'static,
76+
P::Ok: AsyncRead + AsyncWrite + Send + Unpin + 'static,
77+
S: TryStream<Error = Error> + Send + Unpin + 'static,
78+
S::Ok: AsyncRead + AsyncWrite + Send + Unpin + 'static,
79+
{
7680

7781
#[tracing::instrument(name = "Api::spawn", skip(self))]
7882
pub async fn spawn(self) -> Result<()> {
@@ -82,30 +86,17 @@ impl <H, S> Api<H, S> {
8286

8387
let http = self
8488
.http
85-
.map(|http|
86-
proxy::wrap(http).try_buffer_unordered(100)
87-
)
8889
.map(|http| warp::serve(routes.clone()).serve_incoming(http))
8990
.map(tokio::spawn);
9091

9192
let pool = self.pool.clone();
9293
let https = self
9394
.https
94-
.map(|https| {
95-
//let addr = https.as_ref().local_addr();
96-
tls::stream(https, pool)
97-
//.instrument(debug_span!("HTTPS", local.addr = ?addr))
98-
})
9995
.map(|https| warp::serve(routes).serve_incoming(https))
10096
.map(tokio::spawn);
10197

10298
let prom = self
10399
.prom
104-
.map(|prom| {
105-
//let addr = prom.as_ref().local_addr();
106-
prom
107-
//.instrument(debug_span!("PROM", local.addr = ?addr))
108-
})
109100
.map(|prom| warp::serve(metrics()).serve_incoming(prom))
110101
.map(tokio::spawn);
111102

src/api/proxy.rs

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use futures_util::future::{ready, BoxFuture, FutureExt};
21
use futures_util::stream::{Stream, TryStream};
32
use futures_util::TryStreamExt;
43
use ppp::error::ParseError;
@@ -7,21 +6,42 @@ use std::future::Future;
76
use std::io::{Cursor, IoSlice, Write};
87
use std::mem::MaybeUninit;
98
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
10-
use std::ops::DerefMut;
119
use std::pin::Pin;
1210
use std::task::{Context, Poll};
1311
use tokio::io::{AsyncRead, AsyncWrite, Error as IoError, ErrorKind, ReadBuf, Result as IoResult};
1412
use tokio::net::TcpStream;
1513
use tracing::field::{display, Empty};
1614
use tracing::{debug_span, error, info, Instrument, Span};
1715

18-
pub(super) fn wrap<S, O, E, P>(
16+
// wrap tcplistener instead of tcpstream
17+
pub(super) enum ProxyProtocol {
18+
Enabled,
19+
Disabled,
20+
}
21+
22+
pub(super) trait ToProxyStream: Sized {
23+
fn source(self, proxy: ProxyProtocol) -> ProxyStream;
24+
}
25+
26+
impl ToProxyStream for TcpStream {
27+
fn source(self, proxy: ProxyProtocol) -> ProxyStream {
28+
let data = match proxy {
29+
ProxyProtocol::Enabled => Some(Default::default()),
30+
ProxyProtocol::Disabled => None,
31+
};
32+
ProxyStream {
33+
stream: self,
34+
data,
35+
start_of_data: 0,
36+
}
37+
}
38+
}
39+
40+
pub(super) fn wrap<S, E>(
1941
stream: S,
2042
) -> impl Stream<Item = Result<impl Future<Output = Result<impl AsyncRead + AsyncWrite, E>>, E>>
2143
where
22-
P: std::error::Error,
23-
S: TryStream<Ok = O, Error = E>,
24-
O: PeerAddr<P>,
44+
S: TryStream<Ok = ProxyStream, Error = E>,
2545
{
2646
stream.map_ok(|mut conn| {
2747
let span = debug_span!("ADDR", remote.addr = Empty);
@@ -43,16 +63,6 @@ where
4363
})
4464
}
4565

46-
pub(super) trait PeerAddr<E: std::error::Error>: AsyncRead + AsyncWrite {
47-
fn proxy_peer<'a>(&'a mut self) -> BoxFuture<'a, Result<SocketAddr, E>>;
48-
}
49-
50-
impl PeerAddr<tokio::io::Error> for TcpStream {
51-
fn proxy_peer(&mut self) -> BoxFuture<IoResult<SocketAddr>> {
52-
ready(self.peer_addr()).boxed()
53-
}
54-
}
55-
5666
struct PeerAddrFuture<'a> {
5767
stream: &'a mut ProxyStream,
5868
}
@@ -126,16 +136,17 @@ impl<'a> Future for PeerAddrFuture<'a> {
126136

127137
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
128138
let this = self.get_mut();
139+
let stream = &mut this.stream;
129140

130-
let data = match &mut this.stream.data {
141+
let data = match &mut stream.data {
131142
Some(ref mut data) => data,
132-
None => unreachable!("Future cannot be polled anymore"),
143+
None => return Poll::Ready(stream.stream.local_addr()),
133144
};
134145

135146
let mut buf = [MaybeUninit::<u8>::uninit(); 256];
136147
let mut buf = ReadBuf::uninit(&mut buf);
137148

138-
let stream = Pin::new(&mut this.stream.stream);
149+
let stream = Pin::new(&mut stream.stream);
139150
let buf = match stream.poll_read(cx, &mut buf) {
140151
Poll::Ready(Ok(_)) => buf.filled(),
141152
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
@@ -156,19 +167,9 @@ pub(super) struct ProxyStream {
156167
start_of_data: usize,
157168
}
158169

159-
impl From<TcpStream> for ProxyStream {
160-
fn from(stream: TcpStream) -> Self {
161-
ProxyStream {
162-
stream,
163-
data: Some(Default::default()),
164-
start_of_data: 0,
165-
}
166-
}
167-
}
168-
169-
impl PeerAddr<tokio::io::Error> for ProxyStream {
170-
fn proxy_peer(&mut self) -> BoxFuture<IoResult<SocketAddr>> {
171-
PeerAddrFuture::new(self).boxed()
170+
impl ProxyStream {
171+
fn proxy_peer(&mut self) -> PeerAddrFuture<'_> {
172+
PeerAddrFuture::new(self)
172173
}
173174
}
174175

@@ -216,9 +217,3 @@ impl AsyncWrite for ProxyStream {
216217
self.stream.is_write_vectored()
217218
}
218219
}
219-
220-
impl<E: std::error::Error> PeerAddr<E> for Box<(dyn PeerAddr<E> + Send + Unpin + 'static)> {
221-
fn proxy_peer<'a>(&'a mut self) -> BoxFuture<'a, Result<SocketAddr, E>> {
222-
self.deref_mut().proxy_peer()
223-
}
224-
}

src/api/tls.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ use tokio::io::{AsyncRead, AsyncWrite};
1010
use tokio_rustls::TlsAcceptor;
1111
use tracing::{error, info};
1212

13-
use super::proxy::{wrap, PeerAddr};
13+
use super::proxy::wrap;
14+
use crate::api::proxy::ProxyStream;
1415
use crate::cert::{Cert, CertFacade};
1516
use crate::util::to_u64;
1617

@@ -78,14 +79,12 @@ impl Acceptor {
7879
}
7980
}
8081

81-
pub(super) fn stream<S, O, E, P>(
82+
pub(super) fn stream<S, E>(
8283
listener: S,
8384
pool: PgPool,
8485
) -> impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static, Error>> + Send
8586
where
86-
P: std::error::Error + Send + Sync + 'static,
87-
S: Stream<Item = Result<O, E>> + Send + 'static,
88-
O: PeerAddr<P> + Send + Unpin + 'static,
87+
S: Stream<Item = Result<ProxyStream, E>> + Send + 'static,
8988
E: Into<Error> + Send + 'static,
9089
{
9190
let acceptor = Acceptor::new(pool);

src/main.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ use std::str::FromStr;
88
use tokio::runtime::Runtime;
99
use tokio::signal::ctrl_c;
1010
use tracing::{debug, error, info, Instrument};
11-
12-
use crate::acme::DatabasePersist;
13-
use crate::api::Api;
14-
use crate::cert::CertManager;
15-
use crate::dns::{DatabaseAuthority, DNS};
1611
use std::sync::Arc;
1712

13+
use acme::DatabasePersist;
14+
use cert::CertManager;
15+
use dns::{DatabaseAuthority, DNS};
16+
use api::Api;
17+
1818
mod acme;
1919
mod api;
2020
mod cert;
@@ -54,7 +54,7 @@ fn run() -> Result<()> {
5454
let dns = DNS::new(&config.general.dns, authority);
5555

5656
let api = &config.api;
57-
let api = Api::new(
57+
let api = api::new(
5858
(api.http.as_deref(), api.http_proxy),
5959
(api.https.as_deref(), api.https_proxy),
6060
(api.prom.as_deref(), api.prom_proxy),

0 commit comments

Comments
 (0)