p3_keccak_air/
generation.rs

1use alloc::vec::Vec;
2use core::array;
3use core::mem::transmute;
4
5use p3_air::utils::{u64_to_16_bit_limbs, u64_to_bits_le};
6use p3_field::PrimeField64;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_maybe_rayon::iter::repeat;
9use p3_maybe_rayon::prelude::*;
10use tracing::instrument;
11
12use crate::columns::{KeccakCols, NUM_KECCAK_COLS};
13use crate::{NUM_ROUNDS, R, RC, U64_LIMBS};
14
15// TODO: Take generic iterable
16#[instrument(name = "generate Keccak trace", skip_all)]
17pub fn generate_trace_rows<F: PrimeField64>(
18    inputs: Vec<[u64; 25]>,
19    extra_capacity_bits: usize,
20) -> RowMajorMatrix<F> {
21    let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
22    let trace_length = num_rows * NUM_KECCAK_COLS;
23
24    // We allocate extra_capacity_bits now as this will be needed by the dft.
25    let mut long_trace = F::zero_vec(trace_length << extra_capacity_bits);
26    long_trace.truncate(trace_length);
27
28    let mut trace = RowMajorMatrix::new(long_trace, NUM_KECCAK_COLS);
29    let (prefix, rows, suffix) = unsafe { trace.values.align_to_mut::<KeccakCols<F>>() };
30    assert!(prefix.is_empty(), "Alignment should match");
31    assert!(suffix.is_empty(), "Alignment should match");
32    assert_eq!(rows.len(), num_rows);
33
34    let num_padding_inputs = num_rows.div_ceil(NUM_ROUNDS) - inputs.len();
35    let padded_inputs = inputs
36        .into_par_iter()
37        .chain(repeat([0; 25]).take(num_padding_inputs));
38
39    rows.par_chunks_mut(NUM_ROUNDS)
40        .zip(padded_inputs)
41        .for_each(|(row, input)| {
42            generate_trace_rows_for_perm(row, input);
43        });
44
45    trace
46}
47
48/// `rows` will normally consist of 24 rows, with an exception for the final row.
49fn generate_trace_rows_for_perm<F: PrimeField64>(rows: &mut [KeccakCols<F>], input: [u64; 25]) {
50    let mut current_state: [[u64; 5]; 5] = unsafe { transmute(input) };
51
52    let initial_state: [[[F; 4]; 5]; 5] =
53        array::from_fn(|y| array::from_fn(|x| u64_to_16_bit_limbs(current_state[x][y])));
54
55    // Populate the round input for the first round.
56    rows[0].a = initial_state;
57    rows[0].preimage = initial_state;
58
59    generate_trace_row_for_round(&mut rows[0], 0, &mut current_state);
60
61    for round in 1..rows.len() {
62        rows[round].preimage = initial_state;
63
64        // Copy previous row's output to next row's input.
65        for y in 0..5 {
66            for x in 0..5 {
67                for limb in 0..U64_LIMBS {
68                    rows[round].a[y][x][limb] = rows[round - 1].a_prime_prime_prime(y, x, limb);
69                }
70            }
71        }
72
73        generate_trace_row_for_round(&mut rows[round], round, &mut current_state);
74    }
75}
76
77fn generate_trace_row_for_round<F: PrimeField64>(
78    row: &mut KeccakCols<F>,
79    round: usize,
80    current_state: &mut [[u64; 5]; 5],
81) {
82    row.step_flags[round] = F::ONE;
83
84    // Populate C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]).
85    let state_c: [u64; 5] = current_state.map(|row| row.iter().fold(0, |acc, y| acc ^ y));
86    for (x, elem) in state_c.iter().enumerate() {
87        row.c[x] = u64_to_bits_le(*elem);
88    }
89
90    // Populate C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
91    let state_c_prime: [u64; 5] =
92        array::from_fn(|x| state_c[x] ^ state_c[(x + 4) % 5] ^ state_c[(x + 1) % 5].rotate_left(1));
93    for (x, elem) in state_c_prime.iter().enumerate() {
94        row.c_prime[x] = u64_to_bits_le(*elem);
95    }
96
97    // Populate A'. To avoid shifting indices, we rewrite
98    //     A'[x, y, z] = xor(A[x, y, z], C[x - 1, z], C[x + 1, z - 1])
99    // as
100    //     A'[x, y, z] = xor(A[x, y, z], C[x, z], C'[x, z]).
101    *current_state =
102        array::from_fn(|i| array::from_fn(|j| current_state[i][j] ^ state_c[i] ^ state_c_prime[i]));
103    for (x, x_row) in current_state.iter().enumerate() {
104        for (y, elem) in x_row.iter().enumerate() {
105            row.a_prime[y][x] = u64_to_bits_le(*elem);
106        }
107    }
108
109    // Rotate the current state to get the B array.
110    *current_state = array::from_fn(|i| {
111        array::from_fn(|j| {
112            let new_i = (i + 3 * j) % 5;
113            let new_j = i;
114            current_state[new_i][new_j].rotate_left(R[new_i][new_j] as u32)
115        })
116    });
117
118    // Populate A''.
119    // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
120    *current_state = array::from_fn(|i| {
121        array::from_fn(|j| {
122            current_state[i][j] ^ ((!current_state[(i + 1) % 5][j]) & current_state[(i + 2) % 5][j])
123        })
124    });
125    for (x, x_row) in current_state.iter().enumerate() {
126        for (y, elem) in x_row.iter().enumerate() {
127            row.a_prime_prime[y][x] = u64_to_16_bit_limbs(*elem);
128        }
129    }
130
131    row.a_prime_prime_0_0_bits = u64_to_bits_le(current_state[0][0]);
132
133    // A''[0, 0] is additionally xor'd with RC.
134    current_state[0][0] ^= RC[round];
135
136    row.a_prime_prime_prime_0_0_limbs = u64_to_16_bit_limbs(current_state[0][0]);
137}