p3_poseidon2_air/
generation.rs

1use alloc::vec::Vec;
2use core::mem::MaybeUninit;
3
4use p3_field::PrimeField;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_maybe_rayon::prelude::*;
7use p3_poseidon2::GenericPoseidon2LinearLayers;
8use tracing::instrument;
9
10use crate::columns::{num_cols, Poseidon2Cols};
11use crate::{FullRound, PartialRound, RoundConstants, SBox};
12
13#[instrument(name = "generate vectorized Poseidon2 trace", skip_all)]
14pub fn generate_vectorized_trace_rows<
15    F: PrimeField,
16    LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
17    const WIDTH: usize,
18    const SBOX_DEGREE: u64,
19    const SBOX_REGISTERS: usize,
20    const HALF_FULL_ROUNDS: usize,
21    const PARTIAL_ROUNDS: usize,
22    const VECTOR_LEN: usize,
23>(
24    inputs: Vec<[F; WIDTH]>,
25    round_constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
26    extra_capacity_bits: usize,
27) -> RowMajorMatrix<F> {
28    let n = inputs.len();
29    assert!(
30        n % VECTOR_LEN == 0 && (n / VECTOR_LEN).is_power_of_two(),
31        "Callers expected to pad inputs to VECTOR_LEN times a power of two"
32    );
33
34    let nrows = n.div_ceil(VECTOR_LEN);
35    let ncols = num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>()
36        * VECTOR_LEN;
37    let mut vec = Vec::with_capacity((nrows * ncols) << extra_capacity_bits);
38    let trace: &mut [MaybeUninit<F>] = &mut vec.spare_capacity_mut()[..nrows * ncols];
39    let trace: RowMajorMatrixViewMut<MaybeUninit<F>> = RowMajorMatrixViewMut::new(trace, ncols);
40
41    let (prefix, perms, suffix) = unsafe {
42        trace.values.align_to_mut::<Poseidon2Cols<
43            MaybeUninit<F>,
44            WIDTH,
45            SBOX_DEGREE,
46            SBOX_REGISTERS,
47            HALF_FULL_ROUNDS,
48            PARTIAL_ROUNDS,
49        >>()
50    };
51    assert!(prefix.is_empty(), "Alignment should match");
52    assert!(suffix.is_empty(), "Alignment should match");
53    assert_eq!(perms.len(), n);
54
55    perms.par_iter_mut().zip(inputs).for_each(|(perm, input)| {
56        generate_trace_rows_for_perm::<
57            F,
58            LinearLayers,
59            WIDTH,
60            SBOX_DEGREE,
61            SBOX_REGISTERS,
62            HALF_FULL_ROUNDS,
63            PARTIAL_ROUNDS,
64        >(perm, input, round_constants);
65    });
66
67    unsafe {
68        vec.set_len(nrows * ncols);
69    }
70
71    RowMajorMatrix::new(vec, ncols)
72}
73
74// TODO: Take generic iterable
75#[instrument(name = "generate Poseidon2 trace", skip_all)]
76pub fn generate_trace_rows<
77    F: PrimeField,
78    LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
79    const WIDTH: usize,
80    const SBOX_DEGREE: u64,
81    const SBOX_REGISTERS: usize,
82    const HALF_FULL_ROUNDS: usize,
83    const PARTIAL_ROUNDS: usize,
84>(
85    inputs: Vec<[F; WIDTH]>,
86    constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
87) -> RowMajorMatrix<F> {
88    let n = inputs.len();
89    assert!(
90        n.is_power_of_two(),
91        "Callers expected to pad inputs to a power of two"
92    );
93
94    let ncols = num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>();
95    let mut vec = Vec::with_capacity(n * ncols * 2);
96    let trace: &mut [MaybeUninit<F>] = &mut vec.spare_capacity_mut()[..n * ncols];
97    let trace: RowMajorMatrixViewMut<MaybeUninit<F>> = RowMajorMatrixViewMut::new(trace, ncols);
98
99    let (prefix, perms, suffix) = unsafe {
100        trace.values.align_to_mut::<Poseidon2Cols<
101            MaybeUninit<F>,
102            WIDTH,
103            SBOX_DEGREE,
104            SBOX_REGISTERS,
105            HALF_FULL_ROUNDS,
106            PARTIAL_ROUNDS,
107        >>()
108    };
109    assert!(prefix.is_empty(), "Alignment should match");
110    assert!(suffix.is_empty(), "Alignment should match");
111    assert_eq!(perms.len(), n);
112
113    perms.par_iter_mut().zip(inputs).for_each(|(perm, input)| {
114        generate_trace_rows_for_perm::<
115            F,
116            LinearLayers,
117            WIDTH,
118            SBOX_DEGREE,
119            SBOX_REGISTERS,
120            HALF_FULL_ROUNDS,
121            PARTIAL_ROUNDS,
122        >(perm, input, constants);
123    });
124
125    unsafe {
126        vec.set_len(n * ncols);
127    }
128
129    RowMajorMatrix::new(vec, ncols)
130}
131
132/// `rows` will normally consist of 24 rows, with an exception for the final row.
133fn generate_trace_rows_for_perm<
134    F: PrimeField,
135    LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
136    const WIDTH: usize,
137    const SBOX_DEGREE: u64,
138    const SBOX_REGISTERS: usize,
139    const HALF_FULL_ROUNDS: usize,
140    const PARTIAL_ROUNDS: usize,
141>(
142    perm: &mut Poseidon2Cols<
143        MaybeUninit<F>,
144        WIDTH,
145        SBOX_DEGREE,
146        SBOX_REGISTERS,
147        HALF_FULL_ROUNDS,
148        PARTIAL_ROUNDS,
149    >,
150    mut state: [F; WIDTH],
151    constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
152) {
153    perm.export.write(F::ONE);
154    perm.inputs
155        .iter_mut()
156        .zip(state.iter())
157        .for_each(|(input, &x)| {
158            input.write(x);
159        });
160
161    LinearLayers::external_linear_layer(&mut state);
162
163    for (full_round, constants) in perm
164        .beginning_full_rounds
165        .iter_mut()
166        .zip(&constants.beginning_full_round_constants)
167    {
168        generate_full_round::<F, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
169            &mut state, full_round, constants,
170        );
171    }
172
173    for (partial_round, constant) in perm
174        .partial_rounds
175        .iter_mut()
176        .zip(&constants.partial_round_constants)
177    {
178        generate_partial_round::<F, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
179            &mut state,
180            partial_round,
181            *constant,
182        );
183    }
184
185    for (full_round, constants) in perm
186        .ending_full_rounds
187        .iter_mut()
188        .zip(&constants.ending_full_round_constants)
189    {
190        generate_full_round::<F, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
191            &mut state, full_round, constants,
192        );
193    }
194}
195
196#[inline]
197fn generate_full_round<
198    F: PrimeField,
199    LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
200    const WIDTH: usize,
201    const SBOX_DEGREE: u64,
202    const SBOX_REGISTERS: usize,
203>(
204    state: &mut [F; WIDTH],
205    full_round: &mut FullRound<MaybeUninit<F>, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
206    round_constants: &[F; WIDTH],
207) {
208    for (state_i, const_i) in state.iter_mut().zip(round_constants) {
209        *state_i += *const_i;
210    }
211    for (state_i, sbox_i) in state.iter_mut().zip(full_round.sbox.iter_mut()) {
212        generate_sbox(sbox_i, state_i);
213    }
214    LinearLayers::external_linear_layer(state);
215    full_round
216        .post
217        .iter_mut()
218        .zip(*state)
219        .for_each(|(post, x)| {
220            post.write(x);
221        });
222}
223
224#[inline]
225fn generate_partial_round<
226    F: PrimeField,
227    LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
228    const WIDTH: usize,
229    const SBOX_DEGREE: u64,
230    const SBOX_REGISTERS: usize,
231>(
232    state: &mut [F; WIDTH],
233    partial_round: &mut PartialRound<MaybeUninit<F>, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
234    round_constant: F,
235) {
236    state[0] += round_constant;
237    generate_sbox(&mut partial_round.sbox, &mut state[0]);
238    partial_round.post_sbox.write(state[0]);
239    LinearLayers::internal_linear_layer(state);
240}
241
242#[inline]
243fn generate_sbox<F: PrimeField, const DEGREE: u64, const REGISTERS: usize>(
244    sbox: &mut SBox<MaybeUninit<F>, DEGREE, REGISTERS>,
245    x: &mut F,
246) {
247    *x = match (DEGREE, REGISTERS) {
248        (3, 0) => x.cube(),
249        (5, 0) => x.exp_const_u64::<5>(),
250        (7, 0) => x.exp_const_u64::<7>(),
251        (5, 1) => {
252            let x2 = x.square();
253            let x3 = x2 * *x;
254            sbox.0[0].write(x3);
255            x3 * x2
256        }
257        (7, 1) => {
258            let x3 = x.cube();
259            sbox.0[0].write(x3);
260            x3 * x3 * *x
261        }
262        (11, 2) => {
263            let x2 = x.square();
264            let x3 = x2 * *x;
265            let x9 = x3.cube();
266            sbox.0[0].write(x3);
267            sbox.0[1].write(x9);
268            x9 * x2
269        }
270        _ => panic!(
271            "Unexpected (DEGREE, REGISTERS) of ({}, {})",
272            DEGREE, REGISTERS
273        ),
274    }
275}