p3_keccak_air/
generation.rs

1use alloc::vec::Vec;
2use core::array;
3use core::mem::transmute;
4
5use p3_air::utils::{u64_to_16_bit_limbs, u64_to_bits_le};
6use p3_field::PrimeField64;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_maybe_rayon::iter::repeat_n;
9use p3_maybe_rayon::prelude::*;
10use tracing::instrument;
11
12use crate::columns::{KeccakCols, NUM_KECCAK_COLS};
13use crate::{NUM_ROUNDS, R, RC, U64_LIMBS};
14
15// TODO: Take generic iterable
16#[instrument(name = "generate Keccak trace", skip_all)]
17pub fn generate_trace_rows<F: PrimeField64>(
18    inputs: Vec<[u64; 25]>,
19    extra_capacity_bits: usize,
20) -> RowMajorMatrix<F> {
21    let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
22    let trace_length = num_rows * NUM_KECCAK_COLS;
23
24    // We allocate extra_capacity_bits now as this will be needed by the dft.
25    let mut long_trace = F::zero_vec(trace_length << extra_capacity_bits);
26    long_trace.truncate(trace_length);
27
28    let mut trace = RowMajorMatrix::new(long_trace, NUM_KECCAK_COLS);
29    let (prefix, rows, suffix) = unsafe { trace.values.align_to_mut::<KeccakCols<F>>() };
30    assert!(prefix.is_empty(), "Alignment should match");
31    assert!(suffix.is_empty(), "Alignment should match");
32    assert_eq!(rows.len(), num_rows);
33
34    let num_padding_inputs = num_rows.div_ceil(NUM_ROUNDS) - inputs.len();
35    let padded_inputs = inputs
36        .into_par_iter()
37        .chain(repeat_n([0; 25], num_padding_inputs));
38
39    rows.par_chunks_mut(NUM_ROUNDS)
40        .zip(padded_inputs)
41        .for_each(|(row, input)| {
42            generate_trace_rows_for_perm(row, input);
43        });
44
45    trace
46}
47
48/// `rows` will normally consist of 24 rows, with an exception for the final row.
49fn generate_trace_rows_for_perm<F: PrimeField64>(rows: &mut [KeccakCols<F>], input: [u64; 25]) {
50    // Convert flat input array to 5x5 matrix.
51    // The input uses standard Keccak indexing: input[x + 5*y] corresponds to state[x][y].
52    // After transmute, we get row-major layout: transmuted[i][j] = input[i*5 + j].
53    // To align with Keccak's state[x][y] = input[x + 5*y], we need to transpose.
54    let transmuted: [[u64; 5]; 5] = unsafe { transmute(input) };
55    let mut current_state: [[u64; 5]; 5] = array::from_fn(|x| array::from_fn(|y| transmuted[y][x]));
56
57    // initial_state is stored in y-major order for the AIR columns (preimage[y][x]).
58    let initial_state: [[[F; 4]; 5]; 5] =
59        array::from_fn(|y| array::from_fn(|x| u64_to_16_bit_limbs(current_state[x][y])));
60
61    // Populate the round input for the first round.
62    rows[0].a = initial_state;
63    rows[0].preimage = initial_state;
64
65    generate_trace_row_for_round(&mut rows[0], 0, &mut current_state);
66
67    for round in 1..rows.len() {
68        rows[round].preimage = initial_state;
69
70        // Copy previous row's output to next row's input.
71        for y in 0..5 {
72            for x in 0..5 {
73                for limb in 0..U64_LIMBS {
74                    rows[round].a[y][x][limb] = rows[round - 1].a_prime_prime_prime(y, x, limb);
75                }
76            }
77        }
78
79        generate_trace_row_for_round(&mut rows[round], round, &mut current_state);
80    }
81}
82
83fn generate_trace_row_for_round<F: PrimeField64>(
84    row: &mut KeccakCols<F>,
85    round: usize,
86    current_state: &mut [[u64; 5]; 5],
87) {
88    row.step_flags[round] = F::ONE;
89
90    // Populate C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]).
91    let state_c: [u64; 5] = current_state.map(|row| row.iter().fold(0, |acc, y| acc ^ y));
92    for (x, elem) in state_c.iter().enumerate() {
93        row.c[x] = u64_to_bits_le(*elem);
94    }
95
96    // Populate C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
97    let state_c_prime: [u64; 5] =
98        array::from_fn(|x| state_c[x] ^ state_c[(x + 4) % 5] ^ state_c[(x + 1) % 5].rotate_left(1));
99    for (x, elem) in state_c_prime.iter().enumerate() {
100        row.c_prime[x] = u64_to_bits_le(*elem);
101    }
102
103    // Populate A'. To avoid shifting indices, we rewrite
104    //     A'[x, y, z] = xor(A[x, y, z], C[x - 1, z], C[x + 1, z - 1])
105    // as
106    //     A'[x, y, z] = xor(A[x, y, z], C[x, z], C'[x, z]).
107    *current_state =
108        array::from_fn(|i| array::from_fn(|j| current_state[i][j] ^ state_c[i] ^ state_c_prime[i]));
109    for (x, x_row) in current_state.iter().enumerate() {
110        for (y, elem) in x_row.iter().enumerate() {
111            row.a_prime[y][x] = u64_to_bits_le(*elem);
112        }
113    }
114
115    // Rotate the current state to get the B array.
116    *current_state = array::from_fn(|i| {
117        array::from_fn(|j| {
118            let new_i = (i + 3 * j) % 5;
119            let new_j = i;
120            current_state[new_i][new_j].rotate_left(R[new_i][new_j] as u32)
121        })
122    });
123
124    // Populate A''.
125    // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
126    *current_state = array::from_fn(|i| {
127        array::from_fn(|j| {
128            current_state[i][j] ^ ((!current_state[(i + 1) % 5][j]) & current_state[(i + 2) % 5][j])
129        })
130    });
131    for (x, x_row) in current_state.iter().enumerate() {
132        for (y, elem) in x_row.iter().enumerate() {
133            row.a_prime_prime[y][x] = u64_to_16_bit_limbs(*elem);
134        }
135    }
136
137    row.a_prime_prime_0_0_bits = u64_to_bits_le(current_state[0][0]);
138
139    // A''[0, 0] is additionally xor'd with RC.
140    current_state[0][0] ^= RC[round];
141
142    row.a_prime_prime_prime_0_0_limbs = u64_to_16_bit_limbs(current_state[0][0]);
143}
144
145#[cfg(test)]
146mod tests {
147    use alloc::vec;
148
149    use p3_goldilocks::Goldilocks;
150    use p3_keccak::KeccakF;
151    use p3_symmetric::Permutation;
152
153    use super::*;
154
155    /// Helper function to extract the output state from the trace after all 24 rounds.
156    /// The output is stored in `a_prime_prime_prime` for (0,0) and `a_prime_prime` for others.
157    fn extract_output_from_trace<F: PrimeField64>(rows: &[KeccakCols<F>]) -> [u64; 25] {
158        let last_row = &rows[NUM_ROUNDS - 1];
159        let mut output = [0u64; 25];
160
161        for y in 0..5 {
162            for x in 0..5 {
163                let mut value = 0u64;
164                for limb in 0..U64_LIMBS {
165                    let limb_val = last_row.a_prime_prime_prime(y, x, limb).as_canonical_u64();
166                    value |= limb_val << (limb * 16);
167                }
168                // Standard Keccak indexing: state[x + 5*y]
169                output[x + 5 * y] = value;
170            }
171        }
172        output
173    }
174
175    /// Helper function to extract the input preimage from the trace.
176    fn extract_input_from_trace<F: PrimeField64>(rows: &[KeccakCols<F>]) -> [u64; 25] {
177        let first_row = &rows[0];
178        let mut input = [0u64; 25];
179
180        for y in 0..5 {
181            for x in 0..5 {
182                let mut value = 0u64;
183                for limb in 0..U64_LIMBS {
184                    let limb_val = first_row.preimage[y][x][limb].as_canonical_u64();
185                    value |= limb_val << (limb * 16);
186                }
187                // Standard Keccak indexing: state[x + 5*y]
188                input[x + 5 * y] = value;
189            }
190        }
191        input
192    }
193
194    #[test]
195    fn test_keccak_permutation_matches_p3_keccak() {
196        // Test with a non-trivial input state
197        let input: [u64; 25] = core::array::from_fn(|i| i as u64 * 0x0123456789ABCDEFu64);
198
199        // Compute expected output using p3-keccak (reference implementation)
200        let mut expected_output = input;
201        KeccakF.permute_mut(&mut expected_output);
202
203        // Generate trace using our implementation
204        let trace = generate_trace_rows::<Goldilocks>(vec![input], 0);
205        let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
206        assert!(prefix.is_empty());
207        assert!(suffix.is_empty());
208
209        // Verify input was stored correctly
210        let stored_input = extract_input_from_trace(&rows[..NUM_ROUNDS]);
211        assert_eq!(
212            stored_input, input,
213            "Input state should match the provided input"
214        );
215
216        // Verify output matches p3-keccak
217        let our_output = extract_output_from_trace(&rows[..NUM_ROUNDS]);
218        assert_eq!(
219            our_output, expected_output,
220            "Keccak-f output should match p3-keccak reference implementation"
221        );
222    }
223
224    #[test]
225    fn test_keccak_permutation_zero_state() {
226        // Test with all-zero state
227        let input = [0u64; 25];
228
229        let mut expected_output = input;
230        KeccakF.permute_mut(&mut expected_output);
231
232        let trace = generate_trace_rows::<Goldilocks>(vec![input], 0);
233        let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
234        assert!(prefix.is_empty());
235        assert!(suffix.is_empty());
236
237        let our_output = extract_output_from_trace(&rows[..NUM_ROUNDS]);
238        assert_eq!(
239            our_output, expected_output,
240            "Keccak-f on zero state should match p3-keccak"
241        );
242    }
243
244    #[test]
245    fn test_keccak_permutation_known_vector() {
246        // Known test vector: state with only first element set to 1
247        let mut input = [0u64; 25];
248        input[0] = 1;
249
250        let mut expected_output = input;
251        KeccakF.permute_mut(&mut expected_output);
252
253        let trace = generate_trace_rows::<Goldilocks>(vec![input], 0);
254        let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
255        assert!(prefix.is_empty());
256        assert!(suffix.is_empty());
257
258        let our_output = extract_output_from_trace(&rows[..NUM_ROUNDS]);
259        assert_eq!(
260            our_output, expected_output,
261            "Keccak-f with input[0]=1 should match p3-keccak"
262        );
263    }
264
265    #[test]
266    fn test_multiple_permutations() {
267        // Test multiple permutations in a single trace
268        let inputs: Vec<[u64; 25]> = (0..4)
269            .map(|i| core::array::from_fn(|j| (i * 25 + j) as u64))
270            .collect();
271
272        let expected_outputs: Vec<[u64; 25]> = inputs
273            .iter()
274            .map(|input| {
275                let mut output = *input;
276                KeccakF.permute_mut(&mut output);
277                output
278            })
279            .collect();
280
281        let trace = generate_trace_rows::<Goldilocks>(inputs, 0);
282        let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
283        assert!(prefix.is_empty());
284        assert!(suffix.is_empty());
285
286        for (i, expected) in expected_outputs.iter().enumerate() {
287            let start = i * NUM_ROUNDS;
288            let our_output = extract_output_from_trace(&rows[start..start + NUM_ROUNDS]);
289            assert_eq!(
290                our_output, *expected,
291                "Permutation {} should match p3-keccak",
292                i
293            );
294        }
295    }
296
297    #[test]
298    fn test_input_output_limb_indexing() {
299        // Verify that input_limb and output_limb functions use correct indexing
300        // This tests the column mapping for preimage and output
301
302        let input: [u64; 25] = core::array::from_fn(|i| i as u64 + 1);
303        let trace = generate_trace_rows::<Goldilocks>(vec![input], 0);
304        let (prefix, rows, suffix) = unsafe { trace.values.align_to::<KeccakCols<Goldilocks>>() };
305        assert!(prefix.is_empty());
306        assert!(suffix.is_empty());
307
308        // Check that preimage is stored in y-major order as per Keccak spec
309        let first_row = &rows[0];
310        for (i_u64, &expected_val) in input.iter().enumerate() {
311            let y = i_u64 / 5;
312            let x = i_u64 % 5;
313
314            let mut stored_value = 0u64;
315            for limb in 0..U64_LIMBS {
316                let limb_val = first_row.preimage[y][x][limb].as_canonical_u64();
317                stored_value |= limb_val << (limb * 16);
318            }
319
320            // input[i_u64] should be stored at preimage[y][x] where i_u64 = x + 5*y
321            // So input[x + 5*y] should equal preimage[y][x]
322            assert_eq!(
323                stored_value, expected_val,
324                "preimage[{}][{}] should equal input[{}]",
325                y, x, i_u64
326            );
327        }
328    }
329}