1use crate::{
2 gates::{GateInstructions, RangeInstructions},
3 poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState},
4 safe_types::{SafeBool, SafeTypeChip},
5 utils::BigPrimeField,
6 AssignedValue, Context,
7 QuantumCell::Constant,
8 ScalarField,
9};
10
11use getset::{CopyGetters, Getters};
12use num_bigint::BigUint;
13use std::{cell::OnceCell, mem};
14
15#[cfg(test)]
16mod tests;
17
18pub mod mds;
20pub mod spec;
22pub mod state;
24
25#[derive(Clone, Debug, Getters)]
27pub struct PoseidonHasher<F: ScalarField, const T: usize, const RATE: usize> {
28 #[getset(get = "pub")]
30 spec: OptimizedPoseidonSpec<F, T, RATE>,
31 consts: OnceCell<PoseidonHasherConsts<F, T, RATE>>,
32}
33#[derive(Clone, Debug, Getters)]
34struct PoseidonHasherConsts<F: ScalarField, const T: usize, const RATE: usize> {
35 #[getset(get = "pub")]
36 init_state: PoseidonState<F, T, RATE>,
37 #[getset(get = "pub")]
39 empty_hash: AssignedValue<F>,
40}
41
42impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherConsts<F, T, RATE> {
43 pub fn new(
44 ctx: &mut Context<F>,
45 gate: &impl GateInstructions<F>,
46 spec: &OptimizedPoseidonSpec<F, T, RATE>,
47 ) -> Self {
48 let init_state = PoseidonState::default(ctx);
49 let mut state = init_state.clone();
50 let empty_hash = fix_len_array_squeeze(ctx, gate, &[], &mut state, spec);
51 Self { init_state, empty_hash }
52 }
53}
54
55#[derive(Copy, Clone, Debug, Getters, CopyGetters)]
57pub struct PoseidonCompactInput<F: ScalarField, const RATE: usize> {
58 #[getset(get = "pub")]
60 inputs: [AssignedValue<F>; RATE],
61 #[getset(get_copy = "pub")]
63 is_final: SafeBool<F>,
64 #[getset(get_copy = "pub")]
66 len: AssignedValue<F>,
67}
68
69impl<F: ScalarField, const RATE: usize> PoseidonCompactInput<F, RATE> {
70 pub fn new(
72 inputs: [AssignedValue<F>; RATE],
73 is_final: SafeBool<F>,
74 len: AssignedValue<F>,
75 ) -> Self {
76 Self { inputs, is_final, len }
77 }
78
79 pub fn add_validation_constraints(
81 &self,
82 ctx: &mut Context<F>,
83 range: &impl RangeInstructions<F>,
84 ) {
85 range.is_less_than_safe(ctx, self.len, (RATE + 1) as u64);
86 let is_full: AssignedValue<F> =
88 range.gate().is_equal(ctx, self.len, Constant(F::from(RATE as u64)));
89 let invalid_cond = range.gate().or(ctx, *self.is_final.as_ref(), is_full);
90 range.gate().assert_is_const(ctx, &invalid_cond, &F::ZERO);
91 }
92}
93
94#[derive(Clone, Debug, Getters, CopyGetters)]
96pub struct PoseidonCompactChunkInput<F: ScalarField, const RATE: usize> {
97 #[getset(get = "pub")]
99 inputs: Vec<[AssignedValue<F>; RATE]>,
100 #[getset(get_copy = "pub")]
102 is_final: SafeBool<F>,
103}
104
105impl<F: ScalarField, const RATE: usize> PoseidonCompactChunkInput<F, RATE> {
106 pub fn new(inputs: Vec<[AssignedValue<F>; RATE]>, is_final: SafeBool<F>) -> Self {
108 Self { inputs, is_final }
109 }
110}
111
112#[derive(Copy, Clone, Debug, CopyGetters)]
114pub struct PoseidonCompactOutput<F: ScalarField> {
115 #[getset(get_copy = "pub")]
117 hash: AssignedValue<F>,
118 #[getset(get_copy = "pub")]
120 is_final: SafeBool<F>,
121}
122
123impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RATE> {
124 pub fn new(spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
126 Self { spec, consts: OnceCell::new() }
127 }
128 pub fn initialize_consts(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
130 self.consts.get_or_init(|| PoseidonHasherConsts::<F, T, RATE>::new(ctx, gate, &self.spec));
131 }
132
133 pub fn clear(&mut self) {
135 self.consts.take();
136 }
137
138 fn empty_hash(&self) -> &AssignedValue<F> {
139 self.consts.get().unwrap().empty_hash()
140 }
141 fn init_state(&self) -> &PoseidonState<F, T, RATE> {
142 self.consts.get().unwrap().init_state()
143 }
144
145 pub fn hash_var_len_array(
153 &self,
154 ctx: &mut Context<F>,
155 range: &impl RangeInstructions<F>,
156 inputs: &[AssignedValue<F>],
157 len: AssignedValue<F>,
158 ) -> AssignedValue<F>
159 where
160 F: BigPrimeField,
161 {
162 let max_len = inputs.len();
164 if max_len == 0 {
165 return *self.empty_hash();
166 };
167
168 let num_bits = (usize::BITS - max_len.leading_zeros()) as usize;
170 let (mut num_perm, len_last_chunk) = range.div_mod(ctx, len, BigUint::from(RATE), num_bits);
172 num_perm = range.gate().inc(ctx, num_perm);
173
174 let mut state = self.init_state().clone();
175 let mut result_state = state.clone();
176 for (i, chunk) in inputs.chunks(RATE).enumerate() {
177 let is_last_perm =
178 range.gate().is_equal(ctx, num_perm, Constant(F::from((i + 1) as u64)));
179 let len_chunk = range.gate().select(
180 ctx,
181 len_last_chunk,
182 Constant(F::from(RATE as u64)),
183 is_last_perm,
184 );
185
186 state.permutation(ctx, range.gate(), chunk, Some(len_chunk), &self.spec);
187 result_state.select(
188 ctx,
189 range.gate(),
190 SafeTypeChip::<F>::unsafe_to_bool(is_last_perm),
191 &state,
192 );
193 }
194 if max_len % RATE == 0 {
195 let is_last_perm = range.gate().is_equal(
196 ctx,
197 num_perm,
198 Constant(F::from((max_len / RATE + 1) as u64)),
199 );
200 let len_chunk = ctx.load_zero();
201 state.permutation(ctx, range.gate(), &[], Some(len_chunk), &self.spec);
202 result_state.select(
203 ctx,
204 range.gate(),
205 SafeTypeChip::<F>::unsafe_to_bool(is_last_perm),
206 &state,
207 );
208 }
209 result_state.s[1]
210 }
211
212 pub fn hash_fix_len_array(
218 &self,
219 ctx: &mut Context<F>,
220 gate: &impl GateInstructions<F>,
221 inputs: &[AssignedValue<F>],
222 ) -> AssignedValue<F>
223 where
224 F: BigPrimeField,
225 {
226 let mut state = self.init_state().clone();
227 fix_len_array_squeeze(ctx, gate, inputs, &mut state, &self.spec)
228 }
229
230 pub fn hash_compact_input(
232 &self,
233 ctx: &mut Context<F>,
234 gate: &impl GateInstructions<F>,
235 compact_inputs: &[PoseidonCompactInput<F, RATE>],
236 ) -> Vec<PoseidonCompactOutput<F>>
237 where
238 F: BigPrimeField,
239 {
240 let mut outputs = Vec::with_capacity(compact_inputs.len());
241 let mut state = self.init_state().clone();
242 for input in compact_inputs {
243 let is_full = gate.is_equal(ctx, input.len, Constant(F::from(RATE as u64)));
246 state.permutation(ctx, gate, &input.inputs, Some(input.len), &self.spec);
248 let mut state_2 = state.clone();
250 state_2.permutation(ctx, gate, &[], None, &self.spec);
251 let hash = gate.select(ctx, state_2.s[1], state.s[1], is_full);
253 outputs.push(PoseidonCompactOutput { hash, is_final: input.is_final });
254 state.select(ctx, gate, input.is_final, self.init_state());
257 }
258 outputs
259 }
260
261 pub fn hash_compact_chunk_inputs(
263 &self,
264 ctx: &mut Context<F>,
265 gate: &impl GateInstructions<F>,
266 chunk_inputs: &[PoseidonCompactChunkInput<F, RATE>],
267 ) -> Vec<PoseidonCompactOutput<F>>
268 where
269 F: BigPrimeField,
270 {
271 let zero_witness = ctx.load_zero();
272 let mut outputs = Vec::with_capacity(chunk_inputs.len());
273 let mut state = self.init_state().clone();
274 for chunk_input in chunk_inputs {
275 let is_final = chunk_input.is_final;
276 for absorb in &chunk_input.inputs {
277 state.permutation(ctx, gate, absorb, None, &self.spec);
278 }
279 let mut output_state = state.clone();
281 output_state.permutation(ctx, gate, &[], None, &self.spec);
282 let hash = gate.select(ctx, output_state.s[1], zero_witness, *is_final.as_ref());
283 outputs.push(PoseidonCompactOutput { hash, is_final });
284 state.select(ctx, gate, is_final, self.init_state());
286 }
287 outputs
288 }
289}
290
291pub struct PoseidonSponge<F: ScalarField, const T: usize, const RATE: usize> {
293 init_state: PoseidonState<F, T, RATE>,
294 state: PoseidonState<F, T, RATE>,
295 spec: OptimizedPoseidonSpec<F, T, RATE>,
296 absorbing: Vec<AssignedValue<F>>,
297}
298
299impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonSponge<F, T, RATE> {
300 pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>(
302 ctx: &mut Context<F>,
303 ) -> Self {
304 let init_state = PoseidonState::default(ctx);
305 let state = init_state.clone();
306 Self {
307 init_state,
308 state,
309 spec: OptimizedPoseidonSpec::new::<R_F, R_P, SECURE_MDS>(),
310 absorbing: Vec::new(),
311 }
312 }
313
314 pub fn from_spec(ctx: &mut Context<F>, spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
316 let init_state = PoseidonState::default(ctx);
317 Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() }
318 }
319
320 pub fn clear(&mut self) {
322 self.state = self.init_state.clone();
323 self.absorbing.clear();
324 }
325
326 pub fn update(&mut self, elements: &[AssignedValue<F>]) {
328 self.absorbing.extend_from_slice(elements);
329 }
330
331 pub fn squeeze(
334 &mut self,
335 ctx: &mut Context<F>,
336 gate: &impl GateInstructions<F>,
337 ) -> AssignedValue<F> {
338 let input_elements = mem::take(&mut self.absorbing);
339 fix_len_array_squeeze(ctx, gate, &input_elements, &mut self.state, &self.spec)
340 }
341}
342
343fn fix_len_array_squeeze<F: ScalarField, const T: usize, const RATE: usize>(
345 ctx: &mut Context<F>,
346 gate: &impl GateInstructions<F>,
347 input_elements: &[AssignedValue<F>],
348 state: &mut PoseidonState<F, T, RATE>,
349 spec: &OptimizedPoseidonSpec<F, T, RATE>,
350) -> AssignedValue<F> {
351 let exact = input_elements.len() % RATE == 0;
352
353 for chunk in input_elements.chunks(RATE) {
354 state.permutation(ctx, gate, chunk, None, spec);
355 }
356 if exact {
357 state.permutation(ctx, gate, &[], None, spec);
358 }
359
360 state.s[1]
361}