1use core::borrow::Borrow;
2use core::marker::PhantomData;
3
4use p3_air::{Air, AirBuilder, BaseAir};
5use p3_field::{Field, FieldAlgebra};
6use p3_matrix::Matrix;
7use p3_poseidon2::GenericPoseidon2LinearLayers;
8
9use crate::columns::{num_cols, Poseidon2Cols};
10use crate::constants::RoundConstants;
11use crate::{FullRound, PartialRound, SBox};
12
13#[derive(Debug)]
15pub struct Poseidon2Air<
16 F: Field,
17 LinearLayers,
18 const WIDTH: usize,
19 const SBOX_DEGREE: u64,
20 const SBOX_REGISTERS: usize,
21 const HALF_FULL_ROUNDS: usize,
22 const PARTIAL_ROUNDS: usize,
23> {
24 pub(crate) constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
25 _phantom: PhantomData<LinearLayers>,
26}
27
28impl<
29 F: Field,
30 LinearLayers,
31 const WIDTH: usize,
32 const SBOX_DEGREE: u64,
33 const SBOX_REGISTERS: usize,
34 const HALF_FULL_ROUNDS: usize,
35 const PARTIAL_ROUNDS: usize,
36 >
37 Poseidon2Air<
38 F,
39 LinearLayers,
40 WIDTH,
41 SBOX_DEGREE,
42 SBOX_REGISTERS,
43 HALF_FULL_ROUNDS,
44 PARTIAL_ROUNDS,
45 >
46{
47 pub fn new(constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>) -> Self {
48 Self {
49 constants,
50 _phantom: PhantomData,
51 }
52 }
53}
54
55impl<
56 F: Field,
57 LinearLayers: Sync,
58 const WIDTH: usize,
59 const SBOX_DEGREE: u64,
60 const SBOX_REGISTERS: usize,
61 const HALF_FULL_ROUNDS: usize,
62 const PARTIAL_ROUNDS: usize,
63 > BaseAir<F>
64 for Poseidon2Air<
65 F,
66 LinearLayers,
67 WIDTH,
68 SBOX_DEGREE,
69 SBOX_REGISTERS,
70 HALF_FULL_ROUNDS,
71 PARTIAL_ROUNDS,
72 >
73{
74 fn width(&self) -> usize {
75 num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>()
76 }
77}
78
79pub(crate) fn eval<
80 AB: AirBuilder,
81 LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
82 const WIDTH: usize,
83 const SBOX_DEGREE: u64,
84 const SBOX_REGISTERS: usize,
85 const HALF_FULL_ROUNDS: usize,
86 const PARTIAL_ROUNDS: usize,
87>(
88 air: &Poseidon2Air<
89 AB::F,
90 LinearLayers,
91 WIDTH,
92 SBOX_DEGREE,
93 SBOX_REGISTERS,
94 HALF_FULL_ROUNDS,
95 PARTIAL_ROUNDS,
96 >,
97 builder: &mut AB,
98 local: &Poseidon2Cols<
99 AB::Var,
100 WIDTH,
101 SBOX_DEGREE,
102 SBOX_REGISTERS,
103 HALF_FULL_ROUNDS,
104 PARTIAL_ROUNDS,
105 >,
106) {
107 let mut state: [AB::Expr; WIDTH] = local.inputs.map(|x| x.into());
108
109 LinearLayers::external_linear_layer(&mut state);
110
111 for round in 0..HALF_FULL_ROUNDS {
112 eval_full_round::<AB, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
113 &mut state,
114 &local.beginning_full_rounds[round],
115 &air.constants.beginning_full_round_constants[round],
116 builder,
117 );
118 }
119
120 for round in 0..PARTIAL_ROUNDS {
121 eval_partial_round::<AB, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
122 &mut state,
123 &local.partial_rounds[round],
124 &air.constants.partial_round_constants[round],
125 builder,
126 );
127 }
128
129 for round in 0..HALF_FULL_ROUNDS {
130 eval_full_round::<AB, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
131 &mut state,
132 &local.ending_full_rounds[round],
133 &air.constants.ending_full_round_constants[round],
134 builder,
135 );
136 }
137}
138
139impl<
140 AB: AirBuilder,
141 LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
142 const WIDTH: usize,
143 const SBOX_DEGREE: u64,
144 const SBOX_REGISTERS: usize,
145 const HALF_FULL_ROUNDS: usize,
146 const PARTIAL_ROUNDS: usize,
147 > Air<AB>
148 for Poseidon2Air<
149 AB::F,
150 LinearLayers,
151 WIDTH,
152 SBOX_DEGREE,
153 SBOX_REGISTERS,
154 HALF_FULL_ROUNDS,
155 PARTIAL_ROUNDS,
156 >
157{
158 #[inline]
159 fn eval(&self, builder: &mut AB) {
160 let main = builder.main();
161 let local = main.row_slice(0);
162 let local: &Poseidon2Cols<
163 AB::Var,
164 WIDTH,
165 SBOX_DEGREE,
166 SBOX_REGISTERS,
167 HALF_FULL_ROUNDS,
168 PARTIAL_ROUNDS,
169 > = (*local).borrow();
170
171 eval::<
172 AB,
173 LinearLayers,
174 WIDTH,
175 SBOX_DEGREE,
176 SBOX_REGISTERS,
177 HALF_FULL_ROUNDS,
178 PARTIAL_ROUNDS,
179 >(self, builder, local);
180 }
181}
182
183#[inline]
184fn eval_full_round<
185 AB: AirBuilder,
186 LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
187 const WIDTH: usize,
188 const SBOX_DEGREE: u64,
189 const SBOX_REGISTERS: usize,
190>(
191 state: &mut [AB::Expr; WIDTH],
192 full_round: &FullRound<AB::Var, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
193 round_constants: &[AB::F; WIDTH],
194 builder: &mut AB,
195) {
196 for (i, (s, r)) in state.iter_mut().zip(round_constants.iter()).enumerate() {
197 *s = s.clone() + *r;
198 eval_sbox(&full_round.sbox[i], s, builder);
199 }
200 LinearLayers::external_linear_layer(state);
201 for (state_i, post_i) in state.iter_mut().zip(full_round.post) {
202 builder.assert_eq(state_i.clone(), post_i);
203 *state_i = post_i.into();
204 }
205}
206
207#[inline]
208fn eval_partial_round<
209 AB: AirBuilder,
210 LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
211 const WIDTH: usize,
212 const SBOX_DEGREE: u64,
213 const SBOX_REGISTERS: usize,
214>(
215 state: &mut [AB::Expr; WIDTH],
216 partial_round: &PartialRound<AB::Var, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
217 round_constant: &AB::F,
218 builder: &mut AB,
219) {
220 state[0] = state[0].clone() + *round_constant;
221 eval_sbox(&partial_round.sbox, &mut state[0], builder);
222
223 builder.assert_eq(state[0].clone(), partial_round.post_sbox);
224 state[0] = partial_round.post_sbox.into();
225
226 LinearLayers::internal_linear_layer(state);
227}
228
229#[inline]
237fn eval_sbox<AB, const DEGREE: u64, const REGISTERS: usize>(
238 sbox: &SBox<AB::Var, DEGREE, REGISTERS>,
239 x: &mut AB::Expr,
240 builder: &mut AB,
241) where
242 AB: AirBuilder,
243{
244 *x = match (DEGREE, REGISTERS) {
245 (3, 0) => x.cube(),
246 (5, 0) => x.exp_const_u64::<5>(),
247 (7, 0) => x.exp_const_u64::<7>(),
248 (5, 1) => {
249 let committed_x3 = sbox.0[0].into();
250 let x2 = x.square();
251 builder.assert_eq(committed_x3.clone(), x2.clone() * x.clone());
252 committed_x3 * x2
253 }
254 (7, 1) => {
255 let committed_x3 = sbox.0[0].into();
256 builder.assert_eq(committed_x3.clone(), x.cube());
257 committed_x3.square() * x.clone()
258 }
259 (11, 2) => {
260 let committed_x3 = sbox.0[0].into();
261 let committed_x9 = sbox.0[1].into();
262 let x2 = x.square();
263 builder.assert_eq(committed_x3.clone(), x2.clone() * x.clone());
264 builder.assert_eq(committed_x9.clone(), committed_x3.cube());
265 committed_x9 * x2
266 }
267 _ => panic!(
268 "Unexpected (DEGREE, REGISTERS) of ({}, {})",
269 DEGREE, REGISTERS
270 ),
271 }
272}