base64/engine/general_purpose/
decode.rs

1use crate::{
2    engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
3    DecodeError, PAD_BYTE,
4};
5
6// decode logic operates on chunks of 8 input bytes without padding
7const INPUT_CHUNK_LEN: usize = 8;
8const DECODED_CHUNK_LEN: usize = 6;
9
10// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
11// 2 bytes of any output u64 should not be counted as written to (but must be available in a
12// slice).
13const DECODED_CHUNK_SUFFIX: usize = 2;
14
15// how many u64's of input to handle at a time
16const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
17
18const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
19
20// includes the trailing 2 bytes for the final u64 write
21const DECODED_BLOCK_LEN: usize =
22    CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
23
24#[doc(hidden)]
25pub struct GeneralPurposeEstimate {
26    /// Total number of decode chunks, including a possibly partial last chunk
27    num_chunks: usize,
28    decoded_len_estimate: usize,
29}
30
31impl GeneralPurposeEstimate {
32    pub(crate) fn new(encoded_len: usize) -> Self {
33        // Formulas that won't overflow
34        Self {
35            num_chunks: encoded_len / INPUT_CHUNK_LEN
36                + (encoded_len % INPUT_CHUNK_LEN > 0) as usize,
37            decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3,
38        }
39    }
40}
41
42impl DecodeEstimate for GeneralPurposeEstimate {
43    fn decoded_len_estimate(&self) -> usize {
44        self.decoded_len_estimate
45    }
46}
47
48/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
49/// Returns the decode metadata, or an error.
50// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
51// inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
52// but this is fragile and the best setting changes with only minor code modifications.
53#[inline]
54pub(crate) fn decode_helper(
55    input: &[u8],
56    estimate: GeneralPurposeEstimate,
57    output: &mut [u8],
58    decode_table: &[u8; 256],
59    decode_allow_trailing_bits: bool,
60    padding_mode: DecodePaddingMode,
61) -> Result<DecodeMetadata, DecodeError> {
62    let remainder_len = input.len() % INPUT_CHUNK_LEN;
63
64    // Because the fast decode loop writes in groups of 8 bytes (unrolled to
65    // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
66    // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
67    // soon enough that there will always be 2 more bytes of valid data written after that loop.
68    let trailing_bytes_to_skip = match remainder_len {
69        // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
70        // and the fast decode logic cannot handle padding
71        0 => INPUT_CHUNK_LEN,
72        // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
73        1 | 5 => {
74            // trailing whitespace is so common that it's worth it to check the last byte to
75            // possibly return a better error message
76            if let Some(b) = input.last() {
77                if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE {
78                    return Err(DecodeError::InvalidByte(input.len() - 1, *b));
79                }
80            }
81
82            return Err(DecodeError::InvalidLength);
83        }
84        // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
85        // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
86        // previous chunk.
87        2 => INPUT_CHUNK_LEN + 2,
88        // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this
89        // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
90        // with an error, not panic from going past the bounds of the output slice, so we let it
91        // use stage 3 + 4.
92        3 => INPUT_CHUNK_LEN + 3,
93        // This can also decode to one output byte because it may be 2 input chars + 2 padding
94        // chars, which would decode to 1 byte.
95        4 => INPUT_CHUNK_LEN + 4,
96        // Everything else is a legal decode len (given that we don't require padding), and will
97        // decode to at least 2 bytes of output.
98        _ => remainder_len,
99    };
100
101    // rounded up to include partial chunks
102    let mut remaining_chunks = estimate.num_chunks;
103
104    let mut input_index = 0;
105    let mut output_index = 0;
106
107    {
108        let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
109
110        // Fast loop, stage 1
111        // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
112        if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
113            while input_index <= max_start_index {
114                let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
115                let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
116
117                decode_chunk(
118                    &input_slice[0..],
119                    input_index,
120                    decode_table,
121                    &mut output_slice[0..],
122                )?;
123                decode_chunk(
124                    &input_slice[8..],
125                    input_index + 8,
126                    decode_table,
127                    &mut output_slice[6..],
128                )?;
129                decode_chunk(
130                    &input_slice[16..],
131                    input_index + 16,
132                    decode_table,
133                    &mut output_slice[12..],
134                )?;
135                decode_chunk(
136                    &input_slice[24..],
137                    input_index + 24,
138                    decode_table,
139                    &mut output_slice[18..],
140                )?;
141
142                input_index += INPUT_BLOCK_LEN;
143                output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
144                remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
145            }
146        }
147
148        // Fast loop, stage 2 (aka still pretty fast loop)
149        // 8 bytes at a time for whatever we didn't do in stage 1.
150        if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
151            while input_index < max_start_index {
152                decode_chunk(
153                    &input[input_index..(input_index + INPUT_CHUNK_LEN)],
154                    input_index,
155                    decode_table,
156                    &mut output
157                        [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
158                )?;
159
160                output_index += DECODED_CHUNK_LEN;
161                input_index += INPUT_CHUNK_LEN;
162                remaining_chunks -= 1;
163            }
164        }
165    }
166
167    // Stage 3
168    // If input length was such that a chunk had to be deferred until after the fast loop
169    // because decoding it would have produced 2 trailing bytes that wouldn't then be
170    // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
171    // trailing bytes.
172    // However, we still need to avoid the last chunk (partial or complete) because it could
173    // have padding, so we always do 1 fewer to avoid the last chunk.
174    for _ in 1..remaining_chunks {
175        decode_chunk_precise(
176            &input[input_index..],
177            input_index,
178            decode_table,
179            &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
180        )?;
181
182        input_index += INPUT_CHUNK_LEN;
183        output_index += DECODED_CHUNK_LEN;
184    }
185
186    // always have one more (possibly partial) block of 8 input
187    debug_assert!(input.len() - input_index > 1 || input.is_empty());
188    debug_assert!(input.len() - input_index <= 8);
189
190    super::decode_suffix::decode_suffix(
191        input,
192        input_index,
193        output,
194        output_index,
195        decode_table,
196        decode_allow_trailing_bits,
197        padding_mode,
198    )
199}
200
201/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
202/// first 6 of those contain meaningful data.
203///
204/// `input` is the bytes to decode, of which the first 8 bytes will be processed.
205/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
206/// accurately)
207/// `decode_table` is the lookup table for the particular base64 alphabet.
208/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
209/// data.
210// yes, really inline (worth 30-50% speedup)
211#[inline(always)]
212fn decode_chunk(
213    input: &[u8],
214    index_at_start_of_input: usize,
215    decode_table: &[u8; 256],
216    output: &mut [u8],
217) -> Result<(), DecodeError> {
218    let morsel = decode_table[input[0] as usize];
219    if morsel == INVALID_VALUE {
220        return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
221    }
222    let mut accum = (morsel as u64) << 58;
223
224    let morsel = decode_table[input[1] as usize];
225    if morsel == INVALID_VALUE {
226        return Err(DecodeError::InvalidByte(
227            index_at_start_of_input + 1,
228            input[1],
229        ));
230    }
231    accum |= (morsel as u64) << 52;
232
233    let morsel = decode_table[input[2] as usize];
234    if morsel == INVALID_VALUE {
235        return Err(DecodeError::InvalidByte(
236            index_at_start_of_input + 2,
237            input[2],
238        ));
239    }
240    accum |= (morsel as u64) << 46;
241
242    let morsel = decode_table[input[3] as usize];
243    if morsel == INVALID_VALUE {
244        return Err(DecodeError::InvalidByte(
245            index_at_start_of_input + 3,
246            input[3],
247        ));
248    }
249    accum |= (morsel as u64) << 40;
250
251    let morsel = decode_table[input[4] as usize];
252    if morsel == INVALID_VALUE {
253        return Err(DecodeError::InvalidByte(
254            index_at_start_of_input + 4,
255            input[4],
256        ));
257    }
258    accum |= (morsel as u64) << 34;
259
260    let morsel = decode_table[input[5] as usize];
261    if morsel == INVALID_VALUE {
262        return Err(DecodeError::InvalidByte(
263            index_at_start_of_input + 5,
264            input[5],
265        ));
266    }
267    accum |= (morsel as u64) << 28;
268
269    let morsel = decode_table[input[6] as usize];
270    if morsel == INVALID_VALUE {
271        return Err(DecodeError::InvalidByte(
272            index_at_start_of_input + 6,
273            input[6],
274        ));
275    }
276    accum |= (morsel as u64) << 22;
277
278    let morsel = decode_table[input[7] as usize];
279    if morsel == INVALID_VALUE {
280        return Err(DecodeError::InvalidByte(
281            index_at_start_of_input + 7,
282            input[7],
283        ));
284    }
285    accum |= (morsel as u64) << 16;
286
287    write_u64(output, accum);
288
289    Ok(())
290}
291
292/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
293/// trailing garbage bytes.
294#[inline]
295fn decode_chunk_precise(
296    input: &[u8],
297    index_at_start_of_input: usize,
298    decode_table: &[u8; 256],
299    output: &mut [u8],
300) -> Result<(), DecodeError> {
301    let mut tmp_buf = [0_u8; 8];
302
303    decode_chunk(
304        input,
305        index_at_start_of_input,
306        decode_table,
307        &mut tmp_buf[..],
308    )?;
309
310    output[0..6].copy_from_slice(&tmp_buf[0..6]);
311
312    Ok(())
313}
314
315#[inline]
316fn write_u64(output: &mut [u8], value: u64) {
317    output[..8].copy_from_slice(&value.to_be_bytes());
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    use crate::engine::general_purpose::STANDARD;
325
326    #[test]
327    fn decode_chunk_precise_writes_only_6_bytes() {
328        let input = b"Zm9vYmFy"; // "foobar"
329        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
330
331        decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
332        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
333    }
334
335    #[test]
336    fn decode_chunk_writes_8_bytes() {
337        let input = b"Zm9vYmFy"; // "foobar"
338        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
339
340        decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
341        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
342    }
343
344    #[test]
345    fn estimate_short_lengths() {
346        for (range, (num_chunks, decoded_len_estimate)) in [
347            (0..=0, (0, 0)),
348            (1..=4, (1, 3)),
349            (5..=8, (1, 6)),
350            (9..=12, (2, 9)),
351            (13..=16, (2, 12)),
352            (17..=20, (3, 15)),
353        ] {
354            for encoded_len in range {
355                let estimate = GeneralPurposeEstimate::new(encoded_len);
356                assert_eq!(num_chunks, estimate.num_chunks);
357                assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate);
358            }
359        }
360    }
361
362    #[test]
363    fn estimate_via_u128_inflation() {
364        // cover both ends of usize
365        (0..1000)
366            .chain(usize::MAX - 1000..=usize::MAX)
367            .for_each(|encoded_len| {
368                // inflate to 128 bit type to be able to safely use the easy formulas
369                let len_128 = encoded_len as u128;
370
371                let estimate = GeneralPurposeEstimate::new(encoded_len);
372                assert_eq!(
373                    ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128))
374                        as usize,
375                    estimate.num_chunks
376                );
377                assert_eq!(
378                    ((len_128 + 3) / 4 * 3) as usize,
379                    estimate.decoded_len_estimate
380                );
381            })
382    }
383}