p3_poseidon2_air/
air.rs

1use core::borrow::Borrow;
2use core::marker::PhantomData;
3
4use p3_air::{Air, AirBuilder, BaseAir};
5use p3_field::{PrimeCharacteristicRing, PrimeField};
6use p3_matrix::Matrix;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_poseidon2::GenericPoseidon2LinearLayers;
9use rand::distr::{Distribution, StandardUniform};
10use rand::rngs::SmallRng;
11use rand::{Rng, SeedableRng};
12
13use crate::columns::{Poseidon2Cols, num_cols};
14use crate::constants::RoundConstants;
15use crate::{FullRound, PartialRound, SBox, generate_trace_rows};
16
17/// Assumes the field size is at least 16 bits.
18#[derive(Debug)]
19pub struct Poseidon2Air<
20    F: PrimeCharacteristicRing,
21    LinearLayers,
22    const WIDTH: usize,
23    const SBOX_DEGREE: u64,
24    const SBOX_REGISTERS: usize,
25    const HALF_FULL_ROUNDS: usize,
26    const PARTIAL_ROUNDS: usize,
27> {
28    pub(crate) constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
29    _phantom: PhantomData<LinearLayers>,
30}
31
32impl<
33    F: PrimeCharacteristicRing,
34    LinearLayers,
35    const WIDTH: usize,
36    const SBOX_DEGREE: u64,
37    const SBOX_REGISTERS: usize,
38    const HALF_FULL_ROUNDS: usize,
39    const PARTIAL_ROUNDS: usize,
40>
41    Poseidon2Air<
42        F,
43        LinearLayers,
44        WIDTH,
45        SBOX_DEGREE,
46        SBOX_REGISTERS,
47        HALF_FULL_ROUNDS,
48        PARTIAL_ROUNDS,
49    >
50{
51    pub const fn new(
52        constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
53    ) -> Self {
54        Self {
55            constants,
56            _phantom: PhantomData,
57        }
58    }
59
60    pub fn generate_trace_rows(
61        &self,
62        num_hashes: usize,
63        extra_capacity_bits: usize,
64    ) -> RowMajorMatrix<F>
65    where
66        F: PrimeField,
67        LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
68        StandardUniform: Distribution<[F; WIDTH]>,
69    {
70        let mut rng = SmallRng::seed_from_u64(1);
71        let inputs = (0..num_hashes).map(|_| rng.random()).collect();
72        generate_trace_rows::<
73            _,
74            LinearLayers,
75            WIDTH,
76            SBOX_DEGREE,
77            SBOX_REGISTERS,
78            HALF_FULL_ROUNDS,
79            PARTIAL_ROUNDS,
80        >(inputs, &self.constants, extra_capacity_bits)
81    }
82}
83
84impl<
85    F: PrimeCharacteristicRing + Sync,
86    LinearLayers: Sync,
87    const WIDTH: usize,
88    const SBOX_DEGREE: u64,
89    const SBOX_REGISTERS: usize,
90    const HALF_FULL_ROUNDS: usize,
91    const PARTIAL_ROUNDS: usize,
92> BaseAir<F>
93    for Poseidon2Air<
94        F,
95        LinearLayers,
96        WIDTH,
97        SBOX_DEGREE,
98        SBOX_REGISTERS,
99        HALF_FULL_ROUNDS,
100        PARTIAL_ROUNDS,
101    >
102{
103    fn width(&self) -> usize {
104        num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>()
105    }
106}
107
108pub(crate) fn eval<
109    AB: AirBuilder,
110    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
111    const WIDTH: usize,
112    const SBOX_DEGREE: u64,
113    const SBOX_REGISTERS: usize,
114    const HALF_FULL_ROUNDS: usize,
115    const PARTIAL_ROUNDS: usize,
116>(
117    air: &Poseidon2Air<
118        AB::F,
119        LinearLayers,
120        WIDTH,
121        SBOX_DEGREE,
122        SBOX_REGISTERS,
123        HALF_FULL_ROUNDS,
124        PARTIAL_ROUNDS,
125    >,
126    builder: &mut AB,
127    local: &Poseidon2Cols<
128        AB::Var,
129        WIDTH,
130        SBOX_DEGREE,
131        SBOX_REGISTERS,
132        HALF_FULL_ROUNDS,
133        PARTIAL_ROUNDS,
134    >,
135) {
136    let mut state: [_; WIDTH] = local.inputs.clone().map(|x| x.into());
137
138    LinearLayers::external_linear_layer(&mut state);
139
140    for round in 0..HALF_FULL_ROUNDS {
141        eval_full_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
142            &mut state,
143            &local.beginning_full_rounds[round],
144            &air.constants.beginning_full_round_constants[round],
145            builder,
146        );
147    }
148
149    for round in 0..PARTIAL_ROUNDS {
150        eval_partial_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
151            &mut state,
152            &local.partial_rounds[round],
153            &air.constants.partial_round_constants[round],
154            builder,
155        );
156    }
157
158    for round in 0..HALF_FULL_ROUNDS {
159        eval_full_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
160            &mut state,
161            &local.ending_full_rounds[round],
162            &air.constants.ending_full_round_constants[round],
163            builder,
164        );
165    }
166}
167
168impl<
169    AB: AirBuilder,
170    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
171    const WIDTH: usize,
172    const SBOX_DEGREE: u64,
173    const SBOX_REGISTERS: usize,
174    const HALF_FULL_ROUNDS: usize,
175    const PARTIAL_ROUNDS: usize,
176> Air<AB>
177    for Poseidon2Air<
178        AB::F,
179        LinearLayers,
180        WIDTH,
181        SBOX_DEGREE,
182        SBOX_REGISTERS,
183        HALF_FULL_ROUNDS,
184        PARTIAL_ROUNDS,
185    >
186{
187    #[inline]
188    fn eval(&self, builder: &mut AB) {
189        let main = builder.main();
190        let local = main.row_slice(0).expect("The matrix is empty?");
191        let local = (*local).borrow();
192
193        eval::<_, _, WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>(
194            self, builder, local,
195        );
196    }
197}
198
199#[inline]
200fn eval_full_round<
201    AB: AirBuilder,
202    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
203    const WIDTH: usize,
204    const SBOX_DEGREE: u64,
205    const SBOX_REGISTERS: usize,
206>(
207    state: &mut [AB::Expr; WIDTH],
208    full_round: &FullRound<AB::Var, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
209    round_constants: &[AB::F; WIDTH],
210    builder: &mut AB,
211) {
212    for (i, (s, r)) in state.iter_mut().zip(round_constants.iter()).enumerate() {
213        *s += r.clone();
214        eval_sbox(&full_round.sbox[i], s, builder);
215    }
216    LinearLayers::external_linear_layer(state);
217    for (state_i, post_i) in state.iter_mut().zip(&full_round.post) {
218        builder.assert_eq(state_i.clone(), post_i.clone());
219        *state_i = post_i.clone().into();
220    }
221}
222
223#[inline]
224fn eval_partial_round<
225    AB: AirBuilder,
226    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
227    const WIDTH: usize,
228    const SBOX_DEGREE: u64,
229    const SBOX_REGISTERS: usize,
230>(
231    state: &mut [AB::Expr; WIDTH],
232    partial_round: &PartialRound<AB::Var, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
233    round_constant: &AB::F,
234    builder: &mut AB,
235) {
236    state[0] += round_constant.clone();
237    eval_sbox(&partial_round.sbox, &mut state[0], builder);
238
239    builder.assert_eq(state[0].clone(), partial_round.post_sbox.clone());
240    state[0] = partial_round.post_sbox.clone().into();
241
242    LinearLayers::internal_linear_layer(state);
243}
244
245/// Evaluates the S-box over a degree-1 expression `x`.
246///
247/// # Panics
248///
249/// This method panics if the number of `REGISTERS` is not chosen optimally for the given
250/// `DEGREE` or if the `DEGREE` is not supported by the S-box. The supported degrees are
251/// `3`, `5`, `7`, and `11`.
252#[inline]
253fn eval_sbox<AB, const DEGREE: u64, const REGISTERS: usize>(
254    sbox: &SBox<AB::Var, DEGREE, REGISTERS>,
255    x: &mut AB::Expr,
256    builder: &mut AB,
257) where
258    AB: AirBuilder,
259{
260    *x = match (DEGREE, REGISTERS) {
261        (3, 0) => x.cube(),
262        (5, 0) => x.exp_const_u64::<5>(),
263        (7, 0) => x.exp_const_u64::<7>(),
264        (5, 1) => {
265            let committed_x3 = sbox.0[0].clone().into();
266            let x2 = x.square();
267            builder.assert_eq(committed_x3.clone(), x2.clone() * x.clone());
268            committed_x3 * x2
269        }
270        (7, 1) => {
271            let committed_x3 = sbox.0[0].clone().into();
272            builder.assert_eq(committed_x3.clone(), x.cube());
273            committed_x3.square() * x.clone()
274        }
275        (11, 2) => {
276            let committed_x3 = sbox.0[0].clone().into();
277            let committed_x9 = sbox.0[1].clone().into();
278            let x2 = x.square();
279            builder.assert_eq(committed_x3.clone(), x2.clone() * x.clone());
280            builder.assert_eq(committed_x9.clone(), committed_x3.cube());
281            committed_x9 * x2
282        }
283        _ => panic!(
284            "Unexpected (DEGREE, REGISTERS) of ({}, {})",
285            DEGREE, REGISTERS
286        ),
287    }
288}