p3_keccak_air/
generation.rs
1use alloc::vec::Vec;
2use core::array;
3use core::mem::transmute;
4
5use p3_air::utils::{u64_to_16_bit_limbs, u64_to_bits_le};
6use p3_field::PrimeField64;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_maybe_rayon::iter::repeat;
9use p3_maybe_rayon::prelude::*;
10use tracing::instrument;
11
12use crate::columns::{KeccakCols, NUM_KECCAK_COLS};
13use crate::{NUM_ROUNDS, R, RC, U64_LIMBS};
14
15#[instrument(name = "generate Keccak trace", skip_all)]
17pub fn generate_trace_rows<F: PrimeField64>(
18 inputs: Vec<[u64; 25]>,
19 extra_capacity_bits: usize,
20) -> RowMajorMatrix<F> {
21 let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
22 let trace_length = num_rows * NUM_KECCAK_COLS;
23
24 let mut long_trace = F::zero_vec(trace_length << extra_capacity_bits);
26 long_trace.truncate(trace_length);
27
28 let mut trace = RowMajorMatrix::new(long_trace, NUM_KECCAK_COLS);
29 let (prefix, rows, suffix) = unsafe { trace.values.align_to_mut::<KeccakCols<F>>() };
30 assert!(prefix.is_empty(), "Alignment should match");
31 assert!(suffix.is_empty(), "Alignment should match");
32 assert_eq!(rows.len(), num_rows);
33
34 let num_padding_inputs = num_rows.div_ceil(NUM_ROUNDS) - inputs.len();
35 let padded_inputs = inputs
36 .into_par_iter()
37 .chain(repeat([0; 25]).take(num_padding_inputs));
38
39 rows.par_chunks_mut(NUM_ROUNDS)
40 .zip(padded_inputs)
41 .for_each(|(row, input)| {
42 generate_trace_rows_for_perm(row, input);
43 });
44
45 trace
46}
47
48fn generate_trace_rows_for_perm<F: PrimeField64>(rows: &mut [KeccakCols<F>], input: [u64; 25]) {
50 let mut current_state: [[u64; 5]; 5] = unsafe { transmute(input) };
51
52 let initial_state: [[[F; 4]; 5]; 5] =
53 array::from_fn(|y| array::from_fn(|x| u64_to_16_bit_limbs(current_state[x][y])));
54
55 rows[0].a = initial_state;
57 rows[0].preimage = initial_state;
58
59 generate_trace_row_for_round(&mut rows[0], 0, &mut current_state);
60
61 for round in 1..rows.len() {
62 rows[round].preimage = initial_state;
63
64 for y in 0..5 {
66 for x in 0..5 {
67 for limb in 0..U64_LIMBS {
68 rows[round].a[y][x][limb] = rows[round - 1].a_prime_prime_prime(y, x, limb);
69 }
70 }
71 }
72
73 generate_trace_row_for_round(&mut rows[round], round, &mut current_state);
74 }
75}
76
77fn generate_trace_row_for_round<F: PrimeField64>(
78 row: &mut KeccakCols<F>,
79 round: usize,
80 current_state: &mut [[u64; 5]; 5],
81) {
82 row.step_flags[round] = F::ONE;
83
84 let state_c: [u64; 5] = current_state.map(|row| row.iter().fold(0, |acc, y| acc ^ y));
86 for (x, elem) in state_c.iter().enumerate() {
87 row.c[x] = u64_to_bits_le(*elem);
88 }
89
90 let state_c_prime: [u64; 5] =
92 array::from_fn(|x| state_c[x] ^ state_c[(x + 4) % 5] ^ state_c[(x + 1) % 5].rotate_left(1));
93 for (x, elem) in state_c_prime.iter().enumerate() {
94 row.c_prime[x] = u64_to_bits_le(*elem);
95 }
96
97 *current_state =
102 array::from_fn(|i| array::from_fn(|j| current_state[i][j] ^ state_c[i] ^ state_c_prime[i]));
103 for (x, x_row) in current_state.iter().enumerate() {
104 for (y, elem) in x_row.iter().enumerate() {
105 row.a_prime[y][x] = u64_to_bits_le(*elem);
106 }
107 }
108
109 *current_state = array::from_fn(|i| {
111 array::from_fn(|j| {
112 let new_i = (i + 3 * j) % 5;
113 let new_j = i;
114 current_state[new_i][new_j].rotate_left(R[new_i][new_j] as u32)
115 })
116 });
117
118 *current_state = array::from_fn(|i| {
121 array::from_fn(|j| {
122 current_state[i][j] ^ ((!current_state[(i + 1) % 5][j]) & current_state[(i + 2) % 5][j])
123 })
124 });
125 for (x, x_row) in current_state.iter().enumerate() {
126 for (y, elem) in x_row.iter().enumerate() {
127 row.a_prime_prime[y][x] = u64_to_16_bit_limbs(*elem);
128 }
129 }
130
131 row.a_prime_prime_0_0_bits = u64_to_bits_le(current_state[0][0]);
132
133 current_state[0][0] ^= RC[round];
135
136 row.a_prime_prime_prime_0_0_limbs = u64_to_16_bit_limbs(current_state[0][0]);
137}