p3_keccak/
sse2.rs

1use core::arch::x86_64::{
2    __m128i, _mm_add_epi64, _mm_andnot_si128, _mm_or_si128, _mm_slli_epi64, _mm_srli_epi64,
3    _mm_xor_si128,
4};
5use core::mem::transmute;
6
7use p3_symmetric::{CryptographicPermutation, Permutation};
8
9use crate::KeccakF;
10
11pub const VECTOR_LEN: usize = 2;
12
13const RC: [__m128i; 24] = unsafe {
14    transmute([
15        [1u64; 2],
16        [0x8082u64; 2],
17        [0x800000000000808au64; 2],
18        [0x8000000080008000u64; 2],
19        [0x808bu64; 2],
20        [0x80000001u64; 2],
21        [0x8000000080008081u64; 2],
22        [0x8000000000008009u64; 2],
23        [0x8au64; 2],
24        [0x88u64; 2],
25        [0x80008009u64; 2],
26        [0x8000000au64; 2],
27        [0x8000808bu64; 2],
28        [0x800000000000008bu64; 2],
29        [0x8000000000008089u64; 2],
30        [0x8000000000008003u64; 2],
31        [0x8000000000008002u64; 2],
32        [0x8000000000000080u64; 2],
33        [0x800au64; 2],
34        [0x800000008000000au64; 2],
35        [0x8000000080008081u64; 2],
36        [0x8000000000008080u64; 2],
37        [0x80000001u64; 2],
38        [0x8000000080008008u64; 2],
39    ])
40};
41
42#[inline(always)]
43fn form_matrix(buf: [__m128i; 25]) -> [[__m128i; 5]; 5] {
44    unsafe { transmute(buf) }
45}
46
47#[inline(always)]
48fn flatten(mat: [[__m128i; 5]; 5]) -> [__m128i; 25] {
49    unsafe { transmute(mat) }
50}
51
52#[inline(always)]
53fn rol_1(a: __m128i) -> __m128i {
54    unsafe {
55        let shl = _mm_add_epi64(a, a);
56        let shr = _mm_srli_epi64::<63>(a);
57        _mm_or_si128(shl, shr)
58    }
59}
60
61#[cfg(target_feature = "ssse3")]
62#[inline(always)]
63fn rol_8(a: __m128i) -> __m128i {
64    use core::arch::x86_64::_mm_shuffle_epi8;
65    const ROL_8_CTRL: __m128i = unsafe {
66        transmute::<[u8; 16], _>([
67            0o07, 0o00, 0o01, 0o02, 0o03, 0o04, 0o05, 0o06, 0o17, 0o10, 0o11, 0o12, 0o13, 0o14,
68            0o15, 0o16,
69        ])
70    };
71    unsafe { _mm_shuffle_epi8(a, ROL_8_CTRL) }
72}
73
74#[cfg(not(target_feature = "ssse3"))]
75#[inline(always)]
76fn rol_8(a: __m128i) -> __m128i {
77    rol::<8, { 64 - 8 }>(a)
78}
79
80#[cfg(target_feature = "ssse3")]
81#[inline(always)]
82fn rol_56(a: __m128i) -> __m128i {
83    use core::arch::x86_64::_mm_shuffle_epi8;
84    const ROL_56_CTRL: __m128i = unsafe {
85        transmute::<[u8; 16], _>([
86            0o01, 0o02, 0o03, 0o04, 0o05, 0o06, 0o07, 0o00, 0o11, 0o12, 0o13, 0o14, 0o15, 0o16,
87            0o17, 0o10,
88        ])
89    };
90    unsafe { _mm_shuffle_epi8(a, ROL_56_CTRL) }
91}
92
93#[cfg(not(target_feature = "ssse3"))]
94#[inline(always)]
95fn rol_56(a: __m128i) -> __m128i {
96    rol::<56, { 64 - 56 }>(a)
97}
98
99#[inline(always)]
100fn rol<const SHL_AMT: i32, const SHR_AMT: i32>(a: __m128i) -> __m128i {
101    unsafe {
102        let shl = _mm_slli_epi64::<SHL_AMT>(a);
103        let shr = _mm_srli_epi64::<SHR_AMT>(a);
104        _mm_or_si128(shl, shr)
105    }
106}
107
108#[inline(always)]
109fn get_theta_parities(state: [[__m128i; 5]; 5]) -> [__m128i; 5] {
110    unsafe {
111        let mut par0 = _mm_xor_si128(state[0][0], state[1][0]);
112        let mut par1 = _mm_xor_si128(state[0][1], state[1][1]);
113        let mut par2 = _mm_xor_si128(state[0][2], state[1][2]);
114        let mut par3 = _mm_xor_si128(state[0][3], state[1][3]);
115        let mut par4 = _mm_xor_si128(state[0][4], state[1][4]);
116
117        par0 = _mm_xor_si128(par0, state[2][0]);
118        par1 = _mm_xor_si128(par1, state[2][1]);
119        par2 = _mm_xor_si128(par2, state[2][2]);
120        par3 = _mm_xor_si128(par3, state[2][3]);
121        par4 = _mm_xor_si128(par4, state[2][4]);
122
123        par0 = _mm_xor_si128(par0, state[3][0]);
124        par1 = _mm_xor_si128(par1, state[3][1]);
125        par2 = _mm_xor_si128(par2, state[3][2]);
126        par3 = _mm_xor_si128(par3, state[3][3]);
127        par4 = _mm_xor_si128(par4, state[3][4]);
128
129        par0 = _mm_xor_si128(par0, state[4][0]);
130        par1 = _mm_xor_si128(par1, state[4][1]);
131        par2 = _mm_xor_si128(par2, state[4][2]);
132        par3 = _mm_xor_si128(par3, state[4][3]);
133        par4 = _mm_xor_si128(par4, state[4][4]);
134
135        [
136            _mm_xor_si128(par4, rol_1(par1)),
137            _mm_xor_si128(par0, rol_1(par2)),
138            _mm_xor_si128(par1, rol_1(par3)),
139            _mm_xor_si128(par2, rol_1(par4)),
140            _mm_xor_si128(par3, rol_1(par0)),
141        ]
142    }
143}
144
145#[inline(always)]
146fn theta(state: [[__m128i; 5]; 5]) -> [[__m128i; 5]; 5] {
147    let theta_parities = get_theta_parities(state);
148
149    unsafe {
150        [
151            [
152                _mm_xor_si128(state[0][0], theta_parities[0]),
153                _mm_xor_si128(state[0][1], theta_parities[1]),
154                _mm_xor_si128(state[0][2], theta_parities[2]),
155                _mm_xor_si128(state[0][3], theta_parities[3]),
156                _mm_xor_si128(state[0][4], theta_parities[4]),
157            ],
158            [
159                _mm_xor_si128(state[1][0], theta_parities[0]),
160                _mm_xor_si128(state[1][1], theta_parities[1]),
161                _mm_xor_si128(state[1][2], theta_parities[2]),
162                _mm_xor_si128(state[1][3], theta_parities[3]),
163                _mm_xor_si128(state[1][4], theta_parities[4]),
164            ],
165            [
166                _mm_xor_si128(state[2][0], theta_parities[0]),
167                _mm_xor_si128(state[2][1], theta_parities[1]),
168                _mm_xor_si128(state[2][2], theta_parities[2]),
169                _mm_xor_si128(state[2][3], theta_parities[3]),
170                _mm_xor_si128(state[2][4], theta_parities[4]),
171            ],
172            [
173                _mm_xor_si128(state[3][0], theta_parities[0]),
174                _mm_xor_si128(state[3][1], theta_parities[1]),
175                _mm_xor_si128(state[3][2], theta_parities[2]),
176                _mm_xor_si128(state[3][3], theta_parities[3]),
177                _mm_xor_si128(state[3][4], theta_parities[4]),
178            ],
179            [
180                _mm_xor_si128(state[4][0], theta_parities[0]),
181                _mm_xor_si128(state[4][1], theta_parities[1]),
182                _mm_xor_si128(state[4][2], theta_parities[2]),
183                _mm_xor_si128(state[4][3], theta_parities[3]),
184                _mm_xor_si128(state[4][4], theta_parities[4]),
185            ],
186        ]
187    }
188}
189
190#[inline(always)]
191fn rho(state: [[__m128i; 5]; 5]) -> [[__m128i; 5]; 5] {
192    [
193        [
194            state[0][0],
195            rol_1(state[0][1]),
196            rol::<62, { 64 - 62 }>(state[0][2]),
197            rol::<28, { 64 - 28 }>(state[0][3]),
198            rol::<27, { 64 - 27 }>(state[0][4]),
199        ],
200        [
201            rol::<36, { 64 - 36 }>(state[1][0]),
202            rol::<44, { 64 - 44 }>(state[1][1]),
203            rol::<6, { 64 - 6 }>(state[1][2]),
204            rol::<55, { 64 - 55 }>(state[1][3]),
205            rol::<20, { 64 - 20 }>(state[1][4]),
206        ],
207        [
208            rol::<3, { 64 - 3 }>(state[2][0]),
209            rol::<10, { 64 - 10 }>(state[2][1]),
210            rol::<43, { 64 - 43 }>(state[2][2]),
211            rol::<25, { 64 - 25 }>(state[2][3]),
212            rol::<39, { 64 - 39 }>(state[2][4]),
213        ],
214        [
215            rol::<41, { 64 - 41 }>(state[3][0]),
216            rol::<45, { 64 - 45 }>(state[3][1]),
217            rol::<15, { 64 - 15 }>(state[3][2]),
218            rol::<21, { 64 - 21 }>(state[3][3]),
219            rol_8(state[3][4]),
220        ],
221        [
222            rol::<18, { 64 - 18 }>(state[4][0]),
223            rol::<2, { 64 - 2 }>(state[4][1]),
224            rol::<61, { 64 - 61 }>(state[4][2]),
225            rol_56(state[4][3]),
226            rol::<14, { 64 - 14 }>(state[4][4]),
227        ],
228    ]
229}
230
231#[inline(always)]
232fn pi(state: [[__m128i; 5]; 5]) -> [[__m128i; 5]; 5] {
233    [
234        [
235            state[0][0],
236            state[1][1],
237            state[2][2],
238            state[3][3],
239            state[4][4],
240        ],
241        [
242            state[0][3],
243            state[1][4],
244            state[2][0],
245            state[3][1],
246            state[4][2],
247        ],
248        [
249            state[0][1],
250            state[1][2],
251            state[2][3],
252            state[3][4],
253            state[4][0],
254        ],
255        [
256            state[0][4],
257            state[1][0],
258            state[2][1],
259            state[3][2],
260            state[4][3],
261        ],
262        [
263            state[0][2],
264            state[1][3],
265            state[2][4],
266            state[3][0],
267            state[4][1],
268        ],
269    ]
270}
271
272#[inline(always)]
273fn chi_row(row: [__m128i; 5]) -> [__m128i; 5] {
274    unsafe {
275        [
276            _mm_xor_si128(row[0], _mm_andnot_si128(row[1], row[2])),
277            _mm_xor_si128(row[1], _mm_andnot_si128(row[2], row[3])),
278            _mm_xor_si128(row[2], _mm_andnot_si128(row[3], row[4])),
279            _mm_xor_si128(row[3], _mm_andnot_si128(row[4], row[0])),
280            _mm_xor_si128(row[4], _mm_andnot_si128(row[0], row[1])),
281        ]
282    }
283}
284
285#[inline(always)]
286fn chi(state: [[__m128i; 5]; 5]) -> [[__m128i; 5]; 5] {
287    [
288        chi_row(state[0]),
289        chi_row(state[1]),
290        chi_row(state[2]),
291        chi_row(state[3]),
292        chi_row(state[4]),
293    ]
294}
295
296#[inline(always)]
297fn iota(i: usize, state: [[__m128i; 5]; 5]) -> [[__m128i; 5]; 5] {
298    let mut res = state;
299    unsafe {
300        res[0][0] = _mm_xor_si128(state[0][0], RC[i]);
301    }
302    res
303}
304
305#[inline(always)]
306fn round(i: usize, state: [__m128i; 25]) -> [__m128i; 25] {
307    let mut state = form_matrix(state);
308    state = theta(state);
309    state = rho(state);
310    state = pi(state);
311    state = chi(state);
312    state = iota(i, state);
313    flatten(state)
314}
315
316fn keccak_perm(buf: &mut [[u64; VECTOR_LEN]; 25]) {
317    let mut state: [__m128i; 25] = unsafe { transmute(*buf) };
318    for i in 0..24 {
319        state = round(i, state);
320    }
321    *buf = unsafe { transmute::<[__m128i; 25], [[u64; VECTOR_LEN]; 25]>(state) };
322}
323
324impl Permutation<[[u64; VECTOR_LEN]; 25]> for KeccakF {
325    fn permute_mut(&self, state: &mut [[u64; VECTOR_LEN]; 25]) {
326        keccak_perm(state);
327    }
328}
329
330impl CryptographicPermutation<[[u64; VECTOR_LEN]; 25]> for KeccakF {}
331
332#[cfg(test)]
333mod tests {
334    use tiny_keccak::keccakf;
335
336    use super::*;
337
338    const STATES: [[u64; 25]; 2] = [
339        [
340            0xc22c4c11dbedc46a,
341            0x317f74268c4f5cd0,
342            0x838719da5aa295b6,
343            0x9e9b17211985a3ba,
344            0x92927b963ce29d69,
345            0xf9a7169e38cc7216,
346            0x639a594d6fbfe341,
347            0x2335ebd8d15777bd,
348            0x44e1abc0d022823b,
349            0xb3657f9d16b36c13,
350            0x26d9217c32b3010a,
351            0x6e73d6e9c7e5bcc8,
352            0x400aa469d130a391,
353            0x1aa7c8a2cb97188a,
354            0xdc3084a09bd0a6e3,
355            0xbcfe3b656841baea,
356            0x325f41887c840166,
357            0x844656e313674bfe,
358            0xd63de8bad19d156c,
359            0x49ef0ac0ab52e147,
360            0x8b92ee811c654ca9,
361            0x42a9310fedf09bda,
362            0x182dbdac03a5358e,
363            0x3b4692ce58af8cb5,
364            0x534da610f01b8fb3,
365        ],
366        [
367            0x1c322ff4aea07d26,
368            0xbd67bde061c97612,
369            0x517565bd02ab410a,
370            0xb251273ddc12a725,
371            0x24f0979fe4f4fedc,
372            0xc32d063a64f0bf03,
373            0xd33c6709a7b103d2,
374            0xaf33a8224b5c8828,
375            0x6544ca066e997f1c,
376            0xd53ad41e39f06d68,
377            0x67695f6fb71d77d9,
378            0xd6378cf19ee510f2,
379            0x49472ea57abcbd08,
380            0xcf3739df1eefbbb4,
381            0x0fac1bf30e8ef101,
382            0x7ff04c9b90de0f27,
383            0xf3d63ec0e64cb2ab,
384            0x76388c05f377d4bd,
385            0x7886dd8f5b14ef5b,
386            0xb036d289ba24a513,
387            0x011e8fd6be65a408,
388            0x695e2d20848eec67,
389            0x31f9e80c5f45f8ee,
390            0xcdf873daf7a5fdeb,
391            0xfe98ff5bf28d560a,
392        ],
393    ];
394
395    #[allow(clippy::needless_range_loop)]
396    fn our_res() -> [[u64; 25]; 2] {
397        let mut packed_result = [[0; 2]; 25];
398        for i in 0..25 {
399            packed_result[i] = [STATES[0][i], STATES[1][i]];
400        }
401
402        keccak_perm(&mut packed_result);
403
404        let mut result = [[0; 25]; 2];
405        for i in 0..25 {
406            result[0][i] = packed_result[i][0];
407            result[1][i] = packed_result[i][1];
408        }
409        result
410    }
411
412    fn tiny_keccak_res() -> [[u64; 25]; 2] {
413        let mut result = STATES;
414        keccakf(&mut result[0]);
415        keccakf(&mut result[1]);
416        result
417    }
418
419    #[test]
420    fn test_vs_tiny_keccak() {
421        let expected = tiny_keccak_res();
422        let computed = our_res();
423        assert_eq!(expected, computed);
424    }
425}