1use alloc::vec::Vec;
2use core::borrow::{Borrow, BorrowMut};
3
4use p3_air::{Air, AirBuilder, BaseAir};
5use p3_field::{Field, PrimeField};
6use p3_matrix::dense::RowMajorMatrix;
7use p3_matrix::Matrix;
8use p3_poseidon2::GenericPoseidon2LinearLayers;
9use rand::distributions::Standard;
10use rand::prelude::Distribution;
11use rand::random;
12
13use crate::air::eval;
14use crate::constants::RoundConstants;
15use crate::{generate_vectorized_trace_rows, Poseidon2Air, Poseidon2Cols};
16
17#[repr(C)]
19pub struct VectorizedPoseidon2Cols<
20 T,
21 const WIDTH: usize,
22 const SBOX_DEGREE: u64,
23 const SBOX_REGISTERS: usize,
24 const HALF_FULL_ROUNDS: usize,
25 const PARTIAL_ROUNDS: usize,
26 const VECTOR_LEN: usize,
27> {
28 pub(crate) cols:
29 [Poseidon2Cols<T, WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>;
30 VECTOR_LEN],
31}
32
33impl<
34 T,
35 const WIDTH: usize,
36 const SBOX_DEGREE: u64,
37 const SBOX_REGISTERS: usize,
38 const HALF_FULL_ROUNDS: usize,
39 const PARTIAL_ROUNDS: usize,
40 const VECTOR_LEN: usize,
41 >
42 Borrow<
43 VectorizedPoseidon2Cols<
44 T,
45 WIDTH,
46 SBOX_DEGREE,
47 SBOX_REGISTERS,
48 HALF_FULL_ROUNDS,
49 PARTIAL_ROUNDS,
50 VECTOR_LEN,
51 >,
52 > for [T]
53{
54 fn borrow(
55 &self,
56 ) -> &VectorizedPoseidon2Cols<
57 T,
58 WIDTH,
59 SBOX_DEGREE,
60 SBOX_REGISTERS,
61 HALF_FULL_ROUNDS,
62 PARTIAL_ROUNDS,
63 VECTOR_LEN,
64 > {
65 let (prefix, shorts, suffix) = unsafe {
67 self.align_to::<VectorizedPoseidon2Cols<
68 T,
69 WIDTH,
70 SBOX_DEGREE,
71 SBOX_REGISTERS,
72 HALF_FULL_ROUNDS,
73 PARTIAL_ROUNDS,
74 VECTOR_LEN,
75 >>()
76 };
77 debug_assert!(prefix.is_empty(), "Alignment should match");
78 debug_assert!(suffix.is_empty(), "Alignment should match");
79 debug_assert_eq!(shorts.len(), 1);
80 &shorts[0]
81 }
82}
83
84impl<
85 T,
86 const WIDTH: usize,
87 const SBOX_DEGREE: u64,
88 const SBOX_REGISTERS: usize,
89 const HALF_FULL_ROUNDS: usize,
90 const PARTIAL_ROUNDS: usize,
91 const VECTOR_LEN: usize,
92 >
93 BorrowMut<
94 VectorizedPoseidon2Cols<
95 T,
96 WIDTH,
97 SBOX_DEGREE,
98 SBOX_REGISTERS,
99 HALF_FULL_ROUNDS,
100 PARTIAL_ROUNDS,
101 VECTOR_LEN,
102 >,
103 > for [T]
104{
105 fn borrow_mut(
106 &mut self,
107 ) -> &mut VectorizedPoseidon2Cols<
108 T,
109 WIDTH,
110 SBOX_DEGREE,
111 SBOX_REGISTERS,
112 HALF_FULL_ROUNDS,
113 PARTIAL_ROUNDS,
114 VECTOR_LEN,
115 > {
116 let (prefix, shorts, suffix) = unsafe {
118 self.align_to_mut::<VectorizedPoseidon2Cols<
119 T,
120 WIDTH,
121 SBOX_DEGREE,
122 SBOX_REGISTERS,
123 HALF_FULL_ROUNDS,
124 PARTIAL_ROUNDS,
125 VECTOR_LEN,
126 >>()
127 };
128 debug_assert!(prefix.is_empty(), "Alignment should match");
129 debug_assert!(suffix.is_empty(), "Alignment should match");
130 debug_assert_eq!(shorts.len(), 1);
131 &mut shorts[0]
132 }
133}
134
135pub struct VectorizedPoseidon2Air<
137 F: Field,
138 LinearLayers,
139 const WIDTH: usize,
140 const SBOX_DEGREE: u64,
141 const SBOX_REGISTERS: usize,
142 const HALF_FULL_ROUNDS: usize,
143 const PARTIAL_ROUNDS: usize,
144 const VECTOR_LEN: usize,
145> {
146 pub(crate) air: Poseidon2Air<
147 F,
148 LinearLayers,
149 WIDTH,
150 SBOX_DEGREE,
151 SBOX_REGISTERS,
152 HALF_FULL_ROUNDS,
153 PARTIAL_ROUNDS,
154 >,
155}
156
157impl<
158 F: Field,
159 LinearLayers,
160 const WIDTH: usize,
161 const SBOX_DEGREE: u64,
162 const SBOX_REGISTERS: usize,
163 const HALF_FULL_ROUNDS: usize,
164 const PARTIAL_ROUNDS: usize,
165 const VECTOR_LEN: usize,
166 >
167 VectorizedPoseidon2Air<
168 F,
169 LinearLayers,
170 WIDTH,
171 SBOX_DEGREE,
172 SBOX_REGISTERS,
173 HALF_FULL_ROUNDS,
174 PARTIAL_ROUNDS,
175 VECTOR_LEN,
176 >
177{
178 pub fn new(constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>) -> Self {
179 Self {
180 air: Poseidon2Air::new(constants),
181 }
182 }
183
184 pub fn generate_vectorized_trace_rows(
185 &self,
186 num_hashes: usize,
187 extra_capacity_bits: usize,
188 ) -> RowMajorMatrix<F>
189 where
190 F: PrimeField,
191 LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
192 Standard: Distribution<[F; WIDTH]>,
193 {
194 let inputs = (0..num_hashes).map(|_| random()).collect::<Vec<_>>();
195 generate_vectorized_trace_rows::<
196 F,
197 LinearLayers,
198 WIDTH,
199 SBOX_DEGREE,
200 SBOX_REGISTERS,
201 HALF_FULL_ROUNDS,
202 PARTIAL_ROUNDS,
203 VECTOR_LEN,
204 >(inputs, &self.air.constants, extra_capacity_bits)
205 }
206}
207
208impl<
209 F: Field,
210 LinearLayers: Sync,
211 const WIDTH: usize,
212 const SBOX_DEGREE: u64,
213 const SBOX_REGISTERS: usize,
214 const HALF_FULL_ROUNDS: usize,
215 const PARTIAL_ROUNDS: usize,
216 const VECTOR_LEN: usize,
217 > BaseAir<F>
218 for VectorizedPoseidon2Air<
219 F,
220 LinearLayers,
221 WIDTH,
222 SBOX_DEGREE,
223 SBOX_REGISTERS,
224 HALF_FULL_ROUNDS,
225 PARTIAL_ROUNDS,
226 VECTOR_LEN,
227 >
228{
229 fn width(&self) -> usize {
230 self.air.width() * VECTOR_LEN
231 }
232}
233
234impl<
235 AB: AirBuilder,
236 LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
237 const WIDTH: usize,
238 const SBOX_DEGREE: u64,
239 const SBOX_REGISTERS: usize,
240 const HALF_FULL_ROUNDS: usize,
241 const PARTIAL_ROUNDS: usize,
242 const VECTOR_LEN: usize,
243 > Air<AB>
244 for VectorizedPoseidon2Air<
245 AB::F,
246 LinearLayers,
247 WIDTH,
248 SBOX_DEGREE,
249 SBOX_REGISTERS,
250 HALF_FULL_ROUNDS,
251 PARTIAL_ROUNDS,
252 VECTOR_LEN,
253 >
254{
255 #[inline]
256 fn eval(&self, builder: &mut AB) {
257 let main = builder.main();
258 let local = main.row_slice(0);
259 let local: &VectorizedPoseidon2Cols<
260 AB::Var,
261 WIDTH,
262 SBOX_DEGREE,
263 SBOX_REGISTERS,
264 HALF_FULL_ROUNDS,
265 PARTIAL_ROUNDS,
266 VECTOR_LEN,
267 > = (*local).borrow();
268 for perm in &local.cols {
269 eval(&self.air, builder, perm);
270 }
271 }
272}