p3_poseidon2_air/
generation.rs

1use alloc::vec::Vec;
2use core::mem::MaybeUninit;
3
4use p3_field::PrimeField;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_maybe_rayon::prelude::*;
7use p3_poseidon2::GenericPoseidon2LinearLayers;
8use tracing::instrument;
9
10use crate::columns::{Poseidon2Cols, num_cols};
11use crate::{FullRound, PartialRound, RoundConstants, SBox};
12
13#[instrument(name = "generate vectorized Poseidon2 trace", skip_all)]
14pub fn generate_vectorized_trace_rows<
15    F: PrimeField,
16    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
17    const WIDTH: usize,
18    const SBOX_DEGREE: u64,
19    const SBOX_REGISTERS: usize,
20    const HALF_FULL_ROUNDS: usize,
21    const PARTIAL_ROUNDS: usize,
22    const VECTOR_LEN: usize,
23>(
24    inputs: Vec<[F; WIDTH]>,
25    round_constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
26    extra_capacity_bits: usize,
27) -> RowMajorMatrix<F> {
28    let n = inputs.len();
29    assert!(
30        n.is_multiple_of(VECTOR_LEN) && (n / VECTOR_LEN).is_power_of_two(),
31        "Callers expected to pad inputs to VECTOR_LEN times a power of two"
32    );
33
34    let nrows = n.div_ceil(VECTOR_LEN);
35    let ncols = num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>()
36        * VECTOR_LEN;
37    let mut vec = Vec::with_capacity((nrows * ncols) << extra_capacity_bits);
38    let trace = &mut vec.spare_capacity_mut()[..nrows * ncols];
39    let trace = RowMajorMatrixViewMut::new(trace, ncols);
40
41    let (prefix, perms, suffix) = unsafe {
42        trace.values.align_to_mut::<Poseidon2Cols<
43            MaybeUninit<F>,
44            WIDTH,
45            SBOX_DEGREE,
46            SBOX_REGISTERS,
47            HALF_FULL_ROUNDS,
48            PARTIAL_ROUNDS,
49        >>()
50    };
51    assert!(prefix.is_empty(), "Alignment should match");
52    assert!(suffix.is_empty(), "Alignment should match");
53    assert_eq!(perms.len(), n);
54
55    perms.par_iter_mut().zip(inputs).for_each(|(perm, input)| {
56        generate_trace_rows_for_perm::<
57            F,
58            LinearLayers,
59            WIDTH,
60            SBOX_DEGREE,
61            SBOX_REGISTERS,
62            HALF_FULL_ROUNDS,
63            PARTIAL_ROUNDS,
64        >(perm, input, round_constants);
65    });
66
67    unsafe {
68        vec.set_len(nrows * ncols);
69    }
70
71    RowMajorMatrix::new(vec, ncols)
72}
73
74// TODO: Take generic iterable
75#[instrument(name = "generate Poseidon2 trace", skip_all)]
76pub fn generate_trace_rows<
77    F: PrimeField,
78    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
79    const WIDTH: usize,
80    const SBOX_DEGREE: u64,
81    const SBOX_REGISTERS: usize,
82    const HALF_FULL_ROUNDS: usize,
83    const PARTIAL_ROUNDS: usize,
84>(
85    inputs: Vec<[F; WIDTH]>,
86    constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
87    extra_capacity_bits: usize,
88) -> RowMajorMatrix<F> {
89    let n = inputs.len();
90    assert!(
91        n.is_power_of_two(),
92        "Callers expected to pad inputs to a power of two"
93    );
94
95    let ncols = num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>();
96    let mut vec = Vec::with_capacity((n * ncols) << extra_capacity_bits);
97    let trace = &mut vec.spare_capacity_mut()[..n * ncols];
98    let trace = RowMajorMatrixViewMut::new(trace, ncols);
99
100    let (prefix, perms, suffix) = unsafe {
101        trace.values.align_to_mut::<Poseidon2Cols<
102            MaybeUninit<F>,
103            WIDTH,
104            SBOX_DEGREE,
105            SBOX_REGISTERS,
106            HALF_FULL_ROUNDS,
107            PARTIAL_ROUNDS,
108        >>()
109    };
110    assert!(prefix.is_empty(), "Alignment should match");
111    assert!(suffix.is_empty(), "Alignment should match");
112    assert_eq!(perms.len(), n);
113
114    perms.par_iter_mut().zip(inputs).for_each(|(perm, input)| {
115        generate_trace_rows_for_perm::<
116            F,
117            LinearLayers,
118            WIDTH,
119            SBOX_DEGREE,
120            SBOX_REGISTERS,
121            HALF_FULL_ROUNDS,
122            PARTIAL_ROUNDS,
123        >(perm, input, constants);
124    });
125
126    unsafe {
127        vec.set_len(n * ncols);
128    }
129
130    RowMajorMatrix::new(vec, ncols)
131}
132
133/// `rows` will normally consist of 24 rows, with an exception for the final row.
134pub fn generate_trace_rows_for_perm<
135    F: PrimeField,
136    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
137    const WIDTH: usize,
138    const SBOX_DEGREE: u64,
139    const SBOX_REGISTERS: usize,
140    const HALF_FULL_ROUNDS: usize,
141    const PARTIAL_ROUNDS: usize,
142>(
143    perm: &mut Poseidon2Cols<
144        MaybeUninit<F>,
145        WIDTH,
146        SBOX_DEGREE,
147        SBOX_REGISTERS,
148        HALF_FULL_ROUNDS,
149        PARTIAL_ROUNDS,
150    >,
151    mut state: [F; WIDTH],
152    constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
153) {
154    perm.export.write(F::ONE);
155    perm.inputs
156        .iter_mut()
157        .zip(state.iter())
158        .for_each(|(input, &x)| {
159            input.write(x);
160        });
161
162    LinearLayers::external_linear_layer(&mut state);
163
164    for (full_round, constants) in perm
165        .beginning_full_rounds
166        .iter_mut()
167        .zip(&constants.beginning_full_round_constants)
168    {
169        generate_full_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
170            &mut state, full_round, constants,
171        );
172    }
173
174    for (partial_round, constant) in perm
175        .partial_rounds
176        .iter_mut()
177        .zip(&constants.partial_round_constants)
178    {
179        generate_partial_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
180            &mut state,
181            partial_round,
182            *constant,
183        );
184    }
185
186    for (full_round, constants) in perm
187        .ending_full_rounds
188        .iter_mut()
189        .zip(&constants.ending_full_round_constants)
190    {
191        generate_full_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
192            &mut state, full_round, constants,
193        );
194    }
195}
196
197#[inline]
198fn generate_full_round<
199    F: PrimeField,
200    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
201    const WIDTH: usize,
202    const SBOX_DEGREE: u64,
203    const SBOX_REGISTERS: usize,
204>(
205    state: &mut [F; WIDTH],
206    full_round: &mut FullRound<MaybeUninit<F>, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
207    round_constants: &[F; WIDTH],
208) {
209    // Combine addition of round constants and S-box application in a single loop
210    for ((state_i, const_i), sbox_i) in state
211        .iter_mut()
212        .zip(round_constants.iter())
213        .zip(full_round.sbox.iter_mut())
214    {
215        *state_i += *const_i;
216        generate_sbox(sbox_i, state_i);
217    }
218
219    LinearLayers::external_linear_layer(state);
220    full_round
221        .post
222        .iter_mut()
223        .zip(*state)
224        .for_each(|(post, x)| {
225            post.write(x);
226        });
227}
228
229#[inline]
230fn generate_partial_round<
231    F: PrimeField,
232    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
233    const WIDTH: usize,
234    const SBOX_DEGREE: u64,
235    const SBOX_REGISTERS: usize,
236>(
237    state: &mut [F; WIDTH],
238    partial_round: &mut PartialRound<MaybeUninit<F>, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
239    round_constant: F,
240) {
241    state[0] += round_constant;
242    generate_sbox(&mut partial_round.sbox, &mut state[0]);
243    partial_round.post_sbox.write(state[0]);
244    LinearLayers::internal_linear_layer(state);
245}
246
247/// Computes the S-box `x -> x^{DEGREE}` and stores the partial data required to
248/// verify the computation.
249///
250/// # Panics
251///
252/// This method panics if the number of `REGISTERS` is not chosen optimally for the given
253/// `DEGREE` or if the `DEGREE` is not supported by the S-box. The supported degrees are
254/// `3`, `5`, `7`, and `11`.
255#[inline]
256fn generate_sbox<F: PrimeField, const DEGREE: u64, const REGISTERS: usize>(
257    sbox: &mut SBox<MaybeUninit<F>, DEGREE, REGISTERS>,
258    x: &mut F,
259) {
260    *x = match (DEGREE, REGISTERS) {
261        (3, 0) => x.cube(),
262        (5, 0) => x.exp_const_u64::<5>(),
263        (7, 0) => x.exp_const_u64::<7>(),
264        (5, 1) => {
265            let x2 = x.square();
266            let x3 = x2 * *x;
267            sbox.0[0].write(x3);
268            x3 * x2
269        }
270        (7, 1) => {
271            let x3 = x.cube();
272            sbox.0[0].write(x3);
273            x3 * x3 * *x
274        }
275        (11, 2) => {
276            let x2 = x.square();
277            let x3 = x2 * *x;
278            let x9 = x3.cube();
279            sbox.0[0].write(x3);
280            sbox.0[1].write(x9);
281            x9 * x2
282        }
283        _ => panic!(
284            "Unexpected (DEGREE, REGISTERS) of ({}, {})",
285            DEGREE, REGISTERS
286        ),
287    }
288}