p3_keccak_air/
generation.rs1use alloc::vec::Vec;
2use core::array;
3use core::mem::transmute;
4
5use p3_air::utils::{u64_to_16_bit_limbs, u64_to_bits_le};
6use p3_field::PrimeField64;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_maybe_rayon::iter::repeat_n;
9use p3_maybe_rayon::prelude::*;
10use tracing::instrument;
11
12use crate::columns::{KeccakCols, NUM_KECCAK_COLS};
13use crate::{NUM_ROUNDS, R, RC, U64_LIMBS};
14
15#[instrument(name = "generate Keccak trace", skip_all)]
17pub fn generate_trace_rows<F: PrimeField64>(
18 inputs: Vec<[u64; 25]>,
19 extra_capacity_bits: usize,
20) -> RowMajorMatrix<F> {
21 let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
22 let trace_length = num_rows * NUM_KECCAK_COLS;
23
24 let mut long_trace = F::zero_vec(trace_length << extra_capacity_bits);
26 long_trace.truncate(trace_length);
27
28 let mut trace = RowMajorMatrix::new(long_trace, NUM_KECCAK_COLS);
29 let (prefix, rows, suffix) = unsafe { trace.values.align_to_mut::<KeccakCols<F>>() };
30 assert!(prefix.is_empty(), "Alignment should match");
31 assert!(suffix.is_empty(), "Alignment should match");
32 assert_eq!(rows.len(), num_rows);
33
34 let num_padding_inputs = num_rows.div_ceil(NUM_ROUNDS) - inputs.len();
35 let padded_inputs = inputs
36 .into_par_iter()
37 .chain(repeat_n([0; 25], num_padding_inputs));
38
39 rows.par_chunks_mut(NUM_ROUNDS)
40 .zip(padded_inputs)
41 .for_each(|(row, input)| {
42 generate_trace_rows_for_perm(row, input);
43 });
44
45 trace
46}
47
48fn generate_trace_rows_for_perm<F: PrimeField64>(rows: &mut [KeccakCols<F>], input: [u64; 25]) {
50 let transmuted: [[u64; 5]; 5] = unsafe { transmute(input) };
55 let mut current_state: [[u64; 5]; 5] = array::from_fn(|x| array::from_fn(|y| transmuted[y][x]));
56
57 let initial_state: [[[F; 4]; 5]; 5] =
59 array::from_fn(|y| array::from_fn(|x| u64_to_16_bit_limbs(current_state[x][y])));
60
61 rows[0].a = initial_state;
63 rows[0].preimage = initial_state;
64
65 generate_trace_row_for_round(&mut rows[0], 0, &mut current_state);
66
67 for round in 1..rows.len() {
68 rows[round].preimage = initial_state;
69
70 for y in 0..5 {
72 for x in 0..5 {
73 for limb in 0..U64_LIMBS {
74 rows[round].a[y][x][limb] = rows[round - 1].a_prime_prime_prime(y, x, limb);
75 }
76 }
77 }
78
79 generate_trace_row_for_round(&mut rows[round], round, &mut current_state);
80 }
81}
82
83fn generate_trace_row_for_round<F: PrimeField64>(
84 row: &mut KeccakCols<F>,
85 round: usize,
86 current_state: &mut [[u64; 5]; 5],
87) {
88 row.step_flags[round] = F::ONE;
89
90 let state_c: [u64; 5] = current_state.map(|row| row.iter().fold(0, |acc, y| acc ^ y));
92 for (x, elem) in state_c.iter().enumerate() {
93 row.c[x] = u64_to_bits_le(*elem);
94 }
95
96 let state_c_prime: [u64; 5] =
98 array::from_fn(|x| state_c[x] ^ state_c[(x + 4) % 5] ^ state_c[(x + 1) % 5].rotate_left(1));
99 for (x, elem) in state_c_prime.iter().enumerate() {
100 row.c_prime[x] = u64_to_bits_le(*elem);
101 }
102
103 *current_state =
108 array::from_fn(|i| array::from_fn(|j| current_state[i][j] ^ state_c[i] ^ state_c_prime[i]));
109 for (x, x_row) in current_state.iter().enumerate() {
110 for (y, elem) in x_row.iter().enumerate() {
111 row.a_prime[y][x] = u64_to_bits_le(*elem);
112 }
113 }
114
115 *current_state = array::from_fn(|i| {
117 array::from_fn(|j| {
118 let new_i = (i + 3 * j) % 5;
119 let new_j = i;
120 current_state[new_i][new_j].rotate_left(R[new_i][new_j] as u32)
121 })
122 });
123
124 *current_state = array::from_fn(|i| {
127 array::from_fn(|j| {
128 current_state[i][j] ^ ((!current_state[(i + 1) % 5][j]) & current_state[(i + 2) % 5][j])
129 })
130 });
131 for (x, x_row) in current_state.iter().enumerate() {
132 for (y, elem) in x_row.iter().enumerate() {
133 row.a_prime_prime[y][x] = u64_to_16_bit_limbs(*elem);
134 }
135 }
136
137 row.a_prime_prime_0_0_bits = u64_to_bits_le(current_state[0][0]);
138
139 current_state[0][0] ^= RC[round];
141
142 row.a_prime_prime_prime_0_0_limbs = u64_to_16_bit_limbs(current_state[0][0]);
143}
144
145#[cfg(test)]
146mod tests {
147 use alloc::vec;
148
149 use p3_goldilocks::Goldilocks;
150 use p3_keccak::KeccakF;
151 use p3_symmetric::Permutation;
152
153 use super::*;
154
155 fn extract_output_from_trace<F: PrimeField64>(rows: &[KeccakCols<F>]) -> [u64; 25] {
158 let last_row = &rows[NUM_ROUNDS - 1];
159 let mut output = [0u64; 25];
160
161 for y in 0..5 {
162 for x in 0..5 {
163 let mut value = 0u64;
164 for limb in 0..U64_LIMBS {
165 let limb_val = last_row.a_prime_prime_prime(y, x, limb).as_canonical_u64();
166 value |= limb_val << (limb * 16);
167 }
168 output[x + 5 * y] = value;
170 }
171 }
172 output
173 }
174
175 fn extract_input_from_trace<F: PrimeField64>(rows: &[KeccakCols<F>]) -> [u64; 25] {
177 let first_row = &rows[0];
178 let mut input = [0u64; 25];
179
180 for y in 0..5 {
181 for x in 0..5 {
182 let mut value = 0u64;
183 for limb in 0..U64_LIMBS {
184 let limb_val = first_row.preimage[y][x][limb].as_canonical_u64();
185 value |= limb_val << (limb * 16);
186 }
187 input[x + 5 * y] = value;
189 }
190 }
191 input
192 }
193
194 #[test]
195 fn test_keccak_permutation_matches_p3_keccak() {
196 let input: [u64; 25] = core::array::from_fn(|i| i as u64 * 0x0123456789ABCDEFu64);
198
199 let mut expected_output = input;
201 KeccakF.permute_mut(&mut expected_output);
202
203 let trace = generate_trace_rows::<Goldilocks>(vec![input], 0);
205 let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
206 assert!(prefix.is_empty());
207 assert!(suffix.is_empty());
208
209 let stored_input = extract_input_from_trace(&rows[..NUM_ROUNDS]);
211 assert_eq!(
212 stored_input, input,
213 "Input state should match the provided input"
214 );
215
216 let our_output = extract_output_from_trace(&rows[..NUM_ROUNDS]);
218 assert_eq!(
219 our_output, expected_output,
220 "Keccak-f output should match p3-keccak reference implementation"
221 );
222 }
223
224 #[test]
225 fn test_keccak_permutation_zero_state() {
226 let input = [0u64; 25];
228
229 let mut expected_output = input;
230 KeccakF.permute_mut(&mut expected_output);
231
232 let trace = generate_trace_rows::<Goldilocks>(vec![input], 0);
233 let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
234 assert!(prefix.is_empty());
235 assert!(suffix.is_empty());
236
237 let our_output = extract_output_from_trace(&rows[..NUM_ROUNDS]);
238 assert_eq!(
239 our_output, expected_output,
240 "Keccak-f on zero state should match p3-keccak"
241 );
242 }
243
244 #[test]
245 fn test_keccak_permutation_known_vector() {
246 let mut input = [0u64; 25];
248 input[0] = 1;
249
250 let mut expected_output = input;
251 KeccakF.permute_mut(&mut expected_output);
252
253 let trace = generate_trace_rows::<Goldilocks>(vec![input], 0);
254 let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
255 assert!(prefix.is_empty());
256 assert!(suffix.is_empty());
257
258 let our_output = extract_output_from_trace(&rows[..NUM_ROUNDS]);
259 assert_eq!(
260 our_output, expected_output,
261 "Keccak-f with input[0]=1 should match p3-keccak"
262 );
263 }
264
265 #[test]
266 fn test_multiple_permutations() {
267 let inputs: Vec<[u64; 25]> = (0..4)
269 .map(|i| core::array::from_fn(|j| (i * 25 + j) as u64))
270 .collect();
271
272 let expected_outputs: Vec<[u64; 25]> = inputs
273 .iter()
274 .map(|input| {
275 let mut output = *input;
276 KeccakF.permute_mut(&mut output);
277 output
278 })
279 .collect();
280
281 let trace = generate_trace_rows::<Goldilocks>(inputs, 0);
282 let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
283 assert!(prefix.is_empty());
284 assert!(suffix.is_empty());
285
286 for (i, expected) in expected_outputs.iter().enumerate() {
287 let start = i * NUM_ROUNDS;
288 let our_output = extract_output_from_trace(&rows[start..start + NUM_ROUNDS]);
289 assert_eq!(
290 our_output, *expected,
291 "Permutation {} should match p3-keccak",
292 i
293 );
294 }
295 }
296
297 #[test]
298 fn test_input_output_limb_indexing() {
299 let input: [u64; 25] = core::array::from_fn(|i| i as u64 + 1);
303 let trace = generate_trace_rows::<Goldilocks>(vec![input], 0);
304 let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
305 assert!(prefix.is_empty());
306 assert!(suffix.is_empty());
307
308 let first_row = &rows[0];
310 for (i_u64, &expected_val) in input.iter().enumerate() {
311 let y = i_u64 / 5;
312 let x = i_u64 % 5;
313
314 let mut stored_value = 0u64;
315 for limb in 0..U64_LIMBS {
316 let limb_val = first_row.preimage[y][x][limb].as_canonical_u64();
317 stored_value |= limb_val << (limb * 16);
318 }
319
320 assert_eq!(
323 stored_value, expected_val,
324 "preimage[{}][{}] should equal input[{}]",
325 y, x, i_u64
326 );
327 }
328 }
329}