Skip to content

Commit 10cafe0

Browse files
committed
fix: export delgeate async write
1 parent 7c37bfc commit 10cafe0

File tree

4 files changed

+97
-59
lines changed

4 files changed

+97
-59
lines changed

src/api/proxy.rs

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use futures_util::{ready, TryStreamExt};
33
use ppp::error::ParseError;
44
use ppp::model::{Addresses, Header};
55
use std::future::Future;
6-
use std::io::IoSlice;
76
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
87
use std::pin::Pin;
98
use std::task::{Context, Poll};
@@ -120,33 +119,7 @@ impl<T> AsyncWrite for ProxyStream<T>
120119
where
121120
T: AsyncWrite + Unpin,
122121
{
123-
fn poll_write(
124-
mut self: Pin<&mut Self>,
125-
cx: &mut Context<'_>,
126-
buf: &[u8],
127-
) -> Poll<IoResult<usize>> {
128-
Pin::new(&mut self.stream).poll_write(cx, buf)
129-
}
130-
131-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
132-
Pin::new(&mut self.stream).poll_flush(cx)
133-
}
134-
135-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
136-
Pin::new(&mut self.stream).poll_shutdown(cx)
137-
}
138-
139-
fn poll_write_vectored(
140-
mut self: Pin<&mut Self>,
141-
cx: &mut Context<'_>,
142-
bufs: &[IoSlice<'_>],
143-
) -> Poll<IoResult<usize>> {
144-
Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
145-
}
146-
147-
fn is_write_vectored(&self) -> bool {
148-
self.stream.is_write_vectored()
149-
}
122+
delegate_async_write!(stream);
150123
}
151124

152125
struct RealAddrFuture<'a, T> {
@@ -239,12 +212,10 @@ where
239212

240213
#[cfg(test)]
241214
mod tests {
242-
use futures_util::future;
243215
use ppp::model::{Addresses, Command, Header, Protocol, Version};
244-
use std::io::{Error as IoError, ErrorKind, IoSlice, Result as IoResult};
216+
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
245217
use std::net::SocketAddr;
246-
use std::pin::Pin;
247-
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
218+
use tokio::io::AsyncReadExt;
248219
use tokio_test::io::Builder;
249220

250221
use super::{format_header, RemoteAddr, ToProxyStream};
@@ -354,28 +325,4 @@ mod tests {
354325
let actual = proxy_stream.remote_addr().unwrap();
355326
assert_eq!(SocketAddr::from(([1, 1, 1, 1], 443)), actual)
356327
}
357-
358-
#[tokio::test]
359-
async fn test_async_write_delegation() {
360-
let mut builder = Builder::new();
361-
builder.write("Test1".as_ref());
362-
builder.write("Test2".as_ref());
363-
364-
let mut proxy_stream = builder.build().source(ProxyProtocol::Disabled);
365-
assert_eq!(false, proxy_stream.is_write_vectored());
366-
367-
proxy_stream.write_all("Test1".as_ref()).await.unwrap();
368-
369-
let slice = IoSlice::new("Test2".as_ref());
370-
let size = future::poll_fn(move |cx| {
371-
Pin::new(&mut proxy_stream).poll_write_vectored(cx, &[slice])
372-
})
373-
.await
374-
.unwrap();
375-
assert_eq!(5, size);
376-
377-
let mut proxy_stream = Builder::new().build().source(ProxyProtocol::Disabled);
378-
assert_eq!((), proxy_stream.flush().await.unwrap());
379-
assert_eq!((), proxy_stream.shutdown().await.unwrap());
380-
}
381328
}

src/api/tls.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,15 @@ impl Acceptor {
102102
Ok(TlsAcceptor::from(server_config))
103103
}
104104
}
105+
106+
#[cfg(test)]
107+
mod tests {
108+
use std::io::Cursor;
109+
use tokio::io::split;
110+
111+
#[tokio::test]
112+
async fn test() {
113+
let (_client_read, _server_write) = split(Cursor::new(vec![]));
114+
let (_server_read, _client_write) = split(Cursor::new(vec![]));
115+
}
116+
}

src/main.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ use acme::DatabasePersist;
1414
use cert::CertManager;
1515
use dns::{DatabaseAuthority, DNS};
1616

17+
#[macro_use]
18+
mod util;
1719
mod acme;
1820
mod api;
1921
mod cert;
2022
mod config;
2123
mod dns;
2224
mod domain;
23-
mod util;
2425

2526
static MIGRATOR: Migrator = sqlx::migrate!("migrations/postgres");
2627

src/util.rs

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,56 @@ pub(crate) fn uuid() -> String {
3232
Uuid::new_v4().to_simple().to_string()
3333
}
3434

35+
macro_rules! delegate_async_write {
36+
($item:ident) => {
37+
fn poll_write(
38+
mut self: ::std::pin::Pin<&mut Self>,
39+
cx: &mut ::std::task::Context<'_>,
40+
buf: &[::std::primitive::u8],
41+
) -> ::std::task::Poll<::std::io::Result<usize>> {
42+
::std::pin::Pin::new(&mut self.$item).poll_write(cx, buf)
43+
}
44+
45+
fn poll_flush(
46+
mut self: ::std::pin::Pin<&mut Self>,
47+
cx: &mut ::std::task::Context<'_>,
48+
) -> ::std::task::Poll<::std::io::Result<()>> {
49+
::std::pin::Pin::new(&mut self.$item).poll_flush(cx)
50+
}
51+
52+
fn poll_shutdown(
53+
mut self: ::std::pin::Pin<&mut Self>,
54+
cx: &mut ::std::task::Context<'_>,
55+
) -> ::std::task::Poll<::std::io::Result<()>> {
56+
::std::pin::Pin::new(&mut self.$item).poll_shutdown(cx)
57+
}
58+
59+
fn poll_write_vectored(
60+
mut self: ::std::pin::Pin<&mut Self>,
61+
cx: &mut ::std::task::Context<'_>,
62+
bufs: &[::std::io::IoSlice<'_>],
63+
) -> ::std::task::Poll<::std::io::Result<::std::primitive::usize>> {
64+
::std::pin::Pin::new(&mut self.$item).poll_write_vectored(cx, bufs)
65+
}
66+
67+
fn is_write_vectored(&self) -> ::std::primitive::bool {
68+
self.$item.is_write_vectored()
69+
}
70+
};
71+
}
72+
3573
#[cfg(test)]
3674
mod tests {
37-
use super::{error, now, to_i64, to_u64, uuid};
3875
use anyhow::anyhow;
39-
use std::io::{Error as IoError, ErrorKind};
76+
use futures_util::future;
77+
use std::io::{Error as IoError, ErrorKind, IoSlice};
78+
use std::pin::Pin;
4079
use std::thread;
4180
use std::time::Duration;
81+
use tokio::io::{AsyncWrite, AsyncWriteExt};
82+
use tokio_test::io::Builder;
83+
84+
use super::{error, now, to_i64, to_u64, uuid};
4285

4386
const NUMBER_1: u64 = 2323;
4487
const NUMBER_2: u64 = 940329402394;
@@ -102,4 +145,39 @@ mod tests {
102145
let actual = actual.into_inner().expect("Error has no inner error");
103146
assert_eq!(format!("{}", expected), format!("{}", actual));
104147
}
148+
149+
struct Wrapper<T> {
150+
inner: T,
151+
}
152+
153+
impl<T: AsyncWrite + Unpin> AsyncWrite for Wrapper<T> {
154+
delegate_async_write!(inner);
155+
}
156+
157+
#[tokio::test]
158+
async fn delegate_async_write_works() {
159+
let mut builder = Builder::new();
160+
builder.write("Test1".as_ref());
161+
builder.write("Test2".as_ref());
162+
163+
let mut stream = Wrapper {
164+
inner: builder.build(),
165+
};
166+
assert_eq!(false, stream.is_write_vectored());
167+
168+
stream.write_all("Test1".as_ref()).await.unwrap();
169+
170+
let slice = IoSlice::new("Test2".as_ref());
171+
let size =
172+
future::poll_fn(move |cx| Pin::new(&mut stream).poll_write_vectored(cx, &[slice]))
173+
.await
174+
.unwrap();
175+
assert_eq!(5, size);
176+
177+
let mut stream = Wrapper {
178+
inner: Builder::new().build(),
179+
};
180+
assert_eq!((), stream.flush().await.unwrap());
181+
assert_eq!((), stream.shutdown().await.unwrap());
182+
}
105183
}

0 commit comments

Comments
 (0)