1use core::borrow::Borrow;
2use core::marker::PhantomData;
3
4use p3_air::{Air, AirBuilder, BaseAir};
5use p3_field::{PrimeCharacteristicRing, PrimeField};
6use p3_matrix::Matrix;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_poseidon2::GenericPoseidon2LinearLayers;
9use rand::distr::{Distribution, StandardUniform};
10use rand::rngs::SmallRng;
11use rand::{Rng, SeedableRng};
12
13use crate::columns::{Poseidon2Cols, num_cols};
14use crate::constants::RoundConstants;
15use crate::{FullRound, PartialRound, SBox, generate_trace_rows};
16
17#[derive(Debug)]
19pub struct Poseidon2Air<
20 F: PrimeCharacteristicRing,
21 LinearLayers,
22 const WIDTH: usize,
23 const SBOX_DEGREE: u64,
24 const SBOX_REGISTERS: usize,
25 const HALF_FULL_ROUNDS: usize,
26 const PARTIAL_ROUNDS: usize,
27> {
28 pub(crate) constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
29 _phantom: PhantomData<LinearLayers>,
30}
31
32impl<
33 F: PrimeCharacteristicRing,
34 LinearLayers,
35 const WIDTH: usize,
36 const SBOX_DEGREE: u64,
37 const SBOX_REGISTERS: usize,
38 const HALF_FULL_ROUNDS: usize,
39 const PARTIAL_ROUNDS: usize,
40>
41 Poseidon2Air<
42 F,
43 LinearLayers,
44 WIDTH,
45 SBOX_DEGREE,
46 SBOX_REGISTERS,
47 HALF_FULL_ROUNDS,
48 PARTIAL_ROUNDS,
49 >
50{
51 pub const fn new(
52 constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
53 ) -> Self {
54 Self {
55 constants,
56 _phantom: PhantomData,
57 }
58 }
59
60 pub fn generate_trace_rows(
61 &self,
62 num_hashes: usize,
63 extra_capacity_bits: usize,
64 ) -> RowMajorMatrix<F>
65 where
66 F: PrimeField,
67 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
68 StandardUniform: Distribution<[F; WIDTH]>,
69 {
70 let mut rng = SmallRng::seed_from_u64(1);
71 let inputs = (0..num_hashes).map(|_| rng.random()).collect();
72 generate_trace_rows::<
73 _,
74 LinearLayers,
75 WIDTH,
76 SBOX_DEGREE,
77 SBOX_REGISTERS,
78 HALF_FULL_ROUNDS,
79 PARTIAL_ROUNDS,
80 >(inputs, &self.constants, extra_capacity_bits)
81 }
82}
83
84impl<
85 F: PrimeCharacteristicRing + Sync,
86 LinearLayers: Sync,
87 const WIDTH: usize,
88 const SBOX_DEGREE: u64,
89 const SBOX_REGISTERS: usize,
90 const HALF_FULL_ROUNDS: usize,
91 const PARTIAL_ROUNDS: usize,
92> BaseAir<F>
93 for Poseidon2Air<
94 F,
95 LinearLayers,
96 WIDTH,
97 SBOX_DEGREE,
98 SBOX_REGISTERS,
99 HALF_FULL_ROUNDS,
100 PARTIAL_ROUNDS,
101 >
102{
103 fn width(&self) -> usize {
104 num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>()
105 }
106}
107
108pub(crate) fn eval<
109 AB: AirBuilder,
110 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
111 const WIDTH: usize,
112 const SBOX_DEGREE: u64,
113 const SBOX_REGISTERS: usize,
114 const HALF_FULL_ROUNDS: usize,
115 const PARTIAL_ROUNDS: usize,
116>(
117 air: &Poseidon2Air<
118 AB::F,
119 LinearLayers,
120 WIDTH,
121 SBOX_DEGREE,
122 SBOX_REGISTERS,
123 HALF_FULL_ROUNDS,
124 PARTIAL_ROUNDS,
125 >,
126 builder: &mut AB,
127 local: &Poseidon2Cols<
128 AB::Var,
129 WIDTH,
130 SBOX_DEGREE,
131 SBOX_REGISTERS,
132 HALF_FULL_ROUNDS,
133 PARTIAL_ROUNDS,
134 >,
135) {
136 let mut state: [_; WIDTH] = local.inputs.clone().map(|x| x.into());
137
138 LinearLayers::external_linear_layer(&mut state);
139
140 for round in 0..HALF_FULL_ROUNDS {
141 eval_full_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
142 &mut state,
143 &local.beginning_full_rounds[round],
144 &air.constants.beginning_full_round_constants[round],
145 builder,
146 );
147 }
148
149 for round in 0..PARTIAL_ROUNDS {
150 eval_partial_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
151 &mut state,
152 &local.partial_rounds[round],
153 &air.constants.partial_round_constants[round],
154 builder,
155 );
156 }
157
158 for round in 0..HALF_FULL_ROUNDS {
159 eval_full_round::<_, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
160 &mut state,
161 &local.ending_full_rounds[round],
162 &air.constants.ending_full_round_constants[round],
163 builder,
164 );
165 }
166}
167
168impl<
169 AB: AirBuilder,
170 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
171 const WIDTH: usize,
172 const SBOX_DEGREE: u64,
173 const SBOX_REGISTERS: usize,
174 const HALF_FULL_ROUNDS: usize,
175 const PARTIAL_ROUNDS: usize,
176> Air<AB>
177 for Poseidon2Air<
178 AB::F,
179 LinearLayers,
180 WIDTH,
181 SBOX_DEGREE,
182 SBOX_REGISTERS,
183 HALF_FULL_ROUNDS,
184 PARTIAL_ROUNDS,
185 >
186{
187 #[inline]
188 fn eval(&self, builder: &mut AB) {
189 let main = builder.main();
190 let local = main.row_slice(0).expect("The matrix is empty?");
191 let local = (*local).borrow();
192
193 eval::<_, _, WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>(
194 self, builder, local,
195 );
196 }
197}
198
199#[inline]
200fn eval_full_round<
201 AB: AirBuilder,
202 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
203 const WIDTH: usize,
204 const SBOX_DEGREE: u64,
205 const SBOX_REGISTERS: usize,
206>(
207 state: &mut [AB::Expr; WIDTH],
208 full_round: &FullRound<AB::Var, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
209 round_constants: &[AB::F; WIDTH],
210 builder: &mut AB,
211) {
212 for (i, (s, r)) in state.iter_mut().zip(round_constants.iter()).enumerate() {
213 *s += r.clone();
214 eval_sbox(&full_round.sbox[i], s, builder);
215 }
216 LinearLayers::external_linear_layer(state);
217 for (state_i, post_i) in state.iter_mut().zip(&full_round.post) {
218 builder.assert_eq(state_i.clone(), post_i.clone());
219 *state_i = post_i.clone().into();
220 }
221}
222
223#[inline]
224fn eval_partial_round<
225 AB: AirBuilder,
226 LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
227 const WIDTH: usize,
228 const SBOX_DEGREE: u64,
229 const SBOX_REGISTERS: usize,
230>(
231 state: &mut [AB::Expr; WIDTH],
232 partial_round: &PartialRound<AB::Var, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
233 round_constant: &AB::F,
234 builder: &mut AB,
235) {
236 state[0] += round_constant.clone();
237 eval_sbox(&partial_round.sbox, &mut state[0], builder);
238
239 builder.assert_eq(state[0].clone(), partial_round.post_sbox.clone());
240 state[0] = partial_round.post_sbox.clone().into();
241
242 LinearLayers::internal_linear_layer(state);
243}
244
245#[inline]
253fn eval_sbox<AB, const DEGREE: u64, const REGISTERS: usize>(
254 sbox: &SBox<AB::Var, DEGREE, REGISTERS>,
255 x: &mut AB::Expr,
256 builder: &mut AB,
257) where
258 AB: AirBuilder,
259{
260 *x = match (DEGREE, REGISTERS) {
261 (3, 0) => x.cube(),
262 (5, 0) => x.exp_const_u64::<5>(),
263 (7, 0) => x.exp_const_u64::<7>(),
264 (5, 1) => {
265 let committed_x3 = sbox.0[0].clone().into();
266 let x2 = x.square();
267 builder.assert_eq(committed_x3.clone(), x2.clone() * x.clone());
268 committed_x3 * x2
269 }
270 (7, 1) => {
271 let committed_x3 = sbox.0[0].clone().into();
272 builder.assert_eq(committed_x3.clone(), x.cube());
273 committed_x3.square() * x.clone()
274 }
275 (11, 2) => {
276 let committed_x3 = sbox.0[0].clone().into();
277 let committed_x9 = sbox.0[1].clone().into();
278 let x2 = x.square();
279 builder.assert_eq(committed_x3.clone(), x2.clone() * x.clone());
280 builder.assert_eq(committed_x9.clone(), committed_x3.cube());
281 committed_x9 * x2
282 }
283 _ => panic!(
284 "Unexpected (DEGREE, REGISTERS) of ({}, {})",
285 DEGREE, REGISTERS
286 ),
287 }
288}