1use alloc::vec::Vec;
2use core::mem::MaybeUninit;
3
4use p3_field::PrimeField;
5use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
6use p3_maybe_rayon::prelude::*;
7use p3_poseidon2::GenericPoseidon2LinearLayers;
8use tracing::instrument;
9
10use crate::columns::{Poseidon2Cols, num_cols};
11use crate::{FullRound, PartialRound, RoundConstants, SBox};
12
13#[instrument(name = "generate vectorized Poseidon2 trace", skip_all)]
14pub fn generate_vectorized_trace_rows<
15 F: PrimeField,
16 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
17 const WIDTH: usize,
18 const SBOX_DEGREE: u64,
19 const SBOX_REGISTERS: usize,
20 const HALF_FULL_ROUNDS: usize,
21 const PARTIAL_ROUNDS: usize,
22 const VECTOR_LEN: usize,
23>(
24 inputs: Vec<[F; WIDTH]>,
25 round_constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
26 extra_capacity_bits: usize,
27) -> RowMajorMatrix<F> {
28 let n = inputs.len();
29 assert!(
30 n.is_multiple_of(VECTOR_LEN) && (n / VECTOR_LEN).is_power_of_two(),
31 "Callers expected to pad inputs to VECTOR_LEN times a power of two"
32 );
33
34 let nrows = n.div_ceil(VECTOR_LEN);
35 let ncols = num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>()
36 * VECTOR_LEN;
37 let mut vec = Vec::with_capacity((nrows * ncols) << extra_capacity_bits);
38 let trace = &mut vec.spare_capacity_mut()[..nrows * ncols];
39 let trace = RowMajorMatrixViewMut::new(trace, ncols);
40
41 let (prefix, perms, suffix) = unsafe {
42 trace.values.align_to_mut::<Poseidon2Cols<
43 MaybeUninit<F>,
44 WIDTH,
45 SBOX_DEGREE,
46 SBOX_REGISTERS,
47 HALF_FULL_ROUNDS,
48 PARTIAL_ROUNDS,
49 >>()
50 };
51 assert!(prefix.is_empty(), "Alignment should match");
52 assert!(suffix.is_empty(), "Alignment should match");
53 assert_eq!(perms.len(), n);
54
55 perms.par_iter_mut().zip(inputs).for_each(|(perm, input)| {
56 generate_trace_rows_for_perm::<
57 F,
58 LinearLayers,
59 WIDTH,
60 SBOX_DEGREE,
61 SBOX_REGISTERS,
62 HALF_FULL_ROUNDS,
63 PARTIAL_ROUNDS,
64 >(perm, input, round_constants);
65 });
66
67 unsafe {
68 vec.set_len(nrows * ncols);
69 }
70
71 RowMajorMatrix::new(vec, ncols)
72}
73
74#[instrument(name = "generate Poseidon2 trace", skip_all)]
76pub fn generate_trace_rows<
77 F: PrimeField,
78 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
79 const WIDTH: usize,
80 const SBOX_DEGREE: u64,
81 const SBOX_REGISTERS: usize,
82 const HALF_FULL_ROUNDS: usize,
83 const PARTIAL_ROUNDS: usize,
84>(
85 inputs: Vec<[F; WIDTH]>,
86 constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
87 extra_capacity_bits: usize,
88) -> RowMajorMatrix<F> {
89 let n = inputs.len();
90 assert!(
91 n.is_power_of_two(),
92 "Callers expected to pad inputs to a power of two"
93 );
94
95 let ncols = num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>();
96 let mut vec = Vec::with_capacity((n * ncols) << extra_capacity_bits);
97 let trace = &mut vec.spare_capacity_mut()[..n * ncols];
98 let trace = RowMajorMatrixViewMut::new(trace, ncols);
99
100 let (prefix, perms, suffix) = unsafe {
101 trace.values.align_to_mut::<Poseidon2Cols<
102 MaybeUninit<F>,
103 WIDTH,
104 SBOX_DEGREE,
105 SBOX_REGISTERS,
106 HALF_FULL_ROUNDS,
107 PARTIAL_ROUNDS,
108 >>()
109 };
110 assert!(prefix.is_empty(), "Alignment should match");
111 assert!(suffix.is_empty(), "Alignment should match");
112 assert_eq!(perms.len(), n);
113
114 perms.par_iter_mut().zip(inputs).for_each(|(perm, input)| {
115 generate_trace_rows_for_perm::<
116 F,
117 LinearLayers,
118 WIDTH,
119 SBOX_DEGREE,
120 SBOX_REGISTERS,
121 HALF_FULL_ROUNDS,
122 PARTIAL_ROUNDS,
123 >(perm, input, constants);
124 });
125
126 unsafe {
127 vec.set_len(n * ncols);
128 }
129
130 RowMajorMatrix::new(vec, ncols)
131}
132
133pub fn generate_trace_rows_for_perm<
135 F: PrimeField,
136 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
137 const WIDTH: usize,
138 const SBOX_DEGREE: u64,
139 const SBOX_REGISTERS: usize,
140 const HALF_FULL_ROUNDS: usize,
141 const PARTIAL_ROUNDS: usize,
142>(
143 perm: &mut Poseidon2Cols<
144 MaybeUninit<F>,
145 WIDTH,
146 SBOX_DEGREE,
147 SBOX_REGISTERS,
148 HALF_FULL_ROUNDS,
149 PARTIAL_ROUNDS,
150 >,
151 mut state: [F; WIDTH],
152 constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
153) {
154 perm.export.write(F::ONE);
155 perm.inputs
156 .iter_mut()
157 .zip(state.iter())
158 .for_each(|(input, &x)| {
159 input.write(x);
160 });
161
162 LinearLayers::external_linear_layer(&mut state);
163
164 for (full_round, constants) in perm
165 .beginning_full_rounds
166 .iter_mut()
167 .zip(&constants.beginning_full_round_constants)
168 {
169 generate_full_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
170 &mut state, full_round, constants,
171 );
172 }
173
174 for (partial_round, constant) in perm
175 .partial_rounds
176 .iter_mut()
177 .zip(&constants.partial_round_constants)
178 {
179 generate_partial_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
180 &mut state,
181 partial_round,
182 *constant,
183 );
184 }
185
186 for (full_round, constants) in perm
187 .ending_full_rounds
188 .iter_mut()
189 .zip(&constants.ending_full_round_constants)
190 {
191 generate_full_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
192 &mut state, full_round, constants,
193 );
194 }
195}
196
197#[inline]
198fn generate_full_round<
199 F: PrimeField,
200 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
201 const WIDTH: usize,
202 const SBOX_DEGREE: u64,
203 const SBOX_REGISTERS: usize,
204>(
205 state: &mut [F; WIDTH],
206 full_round: &mut FullRound<MaybeUninit<F>, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
207 round_constants: &[F; WIDTH],
208) {
209 for ((state_i, const_i), sbox_i) in state
211 .iter_mut()
212 .zip(round_constants.iter())
213 .zip(full_round.sbox.iter_mut())
214 {
215 *state_i += *const_i;
216 generate_sbox(sbox_i, state_i);
217 }
218
219 LinearLayers::external_linear_layer(state);
220 full_round
221 .post
222 .iter_mut()
223 .zip(*state)
224 .for_each(|(post, x)| {
225 post.write(x);
226 });
227}
228
229#[inline]
230fn generate_partial_round<
231 F: PrimeField,
232 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
233 const WIDTH: usize,
234 const SBOX_DEGREE: u64,
235 const SBOX_REGISTERS: usize,
236>(
237 state: &mut [F; WIDTH],
238 partial_round: &mut PartialRound<MaybeUninit<F>, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
239 round_constant: F,
240) {
241 state[0] += round_constant;
242 generate_sbox(&mut partial_round.sbox, &mut state[0]);
243 partial_round.post_sbox.write(state[0]);
244 LinearLayers::internal_linear_layer(state);
245}
246
247#[inline]
256fn generate_sbox<F: PrimeField, const DEGREE: u64, const REGISTERS: usize>(
257 sbox: &mut SBox<MaybeUninit<F>, DEGREE, REGISTERS>,
258 x: &mut F,
259) {
260 *x = match (DEGREE, REGISTERS) {
261 (3, 0) => x.cube(),
262 (5, 0) => x.exp_const_u64::<5>(),
263 (7, 0) => x.exp_const_u64::<7>(),
264 (5, 1) => {
265 let x2 = x.square();
266 let x3 = x2 * *x;
267 sbox.0[0].write(x3);
268 x3 * x2
269 }
270 (7, 1) => {
271 let x3 = x.cube();
272 sbox.0[0].write(x3);
273 x3 * x3 * *x
274 }
275 (11, 2) => {
276 let x2 = x.square();
277 let x3 = x2 * *x;
278 let x9 = x3.cube();
279 sbox.0[0].write(x3);
280 sbox.0[1].write(x9);
281 x9 * x2
282 }
283 _ => panic!(
284 "Unexpected (DEGREE, REGISTERS) of ({}, {})",
285 DEGREE, REGISTERS
286 ),
287 }
288}