Skip to content

Commit 0a020ee

Browse files
Add another test
1 parent 61af8bc commit 0a020ee

File tree

2 files changed

+83
-29
lines changed

2 files changed

+83
-29
lines changed

src/read/decoder.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ impl<'a, R: io::Read> DecoderReader<'a, R> {
126126
/// Decode the requested number of bytes from the b64 buffer into the provided buffer. It's the
127127
/// caller's responsibility to choose the number of b64 bytes to decode correctly.
128128
///
129-
/// Returns a Result with the number of decoded bytes written.
129+
/// Returns a Result with the number of decoded bytes written to `buf`.
130130
fn decode_to_buf(&mut self, num_bytes: usize, buf: &mut [u8]) -> io::Result<usize> {
131131
debug_assert!(self.b64_len >= num_bytes);
132132
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
@@ -149,8 +149,8 @@ impl<'a, R: io::Read> DecoderReader<'a, R> {
149149
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
150150

151151
self.total_b64_decoded += num_bytes;
152-
self.b64_len -= num_bytes;
153152
self.b64_offset += num_bytes;
153+
self.b64_len -= num_bytes;
154154

155155
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
156156

@@ -224,6 +224,7 @@ impl<'a, R: Read> Read for DecoderReader<'a, R> {
224224
}
225225

226226
if self.b64_len == 0 {
227+
debug_assert!(at_eof);
227228
// we must be at EOF, and we have no data left to decode
228229
return Ok(0);
229230
};
@@ -236,13 +237,15 @@ impl<'a, R: Read> Read for DecoderReader<'a, R> {
236237
self.b64_len >= BASE64_CHUNK_SIZE
237238
});
238239

240+
debug_assert_eq!(0, self.decoded_len);
241+
239242
if buf.len() < DECODED_CHUNK_SIZE {
240243
// caller requested an annoyingly short read
241-
debug_assert_eq!(0, self.decoded_len);
242-
243244
// have to write to a tmp buf first to avoid double mutable borrow
244245
let mut decoded_chunk = [0_u8; DECODED_CHUNK_SIZE];
245-
// if we are at eof, could have less than BASE64_CHUNK_SIZE
246+
// if we are at eof, could have less than BASE64_CHUNK_SIZE, in which case we have
247+
// to assume that these last few tokens are, in fact, valid (i.e. must be 2-4 b64
248+
// tokens, not 1, since 1 token can't decode to 1 byte).
246249
let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE);
247250

248251
let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?;
@@ -256,20 +259,22 @@ impl<'a, R: Read> Read for DecoderReader<'a, R> {
256259

257260
self.flush_decoded_buf(buf)
258261
} else {
259-
let bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE)
262+
let b64_bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE)
260263
.checked_mul(BASE64_CHUNK_SIZE)
261264
.expect("too many chunks");
262-
debug_assert!(bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE);
265+
debug_assert!(b64_bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE);
263266

264-
let bytes_available_to_decode = if at_eof {
267+
let b64_bytes_available_to_decode = if at_eof {
265268
self.b64_len
266269
} else {
267270
// only use complete chunks
268271
self.b64_len - self.b64_len % 4
269272
};
270273

271-
let actual_decode_len =
272-
cmp::min(bytes_that_can_decode_into_buf, bytes_available_to_decode);
274+
let actual_decode_len = cmp::min(
275+
b64_bytes_that_can_decode_into_buf,
276+
b64_bytes_available_to_decode,
277+
);
273278
self.decode_to_buf(actual_decode_len, buf)
274279
}
275280
}

src/read/decoder_tests.rs

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -142,25 +142,42 @@ fn read_in_short_increments() {
142142
let mut wrapped_reader = io::Cursor::new(&b64[..]);
143143
let mut decoder = DecoderReader::new(&mut wrapped_reader, config);
144144

145-
let mut total_read = 0_usize;
146-
loop {
147-
assert!(total_read <= size, "tr {} size {}", total_read, size);
148-
if total_read == size {
149-
assert_eq!(bytes, &decoded[..total_read]);
150-
// should be done
151-
assert_eq!(0, decoder.read(&mut decoded[..]).unwrap());
152-
// didn't write anything
153-
assert_eq!(bytes, &decoded[..total_read]);
154-
155-
break;
156-
}
157-
let decode_len = rng.gen_range(1, cmp::max(2, size * 2));
145+
consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut decoder);
146+
}
147+
}
158148

159-
let read = decoder
160-
.read(&mut decoded[total_read..total_read + decode_len])
161-
.unwrap();
162-
total_read += read;
163-
}
149+
#[test]
150+
fn read_in_short_increments_with_short_delegate_reads() {
151+
let mut rng = rand::thread_rng();
152+
let mut bytes = Vec::new();
153+
let mut b64 = String::new();
154+
let mut decoded = Vec::new();
155+
156+
for _ in 0..10_000 {
157+
bytes.clear();
158+
b64.clear();
159+
decoded.clear();
160+
161+
let size = rng.gen_range(0, 10 * BUF_SIZE);
162+
bytes.extend(iter::repeat(0).take(size));
163+
// leave room to play around with larger buffers
164+
decoded.extend(iter::repeat(0).take(size * 3));
165+
166+
rng.fill_bytes(&mut bytes[..]);
167+
assert_eq!(size, bytes.len());
168+
169+
let config = random_config(&mut rng);
170+
171+
encode_config_buf(&bytes[..], config, &mut b64);
172+
173+
let mut base_reader = io::Cursor::new(&b64[..]);
174+
let mut decoder = DecoderReader::new(&mut base_reader, config);
175+
let mut short_reader = RandomShortRead {
176+
delegate: &mut decoder,
177+
rng: &mut rand::thread_rng(),
178+
};
179+
180+
consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut short_reader)
164181
}
165182
}
166183

@@ -268,6 +285,38 @@ fn reports_invalid_byte_correctly() {
268285
}
269286
}
270287

288+
fn consume_with_short_reads_and_validate<R: Read>(
289+
rng: &mut rand::rngs::ThreadRng,
290+
expected_bytes: &[u8],
291+
decoded: &mut Vec<u8>,
292+
short_reader: &mut R,
293+
) -> () {
294+
let mut total_read = 0_usize;
295+
loop {
296+
assert!(
297+
total_read <= expected_bytes.len(),
298+
"tr {} size {}",
299+
total_read,
300+
expected_bytes.len()
301+
);
302+
if total_read == expected_bytes.len() {
303+
assert_eq!(expected_bytes, &decoded[..total_read]);
304+
// should be done
305+
assert_eq!(0, short_reader.read(&mut decoded[..]).unwrap());
306+
// didn't write anything
307+
assert_eq!(expected_bytes, &decoded[..total_read]);
308+
309+
break;
310+
}
311+
let decode_len = rng.gen_range(1, cmp::max(2, expected_bytes.len() * 2));
312+
313+
let read = short_reader
314+
.read(&mut decoded[total_read..total_read + decode_len])
315+
.unwrap();
316+
total_read += read;
317+
}
318+
}
319+
271320
/// Limits how many bytes a reader will provide in each read call.
272321
/// Useful for shaking out code that may work fine only with typical input sources that always fill
273322
/// the buffer.
@@ -279,7 +328,7 @@ struct RandomShortRead<'a, 'b, R: io::Read, N: rand::Rng> {
279328
impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, N> {
280329
fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
281330
// avoid 0 since it means EOF for non-empty buffers
282-
let effective_len = self.rng.gen_range(1, 20);
331+
let effective_len = cmp::min(self.rng.gen_range(1, 20), buf.len());
283332

284333
self.delegate.read(&mut buf[..effective_len])
285334
}

0 commit comments

Comments
 (0)