const_hex/arch/
generic.rs

1use crate::{byte2hex, HEX_DECODE_LUT, NIL};
2
3/// Set to `true` to use `check` + `decode_unchecked` for decoding. Otherwise uses `decode_checked`.
4///
5/// This should be set to `false` if `check` is not specialized.
6#[allow(dead_code)]
7pub(crate) const USE_CHECK_FN: bool = false;
8
9/// Default encoding function.
10///
11/// # Safety
12///
13/// `output` must be a valid pointer to at least `2 * input.len()` bytes.
14pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
15    for (i, byte) in input.iter().enumerate() {
16        let (high, low) = byte2hex::<UPPER>(*byte);
17        unsafe {
18            output.add(i * 2).write(high);
19            output.add(i * 2 + 1).write(low);
20        }
21    }
22}
23
24/// Encodes unaligned chunks of `T` in `input` to `output` using `encode_chunk`.
25///
26/// The remainder is encoded using the generic [`encode`].
27#[inline]
28#[allow(dead_code)]
29pub(crate) unsafe fn encode_unaligned_chunks<const UPPER: bool, T: Copy>(
30    input: &[u8],
31    output: *mut u8,
32    mut encode_chunk: impl FnMut(T) -> (T, T),
33) {
34    let (chunks, remainder) = chunks_unaligned::<T>(input);
35    let n_in_chunks = chunks.len();
36    let chunk_output = output.cast::<T>();
37    for (i, chunk) in chunks.enumerate() {
38        let (lo, hi) = encode_chunk(chunk);
39        unsafe {
40            chunk_output.add(i * 2).write_unaligned(lo);
41            chunk_output.add(i * 2 + 1).write_unaligned(hi);
42        }
43    }
44    let n_out_chunks = n_in_chunks * 2;
45    unsafe { encode::<UPPER>(remainder, unsafe { chunk_output.add(n_out_chunks).cast() }) };
46}
47
48/// Default check function.
49#[inline]
50pub(crate) const fn check(mut input: &[u8]) -> bool {
51    while let &[byte, ref rest @ ..] = input {
52        if HEX_DECODE_LUT[byte as usize] == NIL {
53            return false;
54        }
55        input = rest;
56    }
57    true
58}
59
60/// Runs the given check function on unaligned chunks of `T` in `input`, with the remainder passed
61/// to the generic [`check`].
62#[inline]
63#[allow(dead_code)]
64pub(crate) fn check_unaligned_chunks<T: Copy>(
65    input: &[u8],
66    check_chunk: impl FnMut(T) -> bool,
67) -> bool {
68    let (mut chunks, remainder) = chunks_unaligned(input);
69    chunks.all(check_chunk) && check(remainder)
70}
71
72/// Default checked decoding function.
73///
74/// # Safety
75///
76/// Assumes `output.len() == input.len() / 2`.
77pub(crate) unsafe fn decode_checked(input: &[u8], output: &mut [u8]) -> bool {
78    unsafe { decode_maybe_check::<true>(input, output) }
79}
80
81/// Default unchecked decoding function.
82///
83/// # Safety
84///
85/// Assumes `output.len() == input.len() / 2` and that the input is valid hex.
86pub(crate) unsafe fn decode_unchecked(input: &[u8], output: &mut [u8]) {
87    #[allow(unused_braces)] // False positive on older rust versions.
88    let success = unsafe { decode_maybe_check::<{ cfg!(debug_assertions) }>(input, output) };
89    debug_assert!(success);
90}
91
92/// Default decoding function. Checks input validity if `CHECK` is `true`, otherwise assumes it.
93///
94/// # Safety
95///
96/// Assumes `output.len() == input.len() / 2` and that the input is valid hex if `CHECK` is `true`.
97#[inline(always)]
98unsafe fn decode_maybe_check<const CHECK: bool>(input: &[u8], output: &mut [u8]) -> bool {
99    macro_rules! next {
100        ($var:ident, $i:expr) => {
101            let hex = unsafe { *input.get_unchecked($i) };
102            let $var = HEX_DECODE_LUT[hex as usize];
103            if CHECK {
104                if $var == NIL {
105                    return false;
106                }
107            }
108        };
109    }
110
111    debug_assert_eq!(output.len(), input.len() / 2);
112    let mut i = 0;
113    while i < output.len() {
114        next!(high, i * 2);
115        next!(low, i * 2 + 1);
116        output[i] = high << 4 | low;
117        i += 1;
118    }
119    true
120}
121
122#[inline]
123fn chunks_unaligned<T: Copy>(input: &[u8]) -> (impl ExactSizeIterator<Item = T> + '_, &[u8]) {
124    let chunks = input.chunks_exact(core::mem::size_of::<T>());
125    let remainder = chunks.remainder();
126    (
127        chunks.map(|chunk| unsafe { chunk.as_ptr().cast::<T>().read_unaligned() }),
128        remainder,
129    )
130}