halo2_base/poseidon/hasher/
mod.rs

1use crate::{
2    gates::{GateInstructions, RangeInstructions},
3    poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState},
4    safe_types::{SafeBool, SafeTypeChip},
5    utils::BigPrimeField,
6    AssignedValue, Context,
7    QuantumCell::Constant,
8    ScalarField,
9};
10
11use getset::{CopyGetters, Getters};
12use num_bigint::BigUint;
13use std::{cell::OnceCell, mem};
14
15#[cfg(test)]
16mod tests;
17
18/// Module for maximum distance separable matrix operations.
19pub mod mds;
20/// Module for poseidon specification.
21pub mod spec;
22/// Module for poseidon states.
23pub mod state;
24
25/// Stateless Poseidon hasher.
26#[derive(Clone, Debug, Getters)]
27pub struct PoseidonHasher<F: ScalarField, const T: usize, const RATE: usize> {
28    /// Spec, contains round constants and optimized matrices.
29    #[getset(get = "pub")]
30    spec: OptimizedPoseidonSpec<F, T, RATE>,
31    consts: OnceCell<PoseidonHasherConsts<F, T, RATE>>,
32}
33#[derive(Clone, Debug, Getters)]
34struct PoseidonHasherConsts<F: ScalarField, const T: usize, const RATE: usize> {
35    #[getset(get = "pub")]
36    init_state: PoseidonState<F, T, RATE>,
37    // hash of an empty input("").
38    #[getset(get = "pub")]
39    empty_hash: AssignedValue<F>,
40}
41
42impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherConsts<F, T, RATE> {
43    pub fn new(
44        ctx: &mut Context<F>,
45        gate: &impl GateInstructions<F>,
46        spec: &OptimizedPoseidonSpec<F, T, RATE>,
47    ) -> Self {
48        let init_state = PoseidonState::default(ctx);
49        let mut state = init_state.clone();
50        let empty_hash = fix_len_array_squeeze(ctx, gate, &[], &mut state, spec);
51        Self { init_state, empty_hash }
52    }
53}
54
55/// 1 logical row of compact input for Poseidon hasher.
56#[derive(Copy, Clone, Debug, Getters, CopyGetters)]
57pub struct PoseidonCompactInput<F: ScalarField, const RATE: usize> {
58    /// Right padded inputs. No constrains on paddings.
59    #[getset(get = "pub")]
60    inputs: [AssignedValue<F>; RATE],
61    /// is_final = 1 triggers squeeze.
62    #[getset(get_copy = "pub")]
63    is_final: SafeBool<F>,
64    /// Length of `inputs`.
65    #[getset(get_copy = "pub")]
66    len: AssignedValue<F>,
67}
68
69impl<F: ScalarField, const RATE: usize> PoseidonCompactInput<F, RATE> {
70    /// Create a new PoseidonCompactInput.
71    pub fn new(
72        inputs: [AssignedValue<F>; RATE],
73        is_final: SafeBool<F>,
74        len: AssignedValue<F>,
75    ) -> Self {
76        Self { inputs, is_final, len }
77    }
78
79    /// Add data validation constraints.
80    pub fn add_validation_constraints(
81        &self,
82        ctx: &mut Context<F>,
83        range: &impl RangeInstructions<F>,
84    ) {
85        range.is_less_than_safe(ctx, self.len, (RATE + 1) as u64);
86        // Invalid case: (!is_final && len != RATE) ==> !(is_final || len == RATE)
87        let is_full: AssignedValue<F> =
88            range.gate().is_equal(ctx, self.len, Constant(F::from(RATE as u64)));
89        let invalid_cond = range.gate().or(ctx, *self.is_final.as_ref(), is_full);
90        range.gate().assert_is_const(ctx, &invalid_cond, &F::ZERO);
91    }
92}
93
94/// A compact chunk input for Poseidon hasher. The end of a logical input could only be at the boundary of a chunk.
95#[derive(Clone, Debug, Getters, CopyGetters)]
96pub struct PoseidonCompactChunkInput<F: ScalarField, const RATE: usize> {
97    /// Inputs of a chunk. All witnesses will be absorbed.
98    #[getset(get = "pub")]
99    inputs: Vec<[AssignedValue<F>; RATE]>,
100    /// is_final = 1 triggers squeeze.
101    #[getset(get_copy = "pub")]
102    is_final: SafeBool<F>,
103}
104
105impl<F: ScalarField, const RATE: usize> PoseidonCompactChunkInput<F, RATE> {
106    /// Create a new PoseidonCompactInput.
107    pub fn new(inputs: Vec<[AssignedValue<F>; RATE]>, is_final: SafeBool<F>) -> Self {
108        Self { inputs, is_final }
109    }
110}
111
112/// 1 logical row of compact output for Poseidon hasher.
113#[derive(Copy, Clone, Debug, CopyGetters)]
114pub struct PoseidonCompactOutput<F: ScalarField> {
115    /// hash of 1 logical input.
116    #[getset(get_copy = "pub")]
117    hash: AssignedValue<F>,
118    /// is_final = 1 ==> this is the end of a logical input.
119    #[getset(get_copy = "pub")]
120    is_final: SafeBool<F>,
121}
122
123impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RATE> {
124    /// Create a poseidon hasher from an existing spec.
125    pub fn new(spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
126        Self { spec, consts: OnceCell::new() }
127    }
128    /// Initialize necessary consts of hasher. Must be called before any computation.
129    pub fn initialize_consts(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
130        self.consts.get_or_init(|| PoseidonHasherConsts::<F, T, RATE>::new(ctx, gate, &self.spec));
131    }
132
133    /// Clear all consts.
134    pub fn clear(&mut self) {
135        self.consts.take();
136    }
137
138    fn empty_hash(&self) -> &AssignedValue<F> {
139        self.consts.get().unwrap().empty_hash()
140    }
141    fn init_state(&self) -> &PoseidonState<F, T, RATE> {
142        self.consts.get().unwrap().init_state()
143    }
144
145    /// Constrains and returns hash of a witness array with a variable length.
146    ///
147    /// Assumes `len` is within [usize] and `len <= inputs.len()`.
148    /// * inputs: An right-padded array of [AssignedValue]. Constraints on paddings are not required.
149    /// * len: Length of `inputs`.
150    ///
151    /// Return hash of `inputs`.
152    pub fn hash_var_len_array(
153        &self,
154        ctx: &mut Context<F>,
155        range: &impl RangeInstructions<F>,
156        inputs: &[AssignedValue<F>],
157        len: AssignedValue<F>,
158    ) -> AssignedValue<F>
159    where
160        F: BigPrimeField,
161    {
162        // TODO: rewrite this using hash_compact_input.
163        let max_len = inputs.len();
164        if max_len == 0 {
165            return *self.empty_hash();
166        };
167
168        // len <= max_len --> num_of_bits(len) <= num_of_bits(max_len)
169        let num_bits = (usize::BITS - max_len.leading_zeros()) as usize;
170        // num_perm = len // RATE + 1, len_last_chunk = len % RATE
171        let (mut num_perm, len_last_chunk) = range.div_mod(ctx, len, BigUint::from(RATE), num_bits);
172        num_perm = range.gate().inc(ctx, num_perm);
173
174        let mut state = self.init_state().clone();
175        let mut result_state = state.clone();
176        for (i, chunk) in inputs.chunks(RATE).enumerate() {
177            let is_last_perm =
178                range.gate().is_equal(ctx, num_perm, Constant(F::from((i + 1) as u64)));
179            let len_chunk = range.gate().select(
180                ctx,
181                len_last_chunk,
182                Constant(F::from(RATE as u64)),
183                is_last_perm,
184            );
185
186            state.permutation(ctx, range.gate(), chunk, Some(len_chunk), &self.spec);
187            result_state.select(
188                ctx,
189                range.gate(),
190                SafeTypeChip::<F>::unsafe_to_bool(is_last_perm),
191                &state,
192            );
193        }
194        if max_len % RATE == 0 {
195            let is_last_perm = range.gate().is_equal(
196                ctx,
197                num_perm,
198                Constant(F::from((max_len / RATE + 1) as u64)),
199            );
200            let len_chunk = ctx.load_zero();
201            state.permutation(ctx, range.gate(), &[], Some(len_chunk), &self.spec);
202            result_state.select(
203                ctx,
204                range.gate(),
205                SafeTypeChip::<F>::unsafe_to_bool(is_last_perm),
206                &state,
207            );
208        }
209        result_state.s[1]
210    }
211
212    /// Constrains and returns hash of a witness array.
213    ///
214    /// * inputs: An array of [AssignedValue].
215    ///
216    /// Return hash of `inputs`.
217    pub fn hash_fix_len_array(
218        &self,
219        ctx: &mut Context<F>,
220        gate: &impl GateInstructions<F>,
221        inputs: &[AssignedValue<F>],
222    ) -> AssignedValue<F>
223    where
224        F: BigPrimeField,
225    {
226        let mut state = self.init_state().clone();
227        fix_len_array_squeeze(ctx, gate, inputs, &mut state, &self.spec)
228    }
229
230    /// Constrains and returns hashes of inputs in a compact format. Length of `compact_inputs` should be determined at compile time.
231    pub fn hash_compact_input(
232        &self,
233        ctx: &mut Context<F>,
234        gate: &impl GateInstructions<F>,
235        compact_inputs: &[PoseidonCompactInput<F, RATE>],
236    ) -> Vec<PoseidonCompactOutput<F>>
237    where
238        F: BigPrimeField,
239    {
240        let mut outputs = Vec::with_capacity(compact_inputs.len());
241        let mut state = self.init_state().clone();
242        for input in compact_inputs {
243            // Assume this is the last row of a logical input:
244            // Depending on if len == RATE.
245            let is_full = gate.is_equal(ctx, input.len, Constant(F::from(RATE as u64)));
246            // Case 1: if len != RATE.
247            state.permutation(ctx, gate, &input.inputs, Some(input.len), &self.spec);
248            // Case 2: if len == RATE, an extra permuation is needed for squeeze.
249            let mut state_2 = state.clone();
250            state_2.permutation(ctx, gate, &[], None, &self.spec);
251            // Select the result of case 1/2 depending on if len == RATE.
252            let hash = gate.select(ctx, state_2.s[1], state.s[1], is_full);
253            outputs.push(PoseidonCompactOutput { hash, is_final: input.is_final });
254            // Reset state to init_state if this is the end of a logical input.
255            // TODO: skip this if this is the last row.
256            state.select(ctx, gate, input.is_final, self.init_state());
257        }
258        outputs
259    }
260
261    /// Constrains and returns hashes of chunk inputs in a compact format. Length of `chunk_inputs` should be determined at compile time.
262    pub fn hash_compact_chunk_inputs(
263        &self,
264        ctx: &mut Context<F>,
265        gate: &impl GateInstructions<F>,
266        chunk_inputs: &[PoseidonCompactChunkInput<F, RATE>],
267    ) -> Vec<PoseidonCompactOutput<F>>
268    where
269        F: BigPrimeField,
270    {
271        let zero_witness = ctx.load_zero();
272        let mut outputs = Vec::with_capacity(chunk_inputs.len());
273        let mut state = self.init_state().clone();
274        for chunk_input in chunk_inputs {
275            let is_final = chunk_input.is_final;
276            for absorb in &chunk_input.inputs {
277                state.permutation(ctx, gate, absorb, None, &self.spec);
278            }
279            // Because the length of each absorb is always RATE. An extra permutation is needed for squeeze.
280            let mut output_state = state.clone();
281            output_state.permutation(ctx, gate, &[], None, &self.spec);
282            let hash = gate.select(ctx, output_state.s[1], zero_witness, *is_final.as_ref());
283            outputs.push(PoseidonCompactOutput { hash, is_final });
284            // Reset state to init_state if this is the end of a logical input.
285            state.select(ctx, gate, is_final, self.init_state());
286        }
287        outputs
288    }
289}
290
291/// Poseidon sponge. This is stateful.
292pub struct PoseidonSponge<F: ScalarField, const T: usize, const RATE: usize> {
293    init_state: PoseidonState<F, T, RATE>,
294    state: PoseidonState<F, T, RATE>,
295    spec: OptimizedPoseidonSpec<F, T, RATE>,
296    absorbing: Vec<AssignedValue<F>>,
297}
298
299impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonSponge<F, T, RATE> {
300    /// Create new Poseidon hasher.
301    pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>(
302        ctx: &mut Context<F>,
303    ) -> Self {
304        let init_state = PoseidonState::default(ctx);
305        let state = init_state.clone();
306        Self {
307            init_state,
308            state,
309            spec: OptimizedPoseidonSpec::new::<R_F, R_P, SECURE_MDS>(),
310            absorbing: Vec::new(),
311        }
312    }
313
314    /// Initialize a poseidon hasher from an existing spec.
315    pub fn from_spec(ctx: &mut Context<F>, spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
316        let init_state = PoseidonState::default(ctx);
317        Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() }
318    }
319
320    /// Reset state to default and clear the buffer.
321    pub fn clear(&mut self) {
322        self.state = self.init_state.clone();
323        self.absorbing.clear();
324    }
325
326    /// Store given `elements` into buffer.
327    pub fn update(&mut self, elements: &[AssignedValue<F>]) {
328        self.absorbing.extend_from_slice(elements);
329    }
330
331    /// Consume buffer and perform permutation, then output second element of
332    /// state.
333    pub fn squeeze(
334        &mut self,
335        ctx: &mut Context<F>,
336        gate: &impl GateInstructions<F>,
337    ) -> AssignedValue<F> {
338        let input_elements = mem::take(&mut self.absorbing);
339        fix_len_array_squeeze(ctx, gate, &input_elements, &mut self.state, &self.spec)
340    }
341}
342
343/// ATTETION: input_elements.len() needs to be fixed at compile time.
344fn fix_len_array_squeeze<F: ScalarField, const T: usize, const RATE: usize>(
345    ctx: &mut Context<F>,
346    gate: &impl GateInstructions<F>,
347    input_elements: &[AssignedValue<F>],
348    state: &mut PoseidonState<F, T, RATE>,
349    spec: &OptimizedPoseidonSpec<F, T, RATE>,
350) -> AssignedValue<F> {
351    let exact = input_elements.len() % RATE == 0;
352
353    for chunk in input_elements.chunks(RATE) {
354        state.permutation(ctx, gate, chunk, None, spec);
355    }
356    if exact {
357        state.permutation(ctx, gate, &[], None, spec);
358    }
359
360    state.s[1]
361}