|
1 | 1 | use bytes::Bytes;
|
2 | 2 | use futures_util::FutureExt;
|
| 3 | +use http::Uri; |
3 | 4 | use integration_tests::pb::{
|
4 | 5 | test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output,
|
5 | 6 | OutputStream,
|
6 | 7 | };
|
| 8 | +use std::convert::TryFrom; |
| 9 | +use std::error::Error; |
7 | 10 | use std::time::Duration;
|
8 | 11 | use tokio::sync::oneshot;
|
9 | 12 | use tonic::metadata::{MetadataMap, MetadataValue};
|
| 13 | +use tonic::transport::Endpoint; |
10 | 14 | use tonic::{transport::Server, Code, Request, Response, Status};
|
11 | 15 |
|
12 | 16 | #[tokio::test]
|
@@ -173,8 +177,78 @@ async fn status_from_server_stream() {
|
173 | 177 | assert_eq!(stream.message().await.unwrap(), None);
|
174 | 178 | }
|
175 | 179 |
|
| 180 | +#[tokio::test] |
| 181 | +async fn status_from_server_stream_with_source() { |
| 182 | + trace_init(); |
| 183 | + |
| 184 | + let channel = Endpoint::try_from("http://[::]:50051") |
| 185 | + .unwrap() |
| 186 | + .connect_with_connector_lazy(tower::service_fn(move |_: Uri| async move { |
| 187 | + Err::<mock::MockStream, _>(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) |
| 188 | + })) |
| 189 | + .unwrap(); |
| 190 | + |
| 191 | + let mut client = test_stream_client::TestStreamClient::new(channel); |
| 192 | + |
| 193 | + let error = client.stream_call(InputStream {}).await.unwrap_err(); |
| 194 | + |
| 195 | + let source = error.source().unwrap(); |
| 196 | + source.downcast_ref::<tonic::transport::Error>().unwrap(); |
| 197 | +} |
| 198 | + |
176 | 199 | fn trace_init() {
|
177 | 200 | let _ = tracing_subscriber::FmtSubscriber::builder()
|
178 | 201 | .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
179 | 202 | .try_init();
|
180 | 203 | }
|
| 204 | + |
| 205 | +mod mock { |
| 206 | + use std::{ |
| 207 | + pin::Pin, |
| 208 | + task::{Context, Poll}, |
| 209 | + }; |
| 210 | + |
| 211 | + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| 212 | + use tonic::transport::server::Connected; |
| 213 | + |
| 214 | + #[derive(Debug)] |
| 215 | + pub struct MockStream(pub tokio::io::DuplexStream); |
| 216 | + |
| 217 | + impl Connected for MockStream { |
| 218 | + type ConnectInfo = (); |
| 219 | + |
| 220 | + /// Create type holding information about the connection. |
| 221 | + fn connect_info(&self) -> Self::ConnectInfo {} |
| 222 | + } |
| 223 | + |
| 224 | + impl AsyncRead for MockStream { |
| 225 | + fn poll_read( |
| 226 | + mut self: Pin<&mut Self>, |
| 227 | + cx: &mut Context<'_>, |
| 228 | + buf: &mut ReadBuf<'_>, |
| 229 | + ) -> Poll<std::io::Result<()>> { |
| 230 | + Pin::new(&mut self.0).poll_read(cx, buf) |
| 231 | + } |
| 232 | + } |
| 233 | + |
| 234 | + impl AsyncWrite for MockStream { |
| 235 | + fn poll_write( |
| 236 | + mut self: Pin<&mut Self>, |
| 237 | + cx: &mut Context<'_>, |
| 238 | + buf: &[u8], |
| 239 | + ) -> Poll<std::io::Result<usize>> { |
| 240 | + Pin::new(&mut self.0).poll_write(cx, buf) |
| 241 | + } |
| 242 | + |
| 243 | + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> { |
| 244 | + Pin::new(&mut self.0).poll_flush(cx) |
| 245 | + } |
| 246 | + |
| 247 | + fn poll_shutdown( |
| 248 | + mut self: Pin<&mut Self>, |
| 249 | + cx: &mut Context<'_>, |
| 250 | + ) -> Poll<std::io::Result<()>> { |
| 251 | + Pin::new(&mut self.0).poll_shutdown(cx) |
| 252 | + } |
| 253 | + } |
| 254 | +} |
0 commit comments