]>
Commit | Line | Data |
---|---|---|
0a29b90c FG |
1 | use std::{ |
2 | cmp, | |
3 | io::{self, Read as _}, | |
4 | iter, | |
5 | }; | |
6 | ||
7 | use rand::{Rng as _, RngCore as _}; | |
8 | ||
9 | use super::decoder::{DecoderReader, BUF_SIZE}; | |
10 | use crate::{ | |
fe692bf9 | 11 | alphabet, |
0a29b90c FG |
12 | engine::{general_purpose::STANDARD, Engine, GeneralPurpose}, |
13 | tests::{random_alphabet, random_config, random_engine}, | |
fe692bf9 | 14 | DecodeError, PAD_BYTE, |
0a29b90c FG |
15 | }; |
16 | ||
17 | #[test] | |
18 | fn simple() { | |
19 | let tests: &[(&[u8], &[u8])] = &[ | |
20 | (&b"0"[..], &b"MA=="[..]), | |
21 | (b"01", b"MDE="), | |
22 | (b"012", b"MDEy"), | |
23 | (b"0123", b"MDEyMw=="), | |
24 | (b"01234", b"MDEyMzQ="), | |
25 | (b"012345", b"MDEyMzQ1"), | |
26 | (b"0123456", b"MDEyMzQ1Ng=="), | |
27 | (b"01234567", b"MDEyMzQ1Njc="), | |
28 | (b"012345678", b"MDEyMzQ1Njc4"), | |
29 | (b"0123456789", b"MDEyMzQ1Njc4OQ=="), | |
30 | ][..]; | |
31 | ||
32 | for (text_expected, base64data) in tests.iter() { | |
33 | // Read n bytes at a time. | |
34 | for n in 1..base64data.len() + 1 { | |
35 | let mut wrapped_reader = io::Cursor::new(base64data); | |
36 | let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD); | |
37 | ||
38 | // handle errors as you normally would | |
39 | let mut text_got = Vec::new(); | |
40 | let mut buffer = vec![0u8; n]; | |
41 | while let Ok(read) = decoder.read(&mut buffer[..]) { | |
42 | if read == 0 { | |
43 | break; | |
44 | } | |
45 | text_got.extend_from_slice(&buffer[..read]); | |
46 | } | |
47 | ||
48 | assert_eq!( | |
49 | text_got, | |
50 | *text_expected, | |
51 | "\nGot: {}\nExpected: {}", | |
52 | String::from_utf8_lossy(&text_got[..]), | |
53 | String::from_utf8_lossy(text_expected) | |
54 | ); | |
55 | } | |
56 | } | |
57 | } | |
58 | ||
59 | // Make sure we error out on trailing junk. | |
60 | #[test] | |
61 | fn trailing_junk() { | |
62 | let tests: &[&[u8]] = &[&b"MDEyMzQ1Njc4*!@#$%^&"[..], b"MDEyMzQ1Njc4OQ== "][..]; | |
63 | ||
64 | for base64data in tests.iter() { | |
65 | // Read n bytes at a time. | |
66 | for n in 1..base64data.len() + 1 { | |
67 | let mut wrapped_reader = io::Cursor::new(base64data); | |
68 | let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD); | |
69 | ||
70 | // handle errors as you normally would | |
71 | let mut buffer = vec![0u8; n]; | |
72 | let mut saw_error = false; | |
73 | loop { | |
74 | match decoder.read(&mut buffer[..]) { | |
75 | Err(_) => { | |
76 | saw_error = true; | |
77 | break; | |
78 | } | |
ed00b5ec | 79 | Ok(0) => break, |
0a29b90c FG |
80 | Ok(_) => (), |
81 | } | |
82 | } | |
83 | ||
84 | assert!(saw_error); | |
85 | } | |
86 | } | |
87 | } | |
88 | ||
89 | #[test] | |
90 | fn handles_short_read_from_delegate() { | |
91 | let mut rng = rand::thread_rng(); | |
92 | let mut bytes = Vec::new(); | |
93 | let mut b64 = String::new(); | |
94 | let mut decoded = Vec::new(); | |
95 | ||
96 | for _ in 0..10_000 { | |
97 | bytes.clear(); | |
98 | b64.clear(); | |
99 | decoded.clear(); | |
100 | ||
101 | let size = rng.gen_range(0..(10 * BUF_SIZE)); | |
102 | bytes.extend(iter::repeat(0).take(size)); | |
103 | bytes.truncate(size); | |
104 | rng.fill_bytes(&mut bytes[..size]); | |
105 | assert_eq!(size, bytes.len()); | |
106 | ||
107 | let engine = random_engine(&mut rng); | |
108 | engine.encode_string(&bytes[..], &mut b64); | |
109 | ||
110 | let mut wrapped_reader = io::Cursor::new(b64.as_bytes()); | |
111 | let mut short_reader = RandomShortRead { | |
112 | delegate: &mut wrapped_reader, | |
113 | rng: &mut rng, | |
114 | }; | |
115 | ||
116 | let mut decoder = DecoderReader::new(&mut short_reader, &engine); | |
117 | ||
118 | let decoded_len = decoder.read_to_end(&mut decoded).unwrap(); | |
119 | assert_eq!(size, decoded_len); | |
120 | assert_eq!(&bytes[..], &decoded[..]); | |
121 | } | |
122 | } | |
123 | ||
124 | #[test] | |
125 | fn read_in_short_increments() { | |
126 | let mut rng = rand::thread_rng(); | |
127 | let mut bytes = Vec::new(); | |
128 | let mut b64 = String::new(); | |
129 | let mut decoded = Vec::new(); | |
130 | ||
131 | for _ in 0..10_000 { | |
132 | bytes.clear(); | |
133 | b64.clear(); | |
134 | decoded.clear(); | |
135 | ||
136 | let size = rng.gen_range(0..(10 * BUF_SIZE)); | |
137 | bytes.extend(iter::repeat(0).take(size)); | |
138 | // leave room to play around with larger buffers | |
139 | decoded.extend(iter::repeat(0).take(size * 3)); | |
140 | ||
141 | rng.fill_bytes(&mut bytes[..]); | |
142 | assert_eq!(size, bytes.len()); | |
143 | ||
144 | let engine = random_engine(&mut rng); | |
145 | ||
146 | engine.encode_string(&bytes[..], &mut b64); | |
147 | ||
148 | let mut wrapped_reader = io::Cursor::new(&b64[..]); | |
149 | let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); | |
150 | ||
151 | consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut decoder); | |
152 | } | |
153 | } | |
154 | ||
155 | #[test] | |
156 | fn read_in_short_increments_with_short_delegate_reads() { | |
157 | let mut rng = rand::thread_rng(); | |
158 | let mut bytes = Vec::new(); | |
159 | let mut b64 = String::new(); | |
160 | let mut decoded = Vec::new(); | |
161 | ||
162 | for _ in 0..10_000 { | |
163 | bytes.clear(); | |
164 | b64.clear(); | |
165 | decoded.clear(); | |
166 | ||
167 | let size = rng.gen_range(0..(10 * BUF_SIZE)); | |
168 | bytes.extend(iter::repeat(0).take(size)); | |
169 | // leave room to play around with larger buffers | |
170 | decoded.extend(iter::repeat(0).take(size * 3)); | |
171 | ||
172 | rng.fill_bytes(&mut bytes[..]); | |
173 | assert_eq!(size, bytes.len()); | |
174 | ||
175 | let engine = random_engine(&mut rng); | |
176 | ||
177 | engine.encode_string(&bytes[..], &mut b64); | |
178 | ||
179 | let mut base_reader = io::Cursor::new(&b64[..]); | |
180 | let mut decoder = DecoderReader::new(&mut base_reader, &engine); | |
181 | let mut short_reader = RandomShortRead { | |
182 | delegate: &mut decoder, | |
183 | rng: &mut rand::thread_rng(), | |
184 | }; | |
185 | ||
186 | consume_with_short_reads_and_validate( | |
187 | &mut rng, | |
188 | &bytes[..], | |
189 | &mut decoded, | |
190 | &mut short_reader, | |
191 | ); | |
192 | } | |
193 | } | |
194 | ||
195 | #[test] | |
196 | fn reports_invalid_last_symbol_correctly() { | |
197 | let mut rng = rand::thread_rng(); | |
198 | let mut bytes = Vec::new(); | |
199 | let mut b64 = String::new(); | |
200 | let mut b64_bytes = Vec::new(); | |
201 | let mut decoded = Vec::new(); | |
202 | let mut bulk_decoded = Vec::new(); | |
203 | ||
204 | for _ in 0..1_000 { | |
205 | bytes.clear(); | |
206 | b64.clear(); | |
207 | b64_bytes.clear(); | |
208 | ||
209 | let size = rng.gen_range(1..(10 * BUF_SIZE)); | |
210 | bytes.extend(iter::repeat(0).take(size)); | |
211 | decoded.extend(iter::repeat(0).take(size)); | |
212 | rng.fill_bytes(&mut bytes[..]); | |
213 | assert_eq!(size, bytes.len()); | |
214 | ||
215 | let config = random_config(&mut rng); | |
216 | let alphabet = random_alphabet(&mut rng); | |
217 | // changing padding will cause invalid padding errors when we twiddle the last byte | |
218 | let engine = GeneralPurpose::new(alphabet, config.with_encode_padding(false)); | |
219 | engine.encode_string(&bytes[..], &mut b64); | |
220 | b64_bytes.extend(b64.bytes()); | |
221 | assert_eq!(b64_bytes.len(), b64.len()); | |
222 | ||
223 | // change the last character to every possible symbol. Should behave the same as bulk | |
224 | // decoding whether invalid or valid. | |
225 | for &s1 in alphabet.symbols.iter() { | |
226 | decoded.clear(); | |
227 | bulk_decoded.clear(); | |
228 | ||
229 | // replace the last | |
230 | *b64_bytes.last_mut().unwrap() = s1; | |
231 | let bulk_res = engine.decode_vec(&b64_bytes[..], &mut bulk_decoded); | |
232 | ||
233 | let mut wrapped_reader = io::Cursor::new(&b64_bytes[..]); | |
234 | let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); | |
235 | ||
236 | let stream_res = decoder.read_to_end(&mut decoded).map(|_| ()).map_err(|e| { | |
237 | e.into_inner() | |
238 | .and_then(|e| e.downcast::<DecodeError>().ok()) | |
239 | }); | |
240 | ||
241 | assert_eq!(bulk_res.map_err(|e| Some(Box::new(e))), stream_res); | |
242 | } | |
243 | } | |
244 | } | |
245 | ||
246 | #[test] | |
247 | fn reports_invalid_byte_correctly() { | |
248 | let mut rng = rand::thread_rng(); | |
249 | let mut bytes = Vec::new(); | |
250 | let mut b64 = String::new(); | |
fe692bf9 FG |
251 | let mut stream_decoded = Vec::new(); |
252 | let mut bulk_decoded = Vec::new(); | |
0a29b90c FG |
253 | |
254 | for _ in 0..10_000 { | |
255 | bytes.clear(); | |
256 | b64.clear(); | |
fe692bf9 FG |
257 | stream_decoded.clear(); |
258 | bulk_decoded.clear(); | |
0a29b90c FG |
259 | |
260 | let size = rng.gen_range(1..(10 * BUF_SIZE)); | |
261 | bytes.extend(iter::repeat(0).take(size)); | |
262 | rng.fill_bytes(&mut bytes[..size]); | |
263 | assert_eq!(size, bytes.len()); | |
264 | ||
fe692bf9 | 265 | let engine = GeneralPurpose::new(&alphabet::STANDARD, random_config(&mut rng)); |
0a29b90c FG |
266 | |
267 | engine.encode_string(&bytes[..], &mut b64); | |
268 | // replace one byte, somewhere, with '*', which is invalid | |
269 | let bad_byte_pos = rng.gen_range(0..b64.len()); | |
270 | let mut b64_bytes = b64.bytes().collect::<Vec<u8>>(); | |
271 | b64_bytes[bad_byte_pos] = b'*'; | |
272 | ||
273 | let mut wrapped_reader = io::Cursor::new(b64_bytes.clone()); | |
274 | let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); | |
275 | ||
0a29b90c | 276 | let read_decode_err = decoder |
fe692bf9 | 277 | .read_to_end(&mut stream_decoded) |
0a29b90c FG |
278 | .map_err(|e| { |
279 | let kind = e.kind(); | |
280 | let inner = e | |
281 | .into_inner() | |
282 | .and_then(|e| e.downcast::<DecodeError>().ok()); | |
283 | inner.map(|i| (*i, kind)) | |
284 | }) | |
285 | .err() | |
286 | .and_then(|o| o); | |
287 | ||
fe692bf9 | 288 | let bulk_decode_err = engine.decode_vec(&b64_bytes[..], &mut bulk_decoded).err(); |
0a29b90c FG |
289 | |
290 | // it's tricky to predict where the invalid data's offset will be since if it's in the last | |
291 | // chunk it will be reported at the first padding location because it's treated as invalid | |
292 | // padding. So, we just check that it's the same as it is for decoding all at once. | |
293 | assert_eq!( | |
294 | bulk_decode_err.map(|e| (e, io::ErrorKind::InvalidData)), | |
295 | read_decode_err | |
296 | ); | |
297 | } | |
298 | } | |
299 | ||
fe692bf9 FG |
300 | #[test] |
301 | fn internal_padding_error_with_short_read_concatenated_texts_invalid_byte_error() { | |
302 | let mut rng = rand::thread_rng(); | |
303 | let mut bytes = Vec::new(); | |
304 | let mut b64 = String::new(); | |
305 | let mut reader_decoded = Vec::new(); | |
306 | let mut bulk_decoded = Vec::new(); | |
307 | ||
308 | // encodes with padding, requires that padding be present so we don't get InvalidPadding | |
309 | // just because padding is there at all | |
310 | let engine = STANDARD; | |
311 | ||
312 | for _ in 0..10_000 { | |
313 | bytes.clear(); | |
314 | b64.clear(); | |
315 | reader_decoded.clear(); | |
316 | bulk_decoded.clear(); | |
317 | ||
318 | // at least 2 bytes so there can be a split point between bytes | |
319 | let size = rng.gen_range(2..(10 * BUF_SIZE)); | |
320 | bytes.resize(size, 0); | |
321 | rng.fill_bytes(&mut bytes[..size]); | |
322 | ||
323 | // Concatenate two valid b64s, yielding padding in the middle. | |
324 | // This avoids scenarios that are challenging to assert on, like random padding location | |
325 | // that might be InvalidLastSymbol when decoded at certain buffer sizes but InvalidByte | |
326 | // when done all at once. | |
327 | let split = loop { | |
328 | // find a split point that will produce padding on the first part | |
329 | let s = rng.gen_range(1..size); | |
330 | if s % 3 != 0 { | |
331 | // short enough to need padding | |
332 | break s; | |
333 | }; | |
334 | }; | |
335 | ||
336 | engine.encode_string(&bytes[..split], &mut b64); | |
337 | assert!(b64.contains('='), "split: {}, b64: {}", split, b64); | |
338 | let bad_byte_pos = b64.find('=').unwrap(); | |
339 | engine.encode_string(&bytes[split..], &mut b64); | |
340 | let b64_bytes = b64.as_bytes(); | |
341 | ||
342 | // short read to make it plausible for padding to happen on a read boundary | |
343 | let read_len = rng.gen_range(1..10); | |
344 | let mut wrapped_reader = ShortRead { | |
345 | max_read_len: read_len, | |
346 | delegate: io::Cursor::new(&b64_bytes), | |
347 | }; | |
348 | ||
349 | let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); | |
350 | ||
351 | let read_decode_err = decoder | |
352 | .read_to_end(&mut reader_decoded) | |
353 | .map_err(|e| { | |
354 | *e.into_inner() | |
355 | .and_then(|e| e.downcast::<DecodeError>().ok()) | |
356 | .unwrap() | |
357 | }) | |
358 | .unwrap_err(); | |
359 | ||
360 | let bulk_decode_err = engine.decode_vec(b64_bytes, &mut bulk_decoded).unwrap_err(); | |
361 | ||
362 | assert_eq!( | |
363 | bulk_decode_err, | |
364 | read_decode_err, | |
365 | "read len: {}, bad byte pos: {}, b64: {}", | |
366 | read_len, | |
367 | bad_byte_pos, | |
368 | std::str::from_utf8(b64_bytes).unwrap() | |
369 | ); | |
370 | assert_eq!( | |
371 | DecodeError::InvalidByte( | |
372 | split / 3 * 4 | |
373 | + match split % 3 { | |
374 | 1 => 2, | |
375 | 2 => 3, | |
376 | _ => unreachable!(), | |
377 | }, | |
378 | PAD_BYTE | |
379 | ), | |
380 | read_decode_err | |
381 | ); | |
382 | } | |
383 | } | |
384 | ||
385 | #[test] | |
386 | fn internal_padding_anywhere_error() { | |
387 | let mut rng = rand::thread_rng(); | |
388 | let mut bytes = Vec::new(); | |
389 | let mut b64 = String::new(); | |
390 | let mut reader_decoded = Vec::new(); | |
391 | ||
392 | // encodes with padding, requires that padding be present so we don't get InvalidPadding | |
393 | // just because padding is there at all | |
394 | let engine = STANDARD; | |
395 | ||
396 | for _ in 0..10_000 { | |
397 | bytes.clear(); | |
398 | b64.clear(); | |
399 | reader_decoded.clear(); | |
400 | ||
401 | bytes.resize(10 * BUF_SIZE, 0); | |
402 | rng.fill_bytes(&mut bytes[..]); | |
403 | ||
404 | // Just shove a padding byte in there somewhere. | |
405 | // The specific error to expect is challenging to predict precisely because it | |
406 | // will vary based on the position of the padding in the quad and the read buffer | |
407 | // length, but SOMETHING should go wrong. | |
408 | ||
409 | engine.encode_string(&bytes[..], &mut b64); | |
410 | let mut b64_bytes = b64.as_bytes().to_vec(); | |
411 | // put padding somewhere other than the last quad | |
412 | b64_bytes[rng.gen_range(0..bytes.len() - 4)] = PAD_BYTE; | |
413 | ||
414 | // short read to make it plausible for padding to happen on a read boundary | |
415 | let read_len = rng.gen_range(1..10); | |
416 | let mut wrapped_reader = ShortRead { | |
417 | max_read_len: read_len, | |
418 | delegate: io::Cursor::new(&b64_bytes), | |
419 | }; | |
420 | ||
421 | let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); | |
422 | ||
423 | let result = decoder.read_to_end(&mut reader_decoded); | |
424 | assert!(result.is_err()); | |
425 | } | |
426 | } | |
427 | ||
0a29b90c FG |
428 | fn consume_with_short_reads_and_validate<R: io::Read>( |
429 | rng: &mut rand::rngs::ThreadRng, | |
430 | expected_bytes: &[u8], | |
431 | decoded: &mut [u8], | |
432 | short_reader: &mut R, | |
433 | ) { | |
434 | let mut total_read = 0_usize; | |
435 | loop { | |
436 | assert!( | |
437 | total_read <= expected_bytes.len(), | |
438 | "tr {} size {}", | |
439 | total_read, | |
440 | expected_bytes.len() | |
441 | ); | |
442 | if total_read == expected_bytes.len() { | |
443 | assert_eq!(expected_bytes, &decoded[..total_read]); | |
444 | // should be done | |
445 | assert_eq!(0, short_reader.read(&mut *decoded).unwrap()); | |
446 | // didn't write anything | |
447 | assert_eq!(expected_bytes, &decoded[..total_read]); | |
448 | ||
449 | break; | |
450 | } | |
451 | let decode_len = rng.gen_range(1..cmp::max(2, expected_bytes.len() * 2)); | |
452 | ||
453 | let read = short_reader | |
454 | .read(&mut decoded[total_read..total_read + decode_len]) | |
455 | .unwrap(); | |
456 | total_read += read; | |
457 | } | |
458 | } | |
459 | ||
460 | /// Limits how many bytes a reader will provide in each read call. | |
461 | /// Useful for shaking out code that may work fine only with typical input sources that always fill | |
462 | /// the buffer. | |
463 | struct RandomShortRead<'a, 'b, R: io::Read, N: rand::Rng> { | |
464 | delegate: &'b mut R, | |
465 | rng: &'a mut N, | |
466 | } | |
467 | ||
468 | impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, N> { | |
469 | fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> { | |
470 | // avoid 0 since it means EOF for non-empty buffers | |
471 | let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len()); | |
472 | ||
473 | self.delegate.read(&mut buf[..effective_len]) | |
474 | } | |
475 | } | |
fe692bf9 FG |
476 | |
477 | struct ShortRead<R: io::Read> { | |
478 | delegate: R, | |
479 | max_read_len: usize, | |
480 | } | |
481 | ||
482 | impl<R: io::Read> io::Read for ShortRead<R> { | |
483 | fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { | |
484 | let len = self.max_read_len.max(buf.len()); | |
485 | self.delegate.read(&mut buf[..len]) | |
486 | } | |
487 | } |