1
- use futures_util:: stream:: { Stream , TryStream } ;
1
+ use futures_util:: stream:: Stream ;
2
2
use futures_util:: TryStreamExt ;
3
3
use ppp:: error:: ParseError ;
4
4
use ppp:: model:: { Addresses , Header } ;
@@ -9,13 +9,39 @@ use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
9
9
use std:: pin:: Pin ;
10
10
use std:: task:: { Context , Poll } ;
11
11
use tokio:: io:: { AsyncRead , AsyncWrite , Error as IoError , ErrorKind , ReadBuf , Result as IoResult } ;
12
- use tokio:: net:: TcpStream ;
12
+ use tokio:: net:: { TcpListener , TcpStream } ;
13
13
use tracing:: field:: { display, Empty } ;
14
14
use tracing:: { debug_span, error, info, Instrument , Span } ;
15
+ use tokio_stream:: wrappers:: TcpListenerStream ;
15
16
16
17
use crate :: config:: ProxyProtocol ;
17
18
18
- // wrap tcplistener instead of tcpstream
19
+ pub ( super ) fn wrap (
20
+ listener : TcpListener ,
21
+ proxy : ProxyProtocol ,
22
+ ) -> impl Stream < Item = IoResult < ProxyStream > > + Send {
23
+ TcpListenerStream :: new ( listener)
24
+ . map_ok ( move |stream| stream. source ( proxy) )
25
+ . map_ok ( |mut conn| {
26
+ let span = debug_span ! ( "ADDR" , remote. addr = Empty ) ;
27
+ async move {
28
+ let span = Span :: current ( ) ;
29
+ match conn. proxy_peer ( ) . await {
30
+ Ok ( addr) => {
31
+ span. record ( "remote.addr" , & display ( addr) ) ;
32
+ info ! ( "Got addr {}" , addr)
33
+ }
34
+ Err ( e) => {
35
+ span. record ( "remote.addr" , & "Unknown" ) ;
36
+ error ! ( "Could net get remote.addr: {}" , e) ;
37
+ }
38
+ }
39
+ Ok ( conn)
40
+ }
41
+ . instrument ( span)
42
+ } )
43
+ . try_buffer_unordered ( 100 )
44
+ }
19
45
20
46
pub ( super ) trait ToProxyStream : Sized {
21
47
fn source ( self , proxy : ProxyProtocol ) -> ProxyStream ;
@@ -35,30 +61,61 @@ impl ToProxyStream for TcpStream {
35
61
}
36
62
}
37
63
38
- pub ( super ) fn wrap < S , E > (
39
- stream : S ,
40
- ) -> impl Stream < Item = Result < impl Future < Output = Result < impl AsyncRead + AsyncWrite , E > > , E > >
41
- where
42
- S : TryStream < Ok = ProxyStream , Error = E > ,
43
- {
44
- stream. map_ok ( |mut conn| {
45
- let span = debug_span ! ( "ADDR" , remote. addr = Empty ) ;
46
- async move {
47
- let span = Span :: current ( ) ;
48
- match conn. proxy_peer ( ) . await {
49
- Ok ( addr) => {
50
- span. record ( "remote.addr" , & display ( addr) ) ;
51
- info ! ( "Got addr {}" , addr)
52
- }
53
- Err ( e) => {
54
- span. record ( "remote.addr" , & "Unknown" ) ;
55
- error ! ( "Could net get remote.addr: {}" , e) ;
56
- }
57
- }
58
- Ok ( conn)
64
+ pub ( super ) struct ProxyStream {
65
+ stream : TcpStream ,
66
+ data : Option < Cursor < Vec < u8 > > > ,
67
+ start_of_data : usize ,
68
+ }
69
+
70
+ impl ProxyStream {
71
+ fn proxy_peer ( & mut self ) -> PeerAddrFuture < ' _ > {
72
+ PeerAddrFuture :: new ( self )
73
+ }
74
+ }
75
+
76
+ impl AsyncRead for ProxyStream {
77
+ fn poll_read (
78
+ self : Pin < & mut Self > ,
79
+ cx : & mut Context < ' _ > ,
80
+ buf : & mut ReadBuf < ' _ > ,
81
+ ) -> Poll < IoResult < ( ) > > {
82
+ let this = self . get_mut ( ) ;
83
+ // todo: handle the case were the data has no space in buf
84
+ if let Some ( data) = this. data . take ( ) {
85
+ buf. put_slice ( & data. get_ref ( ) [ this. start_of_data ..] )
59
86
}
60
- . instrument ( span)
61
- } )
87
+ Pin :: new ( & mut this. stream ) . poll_read ( cx, buf)
88
+ }
89
+ }
90
+
91
+ impl AsyncWrite for ProxyStream {
92
+ fn poll_write (
93
+ mut self : Pin < & mut Self > ,
94
+ cx : & mut Context < ' _ > ,
95
+ buf : & [ u8 ] ,
96
+ ) -> Poll < IoResult < usize > > {
97
+ Pin :: new ( & mut self . stream ) . poll_write ( cx, buf)
98
+ }
99
+
100
+ fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < IoResult < ( ) > > {
101
+ Pin :: new ( & mut self . stream ) . poll_flush ( cx)
102
+ }
103
+
104
+ fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < IoResult < ( ) > > {
105
+ Pin :: new ( & mut self . stream ) . poll_shutdown ( cx)
106
+ }
107
+
108
+ fn poll_write_vectored (
109
+ mut self : Pin < & mut Self > ,
110
+ cx : & mut Context < ' _ > ,
111
+ bufs : & [ IoSlice < ' _ > ] ,
112
+ ) -> Poll < IoResult < usize > > {
113
+ Pin :: new ( & mut self . stream ) . poll_write_vectored ( cx, bufs)
114
+ }
115
+
116
+ fn is_write_vectored ( & self ) -> bool {
117
+ self . stream . is_write_vectored ( )
118
+ }
62
119
}
63
120
64
121
struct PeerAddrFuture < ' a > {
@@ -158,60 +215,3 @@ impl<'a> Future for PeerAddrFuture<'a> {
158
215
this. get_header ( )
159
216
}
160
217
}
161
-
162
- pub ( super ) struct ProxyStream {
163
- stream : TcpStream ,
164
- data : Option < Cursor < Vec < u8 > > > ,
165
- start_of_data : usize ,
166
- }
167
-
168
- impl ProxyStream {
169
- fn proxy_peer ( & mut self ) -> PeerAddrFuture < ' _ > {
170
- PeerAddrFuture :: new ( self )
171
- }
172
- }
173
-
174
- impl AsyncRead for ProxyStream {
175
- fn poll_read (
176
- self : Pin < & mut Self > ,
177
- cx : & mut Context < ' _ > ,
178
- buf : & mut ReadBuf < ' _ > ,
179
- ) -> Poll < IoResult < ( ) > > {
180
- let this = self . get_mut ( ) ;
181
- // todo: handle the case were the data has no space in buf
182
- if let Some ( data) = this. data . take ( ) {
183
- buf. put_slice ( & data. get_ref ( ) [ this. start_of_data ..] )
184
- }
185
- Pin :: new ( & mut this. stream ) . poll_read ( cx, buf)
186
- }
187
- }
188
-
189
- impl AsyncWrite for ProxyStream {
190
- fn poll_write (
191
- mut self : Pin < & mut Self > ,
192
- cx : & mut Context < ' _ > ,
193
- buf : & [ u8 ] ,
194
- ) -> Poll < IoResult < usize > > {
195
- Pin :: new ( & mut self . stream ) . poll_write ( cx, buf)
196
- }
197
-
198
- fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < IoResult < ( ) > > {
199
- Pin :: new ( & mut self . stream ) . poll_flush ( cx)
200
- }
201
-
202
- fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < IoResult < ( ) > > {
203
- Pin :: new ( & mut self . stream ) . poll_shutdown ( cx)
204
- }
205
-
206
- fn poll_write_vectored (
207
- mut self : Pin < & mut Self > ,
208
- cx : & mut Context < ' _ > ,
209
- bufs : & [ IoSlice < ' _ > ] ,
210
- ) -> Poll < IoResult < usize > > {
211
- Pin :: new ( & mut self . stream ) . poll_write_vectored ( cx, bufs)
212
- }
213
-
214
- fn is_write_vectored ( & self ) -> bool {
215
- self . stream . is_write_vectored ( )
216
- }
217
- }
0 commit comments