blake2b_simd/
guts.rs

1use crate::*;
2use arrayref::array_ref;
3use core::cmp;
4
5#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
6pub const MAX_DEGREE: usize = 4;
7
8#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
9pub const MAX_DEGREE: usize = 1;
10
11// Variants other than Portable are unreachable in no_std, unless CPU features
12// are explicitly enabled for the build with e.g. RUSTFLAGS="-C target-feature=avx2".
13// This might change in the future if is_x86_feature_detected moves into libcore.
14#[allow(dead_code)]
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16enum Platform {
17    Portable,
18    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
19    SSE41,
20    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
21    AVX2,
22}
23
24#[derive(Clone, Copy, Debug)]
25pub struct Implementation(Platform);
26
27impl Implementation {
28    pub fn detect() -> Self {
29        // Try the different implementations in order of how fast/modern they
30        // are. Currently on non-x86, everything just uses portable.
31        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
32        {
33            if let Some(avx2_impl) = Self::avx2_if_supported() {
34                return avx2_impl;
35            }
36        }
37        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38        {
39            if let Some(sse41_impl) = Self::sse41_if_supported() {
40                return sse41_impl;
41            }
42        }
43        Self::portable()
44    }
45
46    pub fn portable() -> Self {
47        Implementation(Platform::Portable)
48    }
49
50    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
51    #[allow(unreachable_code)]
52    pub fn sse41_if_supported() -> Option<Self> {
53        // Check whether SSE4.1 support is assumed by the build.
54        #[cfg(target_feature = "sse4.1")]
55        {
56            return Some(Implementation(Platform::SSE41));
57        }
58        // Otherwise dynamically check for support if we can.
59        #[cfg(feature = "std")]
60        {
61            if is_x86_feature_detected!("sse4.1") {
62                return Some(Implementation(Platform::SSE41));
63            }
64        }
65        None
66    }
67
68    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
69    #[allow(unreachable_code)]
70    pub fn avx2_if_supported() -> Option<Self> {
71        // Check whether AVX2 support is assumed by the build.
72        #[cfg(target_feature = "avx2")]
73        {
74            return Some(Implementation(Platform::AVX2));
75        }
76        // Otherwise dynamically check for support if we can.
77        #[cfg(feature = "std")]
78        {
79            if is_x86_feature_detected!("avx2") {
80                return Some(Implementation(Platform::AVX2));
81            }
82        }
83        None
84    }
85
86    pub fn degree(&self) -> usize {
87        match self.0 {
88            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
89            Platform::AVX2 => avx2::DEGREE,
90            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
91            Platform::SSE41 => sse41::DEGREE,
92            Platform::Portable => 1,
93        }
94    }
95
96    pub fn compress1_loop(
97        &self,
98        input: &[u8],
99        words: &mut [Word; 8],
100        count: Count,
101        last_node: LastNode,
102        finalize: Finalize,
103        stride: Stride,
104    ) {
105        match self.0 {
106            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
107            Platform::AVX2 => unsafe {
108                avx2::compress1_loop(input, words, count, last_node, finalize, stride);
109            },
110            // Note that there's an SSE version of compress1 in the official C
111            // implementation, but I haven't ported it yet.
112            _ => {
113                portable::compress1_loop(input, words, count, last_node, finalize, stride);
114            }
115        }
116    }
117
118    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
119    pub fn compress2_loop(&self, jobs: &mut [Job; 2], finalize: Finalize, stride: Stride) {
120        match self.0 {
121            Platform::AVX2 | Platform::SSE41 => unsafe {
122                sse41::compress2_loop(jobs, finalize, stride)
123            },
124            _ => panic!("unsupported"),
125        }
126    }
127
128    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
129    pub fn compress4_loop(&self, jobs: &mut [Job; 4], finalize: Finalize, stride: Stride) {
130        match self.0 {
131            Platform::AVX2 => unsafe { avx2::compress4_loop(jobs, finalize, stride) },
132            _ => panic!("unsupported"),
133        }
134    }
135}
136
137pub struct Job<'a, 'b> {
138    pub input: &'a [u8],
139    pub words: &'b mut [Word; 8],
140    pub count: Count,
141    pub last_node: LastNode,
142}
143
144impl<'a, 'b> core::fmt::Debug for Job<'a, 'b> {
145    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146        // NB: Don't print the words. Leaking them would allow length extension.
147        write!(
148            f,
149            "Job {{ input_len: {}, count: {}, last_node: {} }}",
150            self.input.len(),
151            self.count,
152            self.last_node.yes(),
153        )
154    }
155}
156
157// Finalize could just be a bool, but this is easier to read at callsites.
158#[derive(Clone, Copy, Debug)]
159pub enum Finalize {
160    Yes,
161    No,
162}
163
164impl Finalize {
165    pub fn yes(&self) -> bool {
166        match self {
167            Finalize::Yes => true,
168            Finalize::No => false,
169        }
170    }
171}
172
173// Like Finalize, this is easier to read at callsites.
174#[derive(Clone, Copy, Debug)]
175pub enum LastNode {
176    Yes,
177    No,
178}
179
180impl LastNode {
181    pub fn yes(&self) -> bool {
182        match self {
183            LastNode::Yes => true,
184            LastNode::No => false,
185        }
186    }
187}
188
189#[derive(Clone, Copy, Debug)]
190pub enum Stride {
191    Serial,   // BLAKE2b/BLAKE2s
192    Parallel, // BLAKE2bp/BLAKE2sp
193}
194
195impl Stride {
196    pub fn padded_blockbytes(&self) -> usize {
197        match self {
198            Stride::Serial => BLOCKBYTES,
199            Stride::Parallel => blake2bp::DEGREE * BLOCKBYTES,
200        }
201    }
202}
203
204pub(crate) fn count_low(count: Count) -> Word {
205    count as Word
206}
207
208pub(crate) fn count_high(count: Count) -> Word {
209    (count >> 8 * size_of::<Word>()) as Word
210}
211
212#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
213pub(crate) fn assemble_count(low: Word, high: Word) -> Count {
214    low as Count + ((high as Count) << 8 * size_of::<Word>())
215}
216
217pub(crate) fn flag_word(flag: bool) -> Word {
218    if flag {
219        !0
220    } else {
221        0
222    }
223}
224
225// Pull a array reference at the given offset straight from the input, if
226// there's a full block of input available. If there's only a partial block,
227// copy it into the provided buffer, and return an array reference that. Along
228// with the array, return the number of bytes of real input, and whether the
229// input can be finalized (i.e. whether there aren't any more bytes after this
230// block). Note that this is written so that the optimizer can elide bounds
231// checks, see: https://godbolt.org/z/0hH2bC
232pub fn final_block<'a>(
233    input: &'a [u8],
234    offset: usize,
235    buffer: &'a mut [u8; BLOCKBYTES],
236    stride: Stride,
237) -> (&'a [u8; BLOCKBYTES], usize, bool) {
238    let capped_offset = cmp::min(offset, input.len());
239    let offset_slice = &input[capped_offset..];
240    if offset_slice.len() >= BLOCKBYTES {
241        let block = array_ref!(offset_slice, 0, BLOCKBYTES);
242        let should_finalize = offset_slice.len() <= stride.padded_blockbytes();
243        (block, BLOCKBYTES, should_finalize)
244    } else {
245        // Copy the final block to the front of the block buffer. The rest of
246        // the buffer is assumed to be initialized to zero.
247        buffer[..offset_slice.len()].copy_from_slice(offset_slice);
248        (buffer, offset_slice.len(), true)
249    }
250}
251
252pub fn input_debug_asserts(input: &[u8], finalize: Finalize) {
253    // If we're not finalizing, the input must not be empty, and it must be an
254    // even multiple of the block size.
255    if !finalize.yes() {
256        debug_assert!(!input.is_empty());
257        debug_assert_eq!(0, input.len() % BLOCKBYTES);
258    }
259}
260
261#[cfg(test)]
262mod test {
263    use super::*;
264    use core::mem::size_of;
265
266    #[test]
267    fn test_detection() {
268        assert_eq!(Platform::Portable, Implementation::portable().0);
269
270        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
271        #[cfg(feature = "std")]
272        {
273            if is_x86_feature_detected!("avx2") {
274                assert_eq!(Platform::AVX2, Implementation::detect().0);
275                assert_eq!(
276                    Platform::AVX2,
277                    Implementation::avx2_if_supported().unwrap().0
278                );
279                assert_eq!(
280                    Platform::SSE41,
281                    Implementation::sse41_if_supported().unwrap().0
282                );
283            } else if is_x86_feature_detected!("sse4.1") {
284                assert_eq!(Platform::SSE41, Implementation::detect().0);
285                assert!(Implementation::avx2_if_supported().is_none());
286                assert_eq!(
287                    Platform::SSE41,
288                    Implementation::sse41_if_supported().unwrap().0
289                );
290            } else {
291                assert_eq!(Platform::Portable, Implementation::detect().0);
292                assert!(Implementation::avx2_if_supported().is_none());
293                assert!(Implementation::sse41_if_supported().is_none());
294            }
295        }
296    }
297
298    // TODO: Move all of these case tests into the implementation files.
299    fn exercise_cases<F>(mut f: F)
300    where
301        F: FnMut(Stride, usize, LastNode, Finalize, Count),
302    {
303        // Chose counts to hit the relevant overflow cases.
304        let counts = &[
305            (0 as Count),
306            ((1 as Count) << (8 * size_of::<Word>())) - BLOCKBYTES as Count,
307            (0 as Count).wrapping_sub(BLOCKBYTES as Count),
308        ];
309        for &stride in &[Stride::Serial, Stride::Parallel] {
310            let lengths = [
311                0,
312                1,
313                BLOCKBYTES - 1,
314                BLOCKBYTES,
315                BLOCKBYTES + 1,
316                2 * BLOCKBYTES - 1,
317                2 * BLOCKBYTES,
318                2 * BLOCKBYTES + 1,
319                stride.padded_blockbytes() - 1,
320                stride.padded_blockbytes(),
321                stride.padded_blockbytes() + 1,
322                2 * stride.padded_blockbytes() - 1,
323                2 * stride.padded_blockbytes(),
324                2 * stride.padded_blockbytes() + 1,
325            ];
326            for &length in &lengths {
327                for &last_node in &[LastNode::No, LastNode::Yes] {
328                    for &finalize in &[Finalize::No, Finalize::Yes] {
329                        if !finalize.yes() && (length == 0 || length % BLOCKBYTES != 0) {
330                            // Skip these cases, they're invalid.
331                            continue;
332                        }
333                        for &count in counts {
334                            // eprintln!("\ncase -----");
335                            // dbg!(stride);
336                            // dbg!(length);
337                            // dbg!(last_node);
338                            // dbg!(finalize);
339                            // dbg!(count);
340
341                            f(stride, length, last_node, finalize, count);
342                        }
343                    }
344                }
345            }
346        }
347    }
348
349    fn initial_test_words(input_index: usize) -> [Word; 8] {
350        crate::Params::new()
351            .node_offset(input_index as u64)
352            .to_words()
353    }
354
355    // Use the portable implementation, one block at a time, to compute the
356    // final state words expected for a given test case.
357    fn reference_compression(
358        input: &[u8],
359        stride: Stride,
360        last_node: LastNode,
361        finalize: Finalize,
362        mut count: Count,
363        input_index: usize,
364    ) -> [Word; 8] {
365        let mut words = initial_test_words(input_index);
366        let mut offset = 0;
367        while offset == 0 || offset < input.len() {
368            let block_size = cmp::min(BLOCKBYTES, input.len() - offset);
369            let maybe_finalize = if offset + stride.padded_blockbytes() < input.len() {
370                Finalize::No
371            } else {
372                finalize
373            };
374            portable::compress1_loop(
375                &input[offset..][..block_size],
376                &mut words,
377                count,
378                last_node,
379                maybe_finalize,
380                Stride::Serial,
381            );
382            offset += stride.padded_blockbytes();
383            count = count.wrapping_add(BLOCKBYTES as Count);
384        }
385        words
386    }
387
388    // For various loop lengths and finalization parameters, make sure that the
389    // implementation gives the same answer as the portable implementation does
390    // when invoked one block at a time. (So even the portable implementation
391    // itself is being tested here, to make sure its loop is correct.) Note
392    // that this doesn't include any fixed test vectors; those are taken from
393    // the blake2-kat.json file (copied from upstream) and tested elsewhere.
394    fn exercise_compress1_loop(implementation: Implementation) {
395        let mut input = [0; 100 * BLOCKBYTES];
396        paint_test_input(&mut input);
397
398        exercise_cases(|stride, length, last_node, finalize, count| {
399            let reference_words =
400                reference_compression(&input[..length], stride, last_node, finalize, count, 0);
401
402            let mut test_words = initial_test_words(0);
403            implementation.compress1_loop(
404                &input[..length],
405                &mut test_words,
406                count,
407                last_node,
408                finalize,
409                stride,
410            );
411            assert_eq!(reference_words, test_words);
412        });
413    }
414
415    #[test]
416    fn test_compress1_loop_portable() {
417        exercise_compress1_loop(Implementation::portable());
418    }
419
420    #[test]
421    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
422    fn test_compress1_loop_sse41() {
423        // Currently this just falls back to portable, but we test it anyway.
424        if let Some(imp) = Implementation::sse41_if_supported() {
425            exercise_compress1_loop(imp);
426        }
427    }
428
429    #[test]
430    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
431    fn test_compress1_loop_avx2() {
432        if let Some(imp) = Implementation::avx2_if_supported() {
433            exercise_compress1_loop(imp);
434        }
435    }
436
437    // I use ArrayVec everywhere in here becuase currently these tests pass
438    // under no_std. I might decide that's not worth maintaining at some point,
439    // since really all we care about with no_std is that the library builds,
440    // but for now it's here. Everything is keyed off of this N constant so
441    // that it's easy to copy the code to exercise_compress4_loop.
442    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
443    fn exercise_compress2_loop(implementation: Implementation) {
444        const N: usize = 2;
445
446        let mut input_buffer = [0; 100 * BLOCKBYTES];
447        paint_test_input(&mut input_buffer);
448        let mut inputs = arrayvec::ArrayVec::<_, N>::new();
449        for i in 0..N {
450            inputs.push(&input_buffer[i..]);
451        }
452
453        exercise_cases(|stride, length, last_node, finalize, count| {
454            let mut reference_words = arrayvec::ArrayVec::<_, N>::new();
455            for i in 0..N {
456                let words = reference_compression(
457                    &inputs[i][..length],
458                    stride,
459                    last_node,
460                    finalize,
461                    count.wrapping_add((i * BLOCKBYTES) as Count),
462                    i,
463                );
464                reference_words.push(words);
465            }
466
467            let mut test_words = arrayvec::ArrayVec::<_, N>::new();
468            for i in 0..N {
469                test_words.push(initial_test_words(i));
470            }
471            let mut jobs = arrayvec::ArrayVec::<_, N>::new();
472            for (i, words) in test_words.iter_mut().enumerate() {
473                jobs.push(Job {
474                    input: &inputs[i][..length],
475                    words,
476                    count: count.wrapping_add((i * BLOCKBYTES) as Count),
477                    last_node,
478                });
479            }
480            let mut jobs = jobs.into_inner().expect("full");
481            implementation.compress2_loop(&mut jobs, finalize, stride);
482
483            for i in 0..N {
484                assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
485            }
486        });
487    }
488
489    #[test]
490    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
491    fn test_compress2_loop_sse41() {
492        if let Some(imp) = Implementation::sse41_if_supported() {
493            exercise_compress2_loop(imp);
494        }
495    }
496
497    #[test]
498    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
499    fn test_compress2_loop_avx2() {
500        // Currently this just falls back to SSE4.1, but we test it anyway.
501        if let Some(imp) = Implementation::avx2_if_supported() {
502            exercise_compress2_loop(imp);
503        }
504    }
505
506    // Copied from exercise_compress2_loop, with a different value of N and an
507    // interior call to compress4_loop.
508    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
509    fn exercise_compress4_loop(implementation: Implementation) {
510        const N: usize = 4;
511
512        let mut input_buffer = [0; 100 * BLOCKBYTES];
513        paint_test_input(&mut input_buffer);
514        let mut inputs = arrayvec::ArrayVec::<_, N>::new();
515        for i in 0..N {
516            inputs.push(&input_buffer[i..]);
517        }
518
519        exercise_cases(|stride, length, last_node, finalize, count| {
520            let mut reference_words = arrayvec::ArrayVec::<_, N>::new();
521            for i in 0..N {
522                let words = reference_compression(
523                    &inputs[i][..length],
524                    stride,
525                    last_node,
526                    finalize,
527                    count.wrapping_add((i * BLOCKBYTES) as Count),
528                    i,
529                );
530                reference_words.push(words);
531            }
532
533            let mut test_words = arrayvec::ArrayVec::<_, N>::new();
534            for i in 0..N {
535                test_words.push(initial_test_words(i));
536            }
537            let mut jobs = arrayvec::ArrayVec::<_, N>::new();
538            for (i, words) in test_words.iter_mut().enumerate() {
539                jobs.push(Job {
540                    input: &inputs[i][..length],
541                    words,
542                    count: count.wrapping_add((i * BLOCKBYTES) as Count),
543                    last_node,
544                });
545            }
546            let mut jobs = jobs.into_inner().expect("full");
547            implementation.compress4_loop(&mut jobs, finalize, stride);
548
549            for i in 0..N {
550                assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
551            }
552        });
553    }
554
555    #[test]
556    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
557    fn test_compress4_loop_avx2() {
558        if let Some(imp) = Implementation::avx2_if_supported() {
559            exercise_compress4_loop(imp);
560        }
561    }
562
563    #[test]
564    fn sanity_check_count_size() {
565        assert_eq!(size_of::<Count>(), 2 * size_of::<Word>());
566    }
567}