halo2_base/poseidon/hasher/
state.rs

1use std::iter;
2
3use itertools::Itertools;
4
5use crate::{
6    gates::GateInstructions,
7    poseidon::hasher::{mds::SparseMDSMatrix, spec::OptimizedPoseidonSpec},
8    safe_types::SafeBool,
9    utils::ScalarField,
10    AssignedValue, Context,
11    QuantumCell::{Constant, Existing},
12};
13
14#[derive(Clone, Debug)]
15pub(crate) struct PoseidonState<F: ScalarField, const T: usize, const RATE: usize> {
16    pub(crate) s: [AssignedValue<F>; T],
17}
18
19impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonState<F, T, RATE> {
20    pub fn default(ctx: &mut Context<F>) -> Self {
21        let mut default_state = [F::ZERO; T];
22        // from Section 4.2 of https://eprint.iacr.org/2019/458.pdf
23        // • Variable-Input-Length Hashing. The capacity value is 2^64 + (o−1) where o the output length.
24        // for our transcript use cases, o = 1
25        default_state[0] = F::from_u128(1u128 << 64);
26        Self { s: default_state.map(|f| ctx.load_constant(f)) }
27    }
28
29    /// Perform permutation on this state.
30    ///
31    /// ATTETION: inputs.len() needs to be fixed at compile time.
32    /// Assume len <= inputs.len().
33    /// `inputs` is right padded.
34    /// If `len` is `None`, treat `inputs` as a fixed length array.
35    pub fn permutation(
36        &mut self,
37        ctx: &mut Context<F>,
38        gate: &impl GateInstructions<F>,
39        inputs: &[AssignedValue<F>],
40        len: Option<AssignedValue<F>>,
41        spec: &OptimizedPoseidonSpec<F, T, RATE>,
42    ) {
43        let r_f = spec.r_f / 2;
44        let mds = &spec.mds_matrices.mds.0;
45        let pre_sparse_mds = &spec.mds_matrices.pre_sparse_mds.0;
46        let sparse_matrices = &spec.mds_matrices.sparse_matrices;
47
48        // First half of the full round
49        let constants = &spec.constants.start;
50        if let Some(len) = len {
51            // Note: this doesn't mean `padded_inputs` is 0 padded because there is no constraints on `inputs[len..]`
52            let padded_inputs: [AssignedValue<F>; RATE] =
53                core::array::from_fn(
54                    |i| if i < inputs.len() { inputs[i] } else { ctx.load_zero() },
55                );
56            self.absorb_var_len_with_pre_constants(ctx, gate, padded_inputs, len, &constants[0]);
57        } else {
58            self.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]);
59        }
60        for constants in constants.iter().skip(1).take(r_f - 1) {
61            self.sbox_full(ctx, gate, constants);
62            self.apply_mds(ctx, gate, mds);
63        }
64        self.sbox_full(ctx, gate, constants.last().unwrap());
65        self.apply_mds(ctx, gate, pre_sparse_mds);
66
67        // Partial rounds
68        let constants = &spec.constants.partial;
69        for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) {
70            self.sbox_part(ctx, gate, constant);
71            self.apply_sparse_mds(ctx, gate, sparse_mds);
72        }
73
74        // Second half of the full rounds
75        let constants = &spec.constants.end;
76        for constants in constants.iter() {
77            self.sbox_full(ctx, gate, constants);
78            self.apply_mds(ctx, gate, mds);
79        }
80        self.sbox_full(ctx, gate, &[F::ZERO; T]);
81        self.apply_mds(ctx, gate, mds);
82    }
83
84    /// Constrains and set self to a specific state if `selector` is true.
85    pub fn select(
86        &mut self,
87        ctx: &mut Context<F>,
88        gate: &impl GateInstructions<F>,
89        selector: SafeBool<F>,
90        set_to: &Self,
91    ) {
92        for i in 0..T {
93            self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref());
94        }
95    }
96
97    fn x_power5_with_constant(
98        ctx: &mut Context<F>,
99        gate: &impl GateInstructions<F>,
100        x: AssignedValue<F>,
101        constant: &F,
102    ) -> AssignedValue<F> {
103        let x2 = gate.mul(ctx, x, x);
104        let x4 = gate.mul(ctx, x2, x2);
105        gate.mul_add(ctx, x, x4, Constant(*constant))
106    }
107
108    fn sbox_full(
109        &mut self,
110        ctx: &mut Context<F>,
111        gate: &impl GateInstructions<F>,
112        constants: &[F; T],
113    ) {
114        for (x, constant) in self.s.iter_mut().zip(constants.iter()) {
115            *x = Self::x_power5_with_constant(ctx, gate, *x, constant);
116        }
117    }
118
119    fn sbox_part(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>, constant: &F) {
120        let x = &mut self.s[0];
121        *x = Self::x_power5_with_constant(ctx, gate, *x, constant);
122    }
123
124    fn absorb_with_pre_constants(
125        &mut self,
126        ctx: &mut Context<F>,
127        gate: &impl GateInstructions<F>,
128        inputs: &[AssignedValue<F>],
129        pre_constants: &[F; T],
130    ) {
131        assert!(inputs.len() < T);
132
133        // Explanation of what's going on: before each round of the poseidon permutation,
134        // two things have to be added to the state: inputs (the absorbed elements) and
135        // preconstants. Imagine the state as a list of T elements, the first of which is
136        // the capacity:  |--cap--|--el1--|--el2--|--elR--|
137        // - A preconstant is added to each of all T elements (which is different for each)
138        // - The inputs are added to all elements starting from el1 (so, not to the capacity),
139        //   to as many elements as inputs are available.
140        // - To the first element for which no input is left (if any), an extra 1 is added.
141
142        // adding preconstant to the distinguished capacity element (only one)
143        self.s[0] = gate.add(ctx, self.s[0], Constant(pre_constants[0]));
144
145        // adding pre-constants and inputs to the elements for which both are available
146        for ((x, constant), input) in
147            self.s.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs.iter())
148        {
149            *x = gate.sum(ctx, [Existing(*x), Existing(*input), Constant(*constant)]);
150        }
151
152        let offset = inputs.len() + 1;
153        // adding only pre-constants when no input is left
154        for (i, (x, constant)) in
155            self.s.iter_mut().zip(pre_constants.iter()).skip(offset).enumerate()
156        {
157            *x = gate.add(ctx, *x, Constant(if i == 0 { F::ONE + constant } else { *constant }));
158            // the if idx == 0 { F::one() } else { F::zero() } is to pad the input with a single 1 and then 0s
159            // this is the padding suggested in pg 31 of https://eprint.iacr.org/2019/458.pdf and in Section 4.2 (Variable-Input-Length Hashing. The padding consists of one field element being 1, and the remaining elements being 0.)
160        }
161    }
162
163    /// Absorb inputs with a variable length.
164    ///
165    /// `inputs` is right padded.
166    fn absorb_var_len_with_pre_constants(
167        &mut self,
168        ctx: &mut Context<F>,
169        gate: &impl GateInstructions<F>,
170        inputs: [AssignedValue<F>; RATE],
171        len: AssignedValue<F>,
172        pre_constants: &[F; T],
173    ) {
174        // Explanation of what's going on: before each round of the poseidon permutation,
175        // two things have to be added to the state: inputs (the absorbed elements) and
176        // preconstants. Imagine the state as a list of T elements, the first of which is
177        // the capacity:  |--cap--|--el1--|--el2--|--elR--|
178        // - A preconstant is added to each of all T elements (which is different for each)
179        // - The inputs are added to all elements starting from el1 (so, not to the capacity),
180        //   to as many elements as inputs are available.
181        // - To the first element for which no input is left (if any), an extra 1 is added.
182
183        // Adding preconstants to the current state.
184        for (i, pre_const) in pre_constants.iter().enumerate() {
185            self.s[i] = gate.add(ctx, self.s[i], Constant(*pre_const));
186        }
187
188        // Generate a mask array where a[i] = i < len for i = 0..RATE.
189        let idx = gate.dec(ctx, len);
190        let len_indicator = gate.idx_to_indicator(ctx, idx, RATE);
191        // inputs_mask[i] = sum(len_indicator[i..])
192        let mut inputs_mask =
193            gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec();
194        inputs_mask.reverse();
195
196        let padded_inputs = inputs
197            .iter()
198            .zip(inputs_mask.iter())
199            .map(|(input, mask)| gate.mul(ctx, *input, *mask))
200            .collect_vec();
201        for i in 0..RATE {
202            // Add all inputs.
203            self.s[i + 1] = gate.add(ctx, self.s[i + 1], padded_inputs[i]);
204            // Add the extra 1 after inputs.
205            if i + 2 < T {
206                self.s[i + 2] = gate.add(ctx, self.s[i + 2], len_indicator[i]);
207            }
208        }
209        // If len == 0, inputs_mask is all 0. Then the extra 1 should be added into s[1].
210        let empty_extra_one = gate.not(ctx, inputs_mask[0]);
211        self.s[1] = gate.add(ctx, self.s[1], empty_extra_one);
212    }
213
214    fn apply_mds(
215        &mut self,
216        ctx: &mut Context<F>,
217        gate: &impl GateInstructions<F>,
218        mds: &[[F; T]; T],
219    ) {
220        let res = mds
221            .iter()
222            .map(|row| {
223                gate.inner_product(ctx, self.s.iter().copied(), row.iter().map(|c| Constant(*c)))
224            })
225            .collect::<Vec<_>>();
226
227        self.s = res.try_into().unwrap();
228    }
229
230    fn apply_sparse_mds(
231        &mut self,
232        ctx: &mut Context<F>,
233        gate: &impl GateInstructions<F>,
234        mds: &SparseMDSMatrix<F, T, RATE>,
235    ) {
236        self.s = iter::once(gate.inner_product(
237            ctx,
238            self.s.iter().copied(),
239            mds.row.iter().map(|c| Constant(*c)),
240        ))
241        .chain(
242            mds.col_hat
243                .iter()
244                .zip(self.s.iter().skip(1))
245                .map(|(coeff, state)| gate.mul_add(ctx, self.s[0], Constant(*coeff), *state)),
246        )
247        .collect::<Vec<_>>()
248        .try_into()
249        .unwrap();
250    }
251}