Skip to content

Commit c1e6206

Browse files
authored
fix(tonic): respect max_message_size when decompressing a message (#2484)
This patch fixes a bug where the configured max_message_size is not respected when a payload is compressed, leading the possibility of over allocating and exhausting memory over the configured maximum message size when decompressing the message. ## Motivation When compression is enabled, and a compressed message is sent or received, the configured max_message_size (for both clients and servers) is only checked against the compressed message size and that limit is not respected for the resulting uncompressed message, which can lead to resource exhaustion. ## Solution Respect the configured, or default, max_message_size limit while decompressing a message, returning an error if the resultant decompressed message would exceed the limit.
1 parent a58c291 commit c1e6206

File tree

4 files changed

+160
-5
lines changed

4 files changed

+160
-5
lines changed

tests/compression/src/compressing_request.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,71 @@ async fn client_mark_compressed_without_header_server_enabled(encoding: Compress
234234
"protocol error: received message with compressed-flag but no grpc-encoding was specified"
235235
);
236236
}
237+
238+
util::parametrized_tests! {
239+
limit_decoded_message_size,
240+
zstd: CompressionEncoding::Zstd,
241+
gzip: CompressionEncoding::Gzip,
242+
deflate: CompressionEncoding::Deflate,
243+
}
244+
245+
#[cfg(test)]
246+
async fn limit_decoded_message_size(encoding: CompressionEncoding) {
247+
use prost::Message;
248+
249+
let under_limit_request = SomeData {
250+
data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
251+
};
252+
let limit = under_limit_request.encoded_len();
253+
let over_limit_request = SomeData {
254+
data: [0_u8; 1 + UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
255+
};
256+
257+
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
258+
259+
let svc = test_server::TestServer::new(Svc::default())
260+
.accept_compressed(encoding)
261+
.max_decoding_message_size(limit);
262+
263+
let request_bytes_counter = Arc::new(AtomicUsize::new(0));
264+
265+
tokio::spawn({
266+
let request_bytes_counter = request_bytes_counter.clone();
267+
async move {
268+
Server::builder()
269+
.layer(
270+
ServiceBuilder::new()
271+
.layer(
272+
ServiceBuilder::new()
273+
.layer(measure_request_body_size_layer(request_bytes_counter))
274+
.into_inner(),
275+
)
276+
.into_inner(),
277+
)
278+
.add_service(svc)
279+
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
280+
.await
281+
.unwrap();
282+
}
283+
});
284+
285+
let mut client =
286+
test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
287+
288+
for _ in 0..3 {
289+
// compressed messages that are under or exactly at the limit are successful.
290+
client
291+
.compress_input_unary(under_limit_request.clone())
292+
.await
293+
.unwrap();
294+
let bytes_sent = request_bytes_counter.load(SeqCst);
295+
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
296+
297+
// compressed messages that are over the limit are fail with resource exhausted
298+
let status = client
299+
.compress_input_unary(over_limit_request.clone())
300+
.await
301+
.unwrap_err();
302+
assert_eq!(status.code(), tonic::Code::ResourceExhausted);
303+
}
304+
}

tests/compression/src/compressing_response.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,3 +488,78 @@ async fn disabling_compression_on_response_from_client_stream(encoding: Compress
488488
let bytes_sent = response_bytes_counter.load(SeqCst);
489489
assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
490490
}
491+
492+
util::parametrized_tests! {
493+
limit_decoded_message_size,
494+
zstd: CompressionEncoding::Zstd,
495+
gzip: CompressionEncoding::Gzip,
496+
deflate: CompressionEncoding::Deflate,
497+
}
498+
499+
#[cfg(test)]
500+
async fn limit_decoded_message_size(encoding: CompressionEncoding) {
501+
use prost::Message;
502+
503+
let under_limit_request = SomeData {
504+
data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
505+
};
506+
let limit = under_limit_request.encoded_len();
507+
508+
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
509+
510+
let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
511+
512+
let response_bytes_counter = Arc::new(AtomicUsize::new(0));
513+
514+
tokio::spawn({
515+
let response_bytes_counter = response_bytes_counter.clone();
516+
async move {
517+
Server::builder()
518+
.layer(
519+
ServiceBuilder::new()
520+
.layer(MapResponseBodyLayer::new(move |body| {
521+
util::CountBytesBody {
522+
inner: body,
523+
counter: response_bytes_counter.clone(),
524+
}
525+
}))
526+
.into_inner(),
527+
)
528+
.add_service(svc)
529+
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
530+
.await
531+
.unwrap();
532+
}
533+
});
534+
535+
let expected = match encoding {
536+
CompressionEncoding::Gzip => "gzip",
537+
CompressionEncoding::Zstd => "zstd",
538+
CompressionEncoding::Deflate => "deflate",
539+
_ => panic!("unexpected encoding {encoding:?}"),
540+
};
541+
542+
// compressed messages that are under or exactly at the limit are successful.
543+
let mut under_limit_client = test_client::TestClient::new(mock_io_channel(client).await)
544+
.accept_compressed(encoding)
545+
.max_decoding_message_size(limit);
546+
547+
for _ in 0..3 {
548+
let res = under_limit_client.compress_output_unary(()).await.unwrap();
549+
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
550+
let bytes_sent = response_bytes_counter.load(SeqCst);
551+
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
552+
}
553+
554+
// compressed messages that are over the limit are fail with resource exhausted
555+
let mut over_limit_client = under_limit_client.max_decoding_message_size(limit - 1);
556+
557+
for _ in 0..3 {
558+
let status = over_limit_client
559+
.compress_output_unary(())
560+
.await
561+
.unwrap_err();
562+
563+
assert_eq!(status.code(), tonic::Code::ResourceExhausted);
564+
}
565+
}

tonic/src/codec/compression.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,17 @@ pub(crate) fn compress(
256256
pub(crate) fn decompress(
257257
settings: CompressionSettings,
258258
compressed_buf: &mut BytesMut,
259-
out_buf: &mut BytesMut,
259+
mut out_buf: bytes::buf::Limit<&mut BytesMut>,
260260
len: usize,
261261
) -> Result<(), std::io::Error> {
262262
let buffer_growth_interval = settings.buffer_growth_interval;
263263
let estimate_decompressed_len = len * 2;
264-
let capacity =
265-
((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
266-
out_buf.reserve(capacity);
264+
let capacity = std::cmp::min(
265+
bytes::buf::Limit::limit(&out_buf),
266+
((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval,
267+
);
268+
269+
out_buf.get_mut().reserve(capacity);
267270

268271
#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
269272
let mut out_writer = out_buf.writer();

tonic/src/codec/decode.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,25 @@ impl StreamingInner {
211211

212212
let decode_buf = if let Some(encoding) = compression {
213213
self.decompress_buf.clear();
214+
let limit = self
215+
.max_message_size
216+
.unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
217+
let limited_out_buf = (&mut self.decompress_buf).limit(limit);
214218

215219
if let Err(err) = decompress(
216220
CompressionSettings {
217221
encoding,
218222
buffer_growth_interval: buffer_settings.buffer_size,
219223
},
220224
&mut self.buf,
221-
&mut self.decompress_buf,
225+
limited_out_buf,
222226
len,
223227
) {
228+
if matches!(err.kind(), std::io::ErrorKind::WriteZero) {
229+
return Err(Status::resource_exhausted(format!(
230+
"Error decompressing: size limit, of {limit} bytes, exceeded while decompressing message"
231+
)));
232+
}
224233
let message = if let Direction::Response(status) = self.direction {
225234
format!(
226235
"Error decompressing: {err}, while receiving response with status: {status}"

0 commit comments

Comments
 (0)