Skip to content

Commit bb262eb

Browse files
committed
fix: wip
1 parent 1d06ac9 commit bb262eb

File tree

3 files changed

+132
-83
lines changed

3 files changed

+132
-83
lines changed

src/api/mod.rs

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
11
use anyhow::{Error, Result};
22
use futures_util::future::OptionFuture;
33
use futures_util::stream::{Stream, TryStream};
4-
use futures_util::{FutureExt, TryStreamExt};
4+
use futures_util::FutureExt;
55
use metrics::{metrics, metrics_wrapper};
66
use sqlx::PgPool;
7+
use tokio::io::Error as IoError;
78
use tokio::io::{AsyncRead, AsyncWrite};
8-
use tokio::io::{Error as IoError, Result as IoResult};
99
use tokio::net::TcpListener;
1010
use tokio::net::ToSocketAddrs;
11-
use tokio_stream::wrappers::TcpListenerStream;
1211
use tracing::info;
1312

13+
use crate::api::proxy::{ProxyListener, ToProxyListener};
1414
use crate::config::ProxyProtocol;
15-
use proxy::ToProxyStream;
1615

1716
mod metrics;
1817
mod proxy;
1918
mod routes;
2019
mod tls;
2120

22-
pub struct Api<H, P, S> {
23-
http: Option<H>,
21+
pub struct Api<S> {
22+
http: Option<ProxyListener<IoError>>,
2423
https: Option<S>,
25-
prom: Option<P>,
24+
prom: Option<ProxyListener<IoError>>,
2625
pool: PgPool,
2726
}
2827

@@ -33,8 +32,6 @@ pub async fn new<A: ToSocketAddrs>(
3332
pool: PgPool,
3433
) -> Result<
3534
Api<
36-
impl Stream<Item = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin + 'static>> + Send,
37-
impl Stream<Item = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin + 'static>> + Send,
3835
impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static, Error>> + Send,
3936
>,
4037
> {
@@ -44,18 +41,11 @@ pub async fn new<A: ToSocketAddrs>(
4441

4542
let (http, https, prom) = tokio::try_join!(http, https, prom)?;
4643

47-
let http = http.map(|http| {
48-
let http = TcpListenerStream::new(http).map_ok(move |stream| stream.source(http_proxy));
49-
proxy::wrap(http).try_buffer_unordered(100)
50-
});
51-
let https = https.map(|https| {
52-
let https = TcpListenerStream::new(https).map_ok(move |stream| stream.source(https_proxy));
53-
tls::stream(https, pool.clone())
54-
});
55-
let prom = prom.map(move |prom| {
56-
let prom = TcpListenerStream::new(prom).map_ok(move |stream| stream.source(prom_proxy));
57-
proxy::wrap(prom).try_buffer_unordered(100)
58-
});
44+
let http = http.map(|http| http.source(http_proxy));
45+
let https = https
46+
.map(|https| https.source(https_proxy))
47+
.map(|listener| tls::stream(listener, pool.clone()));
48+
let prom = prom.map(|prom| prom.source(prom_proxy));
5949

6050
Ok(Api {
6151
http,
@@ -65,12 +55,8 @@ pub async fn new<A: ToSocketAddrs>(
6555
})
6656
}
6757

68-
impl<H, P, S> Api<H, P, S>
58+
impl<S> Api<S>
6959
where
70-
H: TryStream<Error = IoError> + Send + Unpin + 'static,
71-
H::Ok: AsyncRead + AsyncWrite + Send + Unpin + 'static,
72-
P: TryStream<Error = IoError> + Send + Unpin + 'static,
73-
P::Ok: AsyncRead + AsyncWrite + Send + Unpin + 'static,
7460
S: TryStream<Error = Error> + Send + Unpin + 'static,
7561
S::Ok: AsyncRead + AsyncWrite + Send + Unpin + 'static,
7662
{

src/api/proxy.rs

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use futures_util::stream::{MapOk, Stream, TryBufferUnordered, TryStream};
1+
use futures_util::stream::{MapOk, Stream, TryBufferUnordered};
22
use futures_util::{FutureExt, StreamExt, TryStreamExt};
33
use ppp::error::ParseError;
44
use ppp::model::{Addresses, Header};
@@ -13,17 +13,48 @@ use tokio::io::{AsyncRead, AsyncWrite, Error as IoError, ErrorKind, ReadBuf, Res
1313
use tokio::net::{TcpListener, TcpStream};
1414
use tokio_stream::wrappers::TcpListenerStream;
1515
use tracing::field::{display, Empty};
16+
use tracing::instrument::Instrumented;
1617
use tracing::{debug_span, error, info, Instrument, Span};
1718

1819
use crate::config::ProxyProtocol;
19-
use tracing::instrument::Instrumented;
2020

21-
struct ProxyListener {
21+
pub(super) type WrapOutput<E> = Instrumented<ProxyStreamFuture<E>>;
22+
pub(super) type Wrap<E> = fn(conn: ProxyStream) -> WrapOutput<E>;
23+
pub(super) type ProxyListener<E> = TryBufferUnordered<MapOk<ProxyListenerImpl, Wrap<E>>>;
24+
25+
pub(super) trait ToProxyListener<E>: Sized {
26+
fn source(self, proxy: ProxyProtocol) -> ProxyListener<E>;
27+
}
28+
29+
fn wrap<E>(stream: ProxyStream) -> WrapOutput<E> {
30+
let span = debug_span!("ADDR", remote.addr = Empty);
31+
ProxyStreamFuture {
32+
stream: Some(stream),
33+
phantom: PhantomData,
34+
}
35+
.instrument(span)
36+
}
37+
38+
impl ToProxyListener<IoError> for TcpListener {
39+
fn source(
40+
self,
41+
proxy: ProxyProtocol,
42+
) -> TryBufferUnordered<MapOk<ProxyListenerImpl, Wrap<IoError>>> {
43+
ProxyListenerImpl {
44+
listener: TcpListenerStream::new(self),
45+
proxy,
46+
}
47+
.map_ok(wrap as Wrap<IoError>)
48+
.try_buffer_unordered(100)
49+
}
50+
}
51+
52+
pub(super) struct ProxyListenerImpl {
2253
listener: TcpListenerStream,
2354
proxy: ProxyProtocol,
2455
}
2556

26-
impl Stream for ProxyListener {
57+
impl Stream for ProxyListenerImpl {
2758
type Item = IoResult<ProxyStream>;
2859

2960
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
@@ -53,35 +84,15 @@ impl Stream for ProxyListener {
5384

5485
// wrap tcplistener instead of tcpstream
5586

56-
pub(super) trait ToProxyStream: Sized {
57-
fn source(self, proxy: ProxyProtocol) -> ProxyStream;
58-
}
59-
60-
impl ToProxyStream for TcpStream {
61-
fn source(self, proxy: ProxyProtocol) -> ProxyStream {
62-
let data = match proxy {
63-
ProxyProtocol::Enabled => Some(Default::default()),
64-
ProxyProtocol::Disabled => None,
65-
};
66-
ProxyStream {
67-
stream: self,
68-
data,
69-
start_of_data: 0,
70-
}
71-
}
72-
}
73-
74-
struct ProxyStreamFuture<E> {
87+
pub(super) struct ProxyStreamFuture<E> {
7588
stream: Option<ProxyStream>,
7689
phantom: PhantomData<E>,
7790
}
7891

79-
type Wrap<E> = fn(conn: ProxyStream) -> Instrumented<ProxyStreamFuture<E>>;
80-
8192
impl<E: Unpin> Future for ProxyStreamFuture<E> {
8293
type Output = Result<ProxyStream, E>;
8394

84-
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
95+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
8596
let this = self.get_mut();
8697
let span = Span::current();
8798
let stream = this
@@ -109,24 +120,7 @@ impl<E: Unpin> Future for ProxyStreamFuture<E> {
109120
}
110121
}
111122

112-
pub(super) fn wrap(
113-
listener: TcpListener,
114-
) -> TryBufferUnordered<MapOk<ProxyListener, Wrap<IoError>>> {
115-
ProxyListener {
116-
listener: TcpListenerStream::new(listener),
117-
proxy: ProxyProtocol::Enabled,
118-
}
119-
.map_ok(|mut stream| {
120-
let span = debug_span!("test");
121-
ProxyStreamFuture {
122-
stream: Some(stream),
123-
phantom: PhantomData,
124-
}
125-
.instrument(span)
126-
})
127-
.try_buffer_unordered(100)
128-
}
129-
123+
// maybe use state machine
130124
struct PeerAddrFuture<'a> {
131125
stream: &'a mut ProxyStream,
132126
}

src/api/tls.rs

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
use anyhow::{anyhow, Error, Result};
2-
use futures_util::stream::{repeat, Stream};
3-
use futures_util::{StreamExt, TryStreamExt};
2+
use futures_util::stream::{repeat, ErrInto, Stream, Zip};
3+
use futures_util::{StreamExt, TryStreamExt, FutureExt};
44
use parking_lot::RwLock;
55
use rustls::internal::pemfile::{certs, pkcs8_private_keys};
66
use rustls::{NoClientAuth, ServerConfig};
77
use sqlx::PgPool;
88
use std::sync::Arc;
9-
use tokio::io::{AsyncRead, AsyncWrite};
10-
use tokio_rustls::TlsAcceptor;
9+
use tokio::io::{AsyncRead, AsyncWrite, Error as IoError};
10+
use tokio_rustls::{TlsAcceptor, Accept};
1111
use tracing::{error, info};
1212

13-
use super::proxy;
14-
use crate::api::proxy::ProxyStream;
13+
use super::proxy::ProxyListener;
1514
use crate::cert::{Cert, CertFacade};
1615
use crate::util::to_u64;
16+
use futures_util::stream::Repeat;
17+
use std::future::Future;
1718

1819
struct Acceptor {
1920
pool: PgPool,
@@ -79,25 +80,93 @@ impl Acceptor {
7980
}
8081
}
8182

82-
pub(super) fn stream<S, E>(
83-
listener: S,
83+
trait Tuple {
84+
type A;
85+
type B;
86+
}
87+
88+
impl<A, B> Tuple for (A, B) {
89+
type A = A;
90+
type B = B;
91+
}
92+
93+
trait ResultInherit {
94+
type Ok;
95+
type Error;
96+
}
97+
98+
impl<T, E> ResultInherit for Result<T, E> {
99+
type Ok = T;
100+
type Error = E;
101+
}
102+
103+
trait ToTlsListener {
104+
fn tls(self, pool: PgPool) -> ProxyListener<IoError>;
105+
}
106+
107+
type TlsListenerBeforeMap = Zip<ErrInto<ProxyListener<IoError>, Error>, Repeat<Arc<Acceptor>>>;
108+
109+
type MapResultOutput = Result<
110+
(
111+
<<<TlsListenerBeforeMap as Stream>::Item as Tuple>::A as ResultInherit>::Ok,
112+
Arc<Acceptor>,
113+
),
114+
<<<TlsListenerBeforeMap as Stream>::Item as Tuple>::A as ResultInherit>::Error,
115+
>;
116+
117+
type MapResult = fn(<TlsListenerBeforeMap as Stream>::Item) -> MapResultOutput;
118+
119+
fn map_result((listener, acceptor): <TlsListenerBeforeMap as Stream>::Item) -> MapResultOutput {
120+
listener.map(|listener| (listener, acceptor))
121+
}
122+
123+
pub(super) fn stream(
124+
listener: ProxyListener<IoError>,
84125
pool: PgPool,
85126
) -> impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static, Error>> + Send
86-
where
87-
S: Stream<Item = Result<ProxyStream, E>> + Send + 'static,
88-
E: Into<Error> + Send + 'static,
89127
{
90128
let acceptor = Acceptor::new(pool);
91129

92-
proxy::wrap(listener.err_into())
130+
listener
131+
.err_into()
93132
.zip(repeat(acceptor))
94133
.map(|(conn, acceptor)| conn.map(|c| (c, acceptor)))
95134
.map_ok(|(conn, acceptor)| async move {
96-
let (conn, tls) = tokio::try_join!(conn, acceptor.load_cert())?;
135+
let tls = acceptor.load_cert().await?;
97136
Ok(tls.accept(conn).await?)
98137
})
99138
.try_buffer_unordered(100)
100139
.inspect_err(|err| error!("Stream error: {:?}", err))
101140
.filter(|stream| futures_util::future::ready(stream.is_ok()))
102141
.into_stream()
103142
}
143+
144+
enum PrepareTlsFuture {
145+
Start(ProxyListener<IoError>, Arc<Acceptor>),
146+
LoadingCert()
147+
148+
}
149+
struct PrepareTlsFuture {
150+
listener: Option<ProxyListener<IoError>>,
151+
accept: Option<Accept<ProxyListener<IoError>>>
152+
acceptor: Arc<Acceptor>
153+
}
154+
155+
impl Future for PrepareTlsFuture {
156+
type Output = Result<Accept<ProxyListener<IoError>>>;
157+
158+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159+
let this = self.get_mut();
160+
let tls = match this.acceptor.load_cert().poll_unpin(cx)? {
161+
Err(e) => return Poll::Ready(Err(e)),
162+
Ok(tls) => tls,
163+
};
164+
165+
let stream = match this.listener.take() {
166+
Some(stream) => stream,
167+
None => unreachable!("Future cannot be polled anymore"),
168+
};
169+
170+
match tls.accept(stream)
171+
}
172+
}

0 commit comments

Comments
 (0)