1use alloc::vec::Vec;
2use core::mem::MaybeUninit;
3
4use p3_field::PrimeField;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_maybe_rayon::prelude::*;
7use p3_poseidon2::GenericPoseidon2LinearLayers;
8use tracing::instrument;
9
10use crate::columns::{num_cols, Poseidon2Cols};
11use crate::{FullRound, PartialRound, RoundConstants, SBox};
12
13#[instrument(name = "generate vectorized Poseidon2 trace", skip_all)]
14pub fn generate_vectorized_trace_rows<
15 F: PrimeField,
16 LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
17 const WIDTH: usize,
18 const SBOX_DEGREE: u64,
19 const SBOX_REGISTERS: usize,
20 const HALF_FULL_ROUNDS: usize,
21 const PARTIAL_ROUNDS: usize,
22 const VECTOR_LEN: usize,
23>(
24 inputs: Vec<[F; WIDTH]>,
25 round_constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
26 extra_capacity_bits: usize,
27) -> RowMajorMatrix<F> {
28 let n = inputs.len();
29 assert!(
30 n % VECTOR_LEN == 0 && (n / VECTOR_LEN).is_power_of_two(),
31 "Callers expected to pad inputs to VECTOR_LEN times a power of two"
32 );
33
34 let nrows = n.div_ceil(VECTOR_LEN);
35 let ncols = num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>()
36 * VECTOR_LEN;
37 let mut vec = Vec::with_capacity((nrows * ncols) << extra_capacity_bits);
38 let trace: &mut [MaybeUninit<F>] = &mut vec.spare_capacity_mut()[..nrows * ncols];
39 let trace: RowMajorMatrixViewMut<MaybeUninit<F>> = RowMajorMatrixViewMut::new(trace, ncols);
40
41 let (prefix, perms, suffix) = unsafe {
42 trace.values.align_to_mut::<Poseidon2Cols<
43 MaybeUninit<F>,
44 WIDTH,
45 SBOX_DEGREE,
46 SBOX_REGISTERS,
47 HALF_FULL_ROUNDS,
48 PARTIAL_ROUNDS,
49 >>()
50 };
51 assert!(prefix.is_empty(), "Alignment should match");
52 assert!(suffix.is_empty(), "Alignment should match");
53 assert_eq!(perms.len(), n);
54
55 perms.par_iter_mut().zip(inputs).for_each(|(perm, input)| {
56 generate_trace_rows_for_perm::<
57 F,
58 LinearLayers,
59 WIDTH,
60 SBOX_DEGREE,
61 SBOX_REGISTERS,
62 HALF_FULL_ROUNDS,
63 PARTIAL_ROUNDS,
64 >(perm, input, round_constants);
65 });
66
67 unsafe {
68 vec.set_len(nrows * ncols);
69 }
70
71 RowMajorMatrix::new(vec, ncols)
72}
73
74#[instrument(name = "generate Poseidon2 trace", skip_all)]
76pub fn generate_trace_rows<
77 F: PrimeField,
78 LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
79 const WIDTH: usize,
80 const SBOX_DEGREE: u64,
81 const SBOX_REGISTERS: usize,
82 const HALF_FULL_ROUNDS: usize,
83 const PARTIAL_ROUNDS: usize,
84>(
85 inputs: Vec<[F; WIDTH]>,
86 constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
87) -> RowMajorMatrix<F> {
88 let n = inputs.len();
89 assert!(
90 n.is_power_of_two(),
91 "Callers expected to pad inputs to a power of two"
92 );
93
94 let ncols = num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>();
95 let mut vec = Vec::with_capacity(n * ncols * 2);
96 let trace: &mut [MaybeUninit<F>] = &mut vec.spare_capacity_mut()[..n * ncols];
97 let trace: RowMajorMatrixViewMut<MaybeUninit<F>> = RowMajorMatrixViewMut::new(trace, ncols);
98
99 let (prefix, perms, suffix) = unsafe {
100 trace.values.align_to_mut::<Poseidon2Cols<
101 MaybeUninit<F>,
102 WIDTH,
103 SBOX_DEGREE,
104 SBOX_REGISTERS,
105 HALF_FULL_ROUNDS,
106 PARTIAL_ROUNDS,
107 >>()
108 };
109 assert!(prefix.is_empty(), "Alignment should match");
110 assert!(suffix.is_empty(), "Alignment should match");
111 assert_eq!(perms.len(), n);
112
113 perms.par_iter_mut().zip(inputs).for_each(|(perm, input)| {
114 generate_trace_rows_for_perm::<
115 F,
116 LinearLayers,
117 WIDTH,
118 SBOX_DEGREE,
119 SBOX_REGISTERS,
120 HALF_FULL_ROUNDS,
121 PARTIAL_ROUNDS,
122 >(perm, input, constants);
123 });
124
125 unsafe {
126 vec.set_len(n * ncols);
127 }
128
129 RowMajorMatrix::new(vec, ncols)
130}
131
132fn generate_trace_rows_for_perm<
134 F: PrimeField,
135 LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
136 const WIDTH: usize,
137 const SBOX_DEGREE: u64,
138 const SBOX_REGISTERS: usize,
139 const HALF_FULL_ROUNDS: usize,
140 const PARTIAL_ROUNDS: usize,
141>(
142 perm: &mut Poseidon2Cols<
143 MaybeUninit<F>,
144 WIDTH,
145 SBOX_DEGREE,
146 SBOX_REGISTERS,
147 HALF_FULL_ROUNDS,
148 PARTIAL_ROUNDS,
149 >,
150 mut state: [F; WIDTH],
151 constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
152) {
153 perm.export.write(F::ONE);
154 perm.inputs
155 .iter_mut()
156 .zip(state.iter())
157 .for_each(|(input, &x)| {
158 input.write(x);
159 });
160
161 LinearLayers::external_linear_layer(&mut state);
162
163 for (full_round, constants) in perm
164 .beginning_full_rounds
165 .iter_mut()
166 .zip(&constants.beginning_full_round_constants)
167 {
168 generate_full_round::<F, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
169 &mut state, full_round, constants,
170 );
171 }
172
173 for (partial_round, constant) in perm
174 .partial_rounds
175 .iter_mut()
176 .zip(&constants.partial_round_constants)
177 {
178 generate_partial_round::<F, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
179 &mut state,
180 partial_round,
181 *constant,
182 );
183 }
184
185 for (full_round, constants) in perm
186 .ending_full_rounds
187 .iter_mut()
188 .zip(&constants.ending_full_round_constants)
189 {
190 generate_full_round::<F, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
191 &mut state, full_round, constants,
192 );
193 }
194}
195
196#[inline]
197fn generate_full_round<
198 F: PrimeField,
199 LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
200 const WIDTH: usize,
201 const SBOX_DEGREE: u64,
202 const SBOX_REGISTERS: usize,
203>(
204 state: &mut [F; WIDTH],
205 full_round: &mut FullRound<MaybeUninit<F>, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
206 round_constants: &[F; WIDTH],
207) {
208 for (state_i, const_i) in state.iter_mut().zip(round_constants) {
209 *state_i += *const_i;
210 }
211 for (state_i, sbox_i) in state.iter_mut().zip(full_round.sbox.iter_mut()) {
212 generate_sbox(sbox_i, state_i);
213 }
214 LinearLayers::external_linear_layer(state);
215 full_round
216 .post
217 .iter_mut()
218 .zip(*state)
219 .for_each(|(post, x)| {
220 post.write(x);
221 });
222}
223
224#[inline]
225fn generate_partial_round<
226 F: PrimeField,
227 LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
228 const WIDTH: usize,
229 const SBOX_DEGREE: u64,
230 const SBOX_REGISTERS: usize,
231>(
232 state: &mut [F; WIDTH],
233 partial_round: &mut PartialRound<MaybeUninit<F>, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
234 round_constant: F,
235) {
236 state[0] += round_constant;
237 generate_sbox(&mut partial_round.sbox, &mut state[0]);
238 partial_round.post_sbox.write(state[0]);
239 LinearLayers::internal_linear_layer(state);
240}
241
242#[inline]
243fn generate_sbox<F: PrimeField, const DEGREE: u64, const REGISTERS: usize>(
244 sbox: &mut SBox<MaybeUninit<F>, DEGREE, REGISTERS>,
245 x: &mut F,
246) {
247 *x = match (DEGREE, REGISTERS) {
248 (3, 0) => x.cube(),
249 (5, 0) => x.exp_const_u64::<5>(),
250 (7, 0) => x.exp_const_u64::<7>(),
251 (5, 1) => {
252 let x2 = x.square();
253 let x3 = x2 * *x;
254 sbox.0[0].write(x3);
255 x3 * x2
256 }
257 (7, 1) => {
258 let x3 = x.cube();
259 sbox.0[0].write(x3);
260 x3 * x3 * *x
261 }
262 (11, 2) => {
263 let x2 = x.square();
264 let x3 = x2 * *x;
265 let x9 = x3.cube();
266 sbox.0[0].write(x3);
267 sbox.0[1].write(x9);
268 x9 * x2
269 }
270 _ => panic!(
271 "Unexpected (DEGREE, REGISTERS) of ({}, {})",
272 DEGREE, REGISTERS
273 ),
274 }
275}