Skip to content

Commit 35da5ab

Browse files
committed
fix: simple implementation
1 parent f5ea489 commit 35da5ab

File tree

3 files changed

+113
-125
lines changed

3 files changed

+113
-125
lines changed

src/api/mod.rs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
use anyhow::{Error, Result};
22
use futures_util::future::OptionFuture;
33
use futures_util::stream::{Stream, TryStream};
4-
use futures_util::{FutureExt, TryStreamExt};
4+
use futures_util::FutureExt;
55
use metrics::{metrics, metrics_wrapper};
66
use sqlx::PgPool;
77
use tokio::io::{AsyncRead, AsyncWrite};
88
use tokio::io::{Error as IoError, Result as IoResult};
99
use tokio::net::TcpListener;
1010
use tokio::net::ToSocketAddrs;
11-
use tokio_stream::wrappers::TcpListenerStream;
1211
use tracing::info;
1312

1413
use crate::config::ProxyProtocol;
15-
use proxy::ToProxyStream;
1614

1715
mod metrics;
1816
mod proxy;
@@ -35,7 +33,7 @@ pub async fn new<A: ToSocketAddrs>(
3533
Api<
3634
impl Stream<Item = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin + 'static>> + Send,
3735
impl Stream<Item = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin + 'static>> + Send,
38-
impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static, Error>> + Send,
36+
impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static>> + Send,
3937
>,
4038
> {
4139
let http = OptionFuture::from(http.map(TcpListener::bind)).map(Option::transpose);
@@ -44,18 +42,11 @@ pub async fn new<A: ToSocketAddrs>(
4442

4543
let (http, https, prom) = tokio::try_join!(http, https, prom)?;
4644

47-
let http = http.map(|http| {
48-
let http = TcpListenerStream::new(http).map_ok(move |stream| stream.source(http_proxy));
49-
proxy::wrap(http).try_buffer_unordered(100)
50-
});
51-
let https = https.map(|https| {
52-
let https = TcpListenerStream::new(https).map_ok(move |stream| stream.source(https_proxy));
53-
tls::stream(https, pool.clone())
54-
});
55-
let prom = prom.map(move |prom| {
56-
let prom = TcpListenerStream::new(prom).map_ok(move |stream| stream.source(prom_proxy));
57-
proxy::wrap(prom).try_buffer_unordered(100)
58-
});
45+
let http = http.map(move |http| proxy::wrap(http, http_proxy));
46+
let prom = prom.map(move |prom| proxy::wrap(prom, prom_proxy));
47+
let https = https
48+
.map(move |https| proxy::wrap(https, https_proxy))
49+
.map(|https| tls::wrap(https, pool.clone()));
5950

6051
Ok(Api {
6152
http,

src/api/proxy.rs

Lines changed: 83 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use futures_util::stream::{Stream, TryStream};
1+
use futures_util::stream::Stream;
22
use futures_util::TryStreamExt;
33
use ppp::error::ParseError;
44
use ppp::model::{Addresses, Header};
@@ -9,13 +9,39 @@ use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
99
use std::pin::Pin;
1010
use std::task::{Context, Poll};
1111
use tokio::io::{AsyncRead, AsyncWrite, Error as IoError, ErrorKind, ReadBuf, Result as IoResult};
12-
use tokio::net::TcpStream;
12+
use tokio::net::{TcpListener, TcpStream};
1313
use tracing::field::{display, Empty};
1414
use tracing::{debug_span, error, info, Instrument, Span};
15+
use tokio_stream::wrappers::TcpListenerStream;
1516

1617
use crate::config::ProxyProtocol;
1718

18-
// wrap tcplistener instead of tcpstream
19+
pub(super) fn wrap(
20+
listener: TcpListener,
21+
proxy: ProxyProtocol,
22+
) -> impl Stream<Item = IoResult<ProxyStream>> + Send {
23+
TcpListenerStream::new(listener)
24+
.map_ok(move |stream| stream.source(proxy))
25+
.map_ok(|mut conn| {
26+
let span = debug_span!("ADDR", remote.addr = Empty);
27+
async move {
28+
let span = Span::current();
29+
match conn.proxy_peer().await {
30+
Ok(addr) => {
31+
span.record("remote.addr", &display(addr));
32+
info!("Got addr {}", addr)
33+
}
34+
Err(e) => {
35+
span.record("remote.addr", &"Unknown");
36+
error!("Could net get remote.addr: {}", e);
37+
}
38+
}
39+
Ok(conn)
40+
}
41+
.instrument(span)
42+
})
43+
.try_buffer_unordered(100)
44+
}
1945

2046
pub(super) trait ToProxyStream: Sized {
2147
fn source(self, proxy: ProxyProtocol) -> ProxyStream;
@@ -35,30 +61,61 @@ impl ToProxyStream for TcpStream {
3561
}
3662
}
3763

38-
pub(super) fn wrap<S, E>(
39-
stream: S,
40-
) -> impl Stream<Item = Result<impl Future<Output = Result<impl AsyncRead + AsyncWrite, E>>, E>>
41-
where
42-
S: TryStream<Ok = ProxyStream, Error = E>,
43-
{
44-
stream.map_ok(|mut conn| {
45-
let span = debug_span!("ADDR", remote.addr = Empty);
46-
async move {
47-
let span = Span::current();
48-
match conn.proxy_peer().await {
49-
Ok(addr) => {
50-
span.record("remote.addr", &display(addr));
51-
info!("Got addr {}", addr)
52-
}
53-
Err(e) => {
54-
span.record("remote.addr", &"Unknown");
55-
error!("Could net get remote.addr: {}", e);
56-
}
57-
}
58-
Ok(conn)
64+
pub(super) struct ProxyStream {
65+
stream: TcpStream,
66+
data: Option<Cursor<Vec<u8>>>,
67+
start_of_data: usize,
68+
}
69+
70+
impl ProxyStream {
71+
fn proxy_peer(&mut self) -> PeerAddrFuture<'_> {
72+
PeerAddrFuture::new(self)
73+
}
74+
}
75+
76+
impl AsyncRead for ProxyStream {
77+
fn poll_read(
78+
self: Pin<&mut Self>,
79+
cx: &mut Context<'_>,
80+
buf: &mut ReadBuf<'_>,
81+
) -> Poll<IoResult<()>> {
82+
let this = self.get_mut();
83+
// todo: handle the case were the data has no space in buf
84+
if let Some(data) = this.data.take() {
85+
buf.put_slice(&data.get_ref()[this.start_of_data..])
5986
}
60-
.instrument(span)
61-
})
87+
Pin::new(&mut this.stream).poll_read(cx, buf)
88+
}
89+
}
90+
91+
impl AsyncWrite for ProxyStream {
92+
fn poll_write(
93+
mut self: Pin<&mut Self>,
94+
cx: &mut Context<'_>,
95+
buf: &[u8],
96+
) -> Poll<IoResult<usize>> {
97+
Pin::new(&mut self.stream).poll_write(cx, buf)
98+
}
99+
100+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
101+
Pin::new(&mut self.stream).poll_flush(cx)
102+
}
103+
104+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
105+
Pin::new(&mut self.stream).poll_shutdown(cx)
106+
}
107+
108+
fn poll_write_vectored(
109+
mut self: Pin<&mut Self>,
110+
cx: &mut Context<'_>,
111+
bufs: &[IoSlice<'_>],
112+
) -> Poll<IoResult<usize>> {
113+
Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
114+
}
115+
116+
fn is_write_vectored(&self) -> bool {
117+
self.stream.is_write_vectored()
118+
}
62119
}
63120

64121
struct PeerAddrFuture<'a> {
@@ -158,60 +215,3 @@ impl<'a> Future for PeerAddrFuture<'a> {
158215
this.get_header()
159216
}
160217
}
161-
162-
pub(super) struct ProxyStream {
163-
stream: TcpStream,
164-
data: Option<Cursor<Vec<u8>>>,
165-
start_of_data: usize,
166-
}
167-
168-
impl ProxyStream {
169-
fn proxy_peer(&mut self) -> PeerAddrFuture<'_> {
170-
PeerAddrFuture::new(self)
171-
}
172-
}
173-
174-
impl AsyncRead for ProxyStream {
175-
fn poll_read(
176-
self: Pin<&mut Self>,
177-
cx: &mut Context<'_>,
178-
buf: &mut ReadBuf<'_>,
179-
) -> Poll<IoResult<()>> {
180-
let this = self.get_mut();
181-
// todo: handle the case were the data has no space in buf
182-
if let Some(data) = this.data.take() {
183-
buf.put_slice(&data.get_ref()[this.start_of_data..])
184-
}
185-
Pin::new(&mut this.stream).poll_read(cx, buf)
186-
}
187-
}
188-
189-
impl AsyncWrite for ProxyStream {
190-
fn poll_write(
191-
mut self: Pin<&mut Self>,
192-
cx: &mut Context<'_>,
193-
buf: &[u8],
194-
) -> Poll<IoResult<usize>> {
195-
Pin::new(&mut self.stream).poll_write(cx, buf)
196-
}
197-
198-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
199-
Pin::new(&mut self.stream).poll_flush(cx)
200-
}
201-
202-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
203-
Pin::new(&mut self.stream).poll_shutdown(cx)
204-
}
205-
206-
fn poll_write_vectored(
207-
mut self: Pin<&mut Self>,
208-
cx: &mut Context<'_>,
209-
bufs: &[IoSlice<'_>],
210-
) -> Poll<IoResult<usize>> {
211-
Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
212-
}
213-
214-
fn is_write_vectored(&self) -> bool {
215-
self.stream.is_write_vectored()
216-
}
217-
}

src/api/tls.rs

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,35 @@ use rustls::internal::pemfile::{certs, pkcs8_private_keys};
66
use rustls::{NoClientAuth, ServerConfig};
77
use sqlx::PgPool;
88
use std::sync::Arc;
9-
use tokio::io::{AsyncRead, AsyncWrite};
9+
use tokio::io::{AsyncRead, AsyncWrite, Result as IoResult};
1010
use tokio_rustls::TlsAcceptor;
1111
use tracing::{error, info};
1212

13-
use super::proxy;
14-
use crate::api::proxy::ProxyStream;
13+
use super::proxy::ProxyStream;
1514
use crate::cert::{Cert, CertFacade};
1615
use crate::util::to_u64;
1716

17+
pub(super) fn wrap(
18+
listener: impl Stream<Item = IoResult<ProxyStream>> + Send,
19+
pool: PgPool,
20+
) -> impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static, Error>> + Send
21+
{
22+
let acceptor = Acceptor::new(pool);
23+
24+
listener
25+
.err_into()
26+
.zip(repeat(acceptor))
27+
.map(|(conn, acceptor)| conn.map(|c| (c, acceptor)))
28+
.map_ok(|(conn, acceptor)| async move {
29+
let tls = acceptor.load_cert().await?;
30+
Ok(tls.accept(conn).await?)
31+
})
32+
.try_buffer_unordered(100)
33+
.inspect_err(|err| error!("Stream error: {:?}", err))
34+
.filter(|stream| futures_util::future::ready(stream.is_ok()))
35+
.into_stream()
36+
}
37+
1838
struct Acceptor {
1939
pool: PgPool,
2040
config: RwLock<(Option<Cert>, Arc<ServerConfig>)>,
@@ -78,26 +98,3 @@ impl Acceptor {
7898
Ok(TlsAcceptor::from(server_config))
7999
}
80100
}
81-
82-
pub(super) fn stream<S, E>(
83-
listener: S,
84-
pool: PgPool,
85-
) -> impl Stream<Item = Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static, Error>> + Send
86-
where
87-
S: Stream<Item = Result<ProxyStream, E>> + Send + 'static,
88-
E: Into<Error> + Send + 'static,
89-
{
90-
let acceptor = Acceptor::new(pool);
91-
92-
proxy::wrap(listener.err_into())
93-
.zip(repeat(acceptor))
94-
.map(|(conn, acceptor)| conn.map(|c| (c, acceptor)))
95-
.map_ok(|(conn, acceptor)| async move {
96-
let (conn, tls) = tokio::try_join!(conn, acceptor.load_cert())?;
97-
Ok(tls.accept(conn).await?)
98-
})
99-
.try_buffer_unordered(100)
100-
.inspect_err(|err| error!("Stream error: {:?}", err))
101-
.filter(|stream| futures_util::future::ready(stream.is_ok()))
102-
.into_stream()
103-
}

0 commit comments

Comments
 (0)