p3_poseidon2_air/
vectorized.rs

1use alloc::vec::Vec;
2use core::borrow::{Borrow, BorrowMut};
3
4use p3_air::{Air, AirBuilder, BaseAir};
5use p3_field::{Field, PrimeField};
6use p3_matrix::dense::RowMajorMatrix;
7use p3_matrix::Matrix;
8use p3_poseidon2::GenericPoseidon2LinearLayers;
9use rand::distributions::Standard;
10use rand::prelude::Distribution;
11use rand::random;
12
13use crate::air::eval;
14use crate::constants::RoundConstants;
15use crate::{generate_vectorized_trace_rows, Poseidon2Air, Poseidon2Cols};
16
17/// A "vectorized" version of Poseidon2Cols, for computing multiple Poseidon2 permutations per row.
18#[repr(C)]
19pub struct VectorizedPoseidon2Cols<
20    T,
21    const WIDTH: usize,
22    const SBOX_DEGREE: u64,
23    const SBOX_REGISTERS: usize,
24    const HALF_FULL_ROUNDS: usize,
25    const PARTIAL_ROUNDS: usize,
26    const VECTOR_LEN: usize,
27> {
28    pub(crate) cols:
29        [Poseidon2Cols<T, WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>;
30            VECTOR_LEN],
31}
32
33impl<
34        T,
35        const WIDTH: usize,
36        const SBOX_DEGREE: u64,
37        const SBOX_REGISTERS: usize,
38        const HALF_FULL_ROUNDS: usize,
39        const PARTIAL_ROUNDS: usize,
40        const VECTOR_LEN: usize,
41    >
42    Borrow<
43        VectorizedPoseidon2Cols<
44            T,
45            WIDTH,
46            SBOX_DEGREE,
47            SBOX_REGISTERS,
48            HALF_FULL_ROUNDS,
49            PARTIAL_ROUNDS,
50            VECTOR_LEN,
51        >,
52    > for [T]
53{
54    fn borrow(
55        &self,
56    ) -> &VectorizedPoseidon2Cols<
57        T,
58        WIDTH,
59        SBOX_DEGREE,
60        SBOX_REGISTERS,
61        HALF_FULL_ROUNDS,
62        PARTIAL_ROUNDS,
63        VECTOR_LEN,
64    > {
65        // debug_assert_eq!(self.len(), NUM_COLS);
66        let (prefix, shorts, suffix) = unsafe {
67            self.align_to::<VectorizedPoseidon2Cols<
68                T,
69                WIDTH,
70                SBOX_DEGREE,
71                SBOX_REGISTERS,
72                HALF_FULL_ROUNDS,
73                PARTIAL_ROUNDS,
74                VECTOR_LEN,
75            >>()
76        };
77        debug_assert!(prefix.is_empty(), "Alignment should match");
78        debug_assert!(suffix.is_empty(), "Alignment should match");
79        debug_assert_eq!(shorts.len(), 1);
80        &shorts[0]
81    }
82}
83
84impl<
85        T,
86        const WIDTH: usize,
87        const SBOX_DEGREE: u64,
88        const SBOX_REGISTERS: usize,
89        const HALF_FULL_ROUNDS: usize,
90        const PARTIAL_ROUNDS: usize,
91        const VECTOR_LEN: usize,
92    >
93    BorrowMut<
94        VectorizedPoseidon2Cols<
95            T,
96            WIDTH,
97            SBOX_DEGREE,
98            SBOX_REGISTERS,
99            HALF_FULL_ROUNDS,
100            PARTIAL_ROUNDS,
101            VECTOR_LEN,
102        >,
103    > for [T]
104{
105    fn borrow_mut(
106        &mut self,
107    ) -> &mut VectorizedPoseidon2Cols<
108        T,
109        WIDTH,
110        SBOX_DEGREE,
111        SBOX_REGISTERS,
112        HALF_FULL_ROUNDS,
113        PARTIAL_ROUNDS,
114        VECTOR_LEN,
115    > {
116        // debug_assert_eq!(self.len(), NUM_COLS);
117        let (prefix, shorts, suffix) = unsafe {
118            self.align_to_mut::<VectorizedPoseidon2Cols<
119                T,
120                WIDTH,
121                SBOX_DEGREE,
122                SBOX_REGISTERS,
123                HALF_FULL_ROUNDS,
124                PARTIAL_ROUNDS,
125                VECTOR_LEN,
126            >>()
127        };
128        debug_assert!(prefix.is_empty(), "Alignment should match");
129        debug_assert!(suffix.is_empty(), "Alignment should match");
130        debug_assert_eq!(shorts.len(), 1);
131        &mut shorts[0]
132    }
133}
134
135/// A "vectorized" version of Poseidon2Air, for computing multiple Poseidon2 permutations per row.
136pub struct VectorizedPoseidon2Air<
137    F: Field,
138    LinearLayers,
139    const WIDTH: usize,
140    const SBOX_DEGREE: u64,
141    const SBOX_REGISTERS: usize,
142    const HALF_FULL_ROUNDS: usize,
143    const PARTIAL_ROUNDS: usize,
144    const VECTOR_LEN: usize,
145> {
146    pub(crate) air: Poseidon2Air<
147        F,
148        LinearLayers,
149        WIDTH,
150        SBOX_DEGREE,
151        SBOX_REGISTERS,
152        HALF_FULL_ROUNDS,
153        PARTIAL_ROUNDS,
154    >,
155}
156
157impl<
158        F: Field,
159        LinearLayers,
160        const WIDTH: usize,
161        const SBOX_DEGREE: u64,
162        const SBOX_REGISTERS: usize,
163        const HALF_FULL_ROUNDS: usize,
164        const PARTIAL_ROUNDS: usize,
165        const VECTOR_LEN: usize,
166    >
167    VectorizedPoseidon2Air<
168        F,
169        LinearLayers,
170        WIDTH,
171        SBOX_DEGREE,
172        SBOX_REGISTERS,
173        HALF_FULL_ROUNDS,
174        PARTIAL_ROUNDS,
175        VECTOR_LEN,
176    >
177{
178    pub fn new(constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>) -> Self {
179        Self {
180            air: Poseidon2Air::new(constants),
181        }
182    }
183
184    pub fn generate_vectorized_trace_rows(
185        &self,
186        num_hashes: usize,
187        extra_capacity_bits: usize,
188    ) -> RowMajorMatrix<F>
189    where
190        F: PrimeField,
191        LinearLayers: GenericPoseidon2LinearLayers<F, WIDTH>,
192        Standard: Distribution<[F; WIDTH]>,
193    {
194        let inputs = (0..num_hashes).map(|_| random()).collect::<Vec<_>>();
195        generate_vectorized_trace_rows::<
196            F,
197            LinearLayers,
198            WIDTH,
199            SBOX_DEGREE,
200            SBOX_REGISTERS,
201            HALF_FULL_ROUNDS,
202            PARTIAL_ROUNDS,
203            VECTOR_LEN,
204        >(inputs, &self.air.constants, extra_capacity_bits)
205    }
206}
207
208impl<
209        F: Field,
210        LinearLayers: Sync,
211        const WIDTH: usize,
212        const SBOX_DEGREE: u64,
213        const SBOX_REGISTERS: usize,
214        const HALF_FULL_ROUNDS: usize,
215        const PARTIAL_ROUNDS: usize,
216        const VECTOR_LEN: usize,
217    > BaseAir<F>
218    for VectorizedPoseidon2Air<
219        F,
220        LinearLayers,
221        WIDTH,
222        SBOX_DEGREE,
223        SBOX_REGISTERS,
224        HALF_FULL_ROUNDS,
225        PARTIAL_ROUNDS,
226        VECTOR_LEN,
227    >
228{
229    fn width(&self) -> usize {
230        self.air.width() * VECTOR_LEN
231    }
232}
233
234impl<
235        AB: AirBuilder,
236        LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
237        const WIDTH: usize,
238        const SBOX_DEGREE: u64,
239        const SBOX_REGISTERS: usize,
240        const HALF_FULL_ROUNDS: usize,
241        const PARTIAL_ROUNDS: usize,
242        const VECTOR_LEN: usize,
243    > Air<AB>
244    for VectorizedPoseidon2Air<
245        AB::F,
246        LinearLayers,
247        WIDTH,
248        SBOX_DEGREE,
249        SBOX_REGISTERS,
250        HALF_FULL_ROUNDS,
251        PARTIAL_ROUNDS,
252        VECTOR_LEN,
253    >
254{
255    #[inline]
256    fn eval(&self, builder: &mut AB) {
257        let main = builder.main();
258        let local = main.row_slice(0);
259        let local: &VectorizedPoseidon2Cols<
260            AB::Var,
261            WIDTH,
262            SBOX_DEGREE,
263            SBOX_REGISTERS,
264            HALF_FULL_ROUNDS,
265            PARTIAL_ROUNDS,
266            VECTOR_LEN,
267        > = (*local).borrow();
268        for perm in &local.cols {
269            eval(&self.air, builder, perm);
270        }
271    }
272}