1use std::iter;
23use itertools::Itertools;
45use crate::{
6 gates::GateInstructions,
7 poseidon::hasher::{mds::SparseMDSMatrix, spec::OptimizedPoseidonSpec},
8 safe_types::SafeBool,
9 utils::ScalarField,
10 AssignedValue, Context,
11 QuantumCell::{Constant, Existing},
12};
1314#[derive(Clone, Debug)]
15pub(crate) struct PoseidonState<F: ScalarField, const T: usize, const RATE: usize> {
16pub(crate) s: [AssignedValue<F>; T],
17}
1819impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonState<F, T, RATE> {
20pub fn default(ctx: &mut Context<F>) -> Self {
21let mut default_state = [F::ZERO; T];
22// from Section 4.2 of https://eprint.iacr.org/2019/458.pdf
23 // • Variable-Input-Length Hashing. The capacity value is 2^64 + (o−1) where o the output length.
24 // for our transcript use cases, o = 1
25default_state[0] = F::from_u128(1u128 << 64);
26Self { s: default_state.map(|f| ctx.load_constant(f)) }
27 }
2829/// Perform permutation on this state.
30 ///
31 /// ATTETION: inputs.len() needs to be fixed at compile time.
32 /// Assume len <= inputs.len().
33 /// `inputs` is right padded.
34 /// If `len` is `None`, treat `inputs` as a fixed length array.
35pub fn permutation(
36&mut self,
37 ctx: &mut Context<F>,
38 gate: &impl GateInstructions<F>,
39 inputs: &[AssignedValue<F>],
40 len: Option<AssignedValue<F>>,
41 spec: &OptimizedPoseidonSpec<F, T, RATE>,
42 ) {
43let r_f = spec.r_f / 2;
44let mds = &spec.mds_matrices.mds.0;
45let pre_sparse_mds = &spec.mds_matrices.pre_sparse_mds.0;
46let sparse_matrices = &spec.mds_matrices.sparse_matrices;
4748// First half of the full round
49let constants = &spec.constants.start;
50if let Some(len) = len {
51// Note: this doesn't mean `padded_inputs` is 0 padded because there is no constraints on `inputs[len..]`
52let padded_inputs: [AssignedValue<F>; RATE] =
53 core::array::from_fn(
54 |i| if i < inputs.len() { inputs[i] } else { ctx.load_zero() },
55 );
56self.absorb_var_len_with_pre_constants(ctx, gate, padded_inputs, len, &constants[0]);
57 } else {
58self.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]);
59 }
60for constants in constants.iter().skip(1).take(r_f - 1) {
61self.sbox_full(ctx, gate, constants);
62self.apply_mds(ctx, gate, mds);
63 }
64self.sbox_full(ctx, gate, constants.last().unwrap());
65self.apply_mds(ctx, gate, pre_sparse_mds);
6667// Partial rounds
68let constants = &spec.constants.partial;
69for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) {
70self.sbox_part(ctx, gate, constant);
71self.apply_sparse_mds(ctx, gate, sparse_mds);
72 }
7374// Second half of the full rounds
75let constants = &spec.constants.end;
76for constants in constants.iter() {
77self.sbox_full(ctx, gate, constants);
78self.apply_mds(ctx, gate, mds);
79 }
80self.sbox_full(ctx, gate, &[F::ZERO; T]);
81self.apply_mds(ctx, gate, mds);
82 }
8384/// Constrains and set self to a specific state if `selector` is true.
85pub fn select(
86&mut self,
87 ctx: &mut Context<F>,
88 gate: &impl GateInstructions<F>,
89 selector: SafeBool<F>,
90 set_to: &Self,
91 ) {
92for i in 0..T {
93self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref());
94 }
95 }
9697fn x_power5_with_constant(
98 ctx: &mut Context<F>,
99 gate: &impl GateInstructions<F>,
100 x: AssignedValue<F>,
101 constant: &F,
102 ) -> AssignedValue<F> {
103let x2 = gate.mul(ctx, x, x);
104let x4 = gate.mul(ctx, x2, x2);
105 gate.mul_add(ctx, x, x4, Constant(*constant))
106 }
107108fn sbox_full(
109&mut self,
110 ctx: &mut Context<F>,
111 gate: &impl GateInstructions<F>,
112 constants: &[F; T],
113 ) {
114for (x, constant) in self.s.iter_mut().zip(constants.iter()) {
115*x = Self::x_power5_with_constant(ctx, gate, *x, constant);
116 }
117 }
118119fn sbox_part(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>, constant: &F) {
120let x = &mut self.s[0];
121*x = Self::x_power5_with_constant(ctx, gate, *x, constant);
122 }
123124fn absorb_with_pre_constants(
125&mut self,
126 ctx: &mut Context<F>,
127 gate: &impl GateInstructions<F>,
128 inputs: &[AssignedValue<F>],
129 pre_constants: &[F; T],
130 ) {
131assert!(inputs.len() < T);
132133// Explanation of what's going on: before each round of the poseidon permutation,
134 // two things have to be added to the state: inputs (the absorbed elements) and
135 // preconstants. Imagine the state as a list of T elements, the first of which is
136 // the capacity: |--cap--|--el1--|--el2--|--elR--|
137 // - A preconstant is added to each of all T elements (which is different for each)
138 // - The inputs are added to all elements starting from el1 (so, not to the capacity),
139 // to as many elements as inputs are available.
140 // - To the first element for which no input is left (if any), an extra 1 is added.
141142 // adding preconstant to the distinguished capacity element (only one)
143self.s[0] = gate.add(ctx, self.s[0], Constant(pre_constants[0]));
144145// adding pre-constants and inputs to the elements for which both are available
146for ((x, constant), input) in
147self.s.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs.iter())
148 {
149*x = gate.sum(ctx, [Existing(*x), Existing(*input), Constant(*constant)]);
150 }
151152let offset = inputs.len() + 1;
153// adding only pre-constants when no input is left
154for (i, (x, constant)) in
155self.s.iter_mut().zip(pre_constants.iter()).skip(offset).enumerate()
156 {
157*x = gate.add(ctx, *x, Constant(if i == 0 { F::ONE + constant } else { *constant }));
158// the if idx == 0 { F::one() } else { F::zero() } is to pad the input with a single 1 and then 0s
159 // this is the padding suggested in pg 31 of https://eprint.iacr.org/2019/458.pdf and in Section 4.2 (Variable-Input-Length Hashing. The padding consists of one field element being 1, and the remaining elements being 0.)
160}
161 }
162163/// Absorb inputs with a variable length.
164 ///
165 /// `inputs` is right padded.
166fn absorb_var_len_with_pre_constants(
167&mut self,
168 ctx: &mut Context<F>,
169 gate: &impl GateInstructions<F>,
170 inputs: [AssignedValue<F>; RATE],
171 len: AssignedValue<F>,
172 pre_constants: &[F; T],
173 ) {
174// Explanation of what's going on: before each round of the poseidon permutation,
175 // two things have to be added to the state: inputs (the absorbed elements) and
176 // preconstants. Imagine the state as a list of T elements, the first of which is
177 // the capacity: |--cap--|--el1--|--el2--|--elR--|
178 // - A preconstant is added to each of all T elements (which is different for each)
179 // - The inputs are added to all elements starting from el1 (so, not to the capacity),
180 // to as many elements as inputs are available.
181 // - To the first element for which no input is left (if any), an extra 1 is added.
182183 // Adding preconstants to the current state.
184for (i, pre_const) in pre_constants.iter().enumerate() {
185self.s[i] = gate.add(ctx, self.s[i], Constant(*pre_const));
186 }
187188// Generate a mask array where a[i] = i < len for i = 0..RATE.
189let idx = gate.dec(ctx, len);
190let len_indicator = gate.idx_to_indicator(ctx, idx, RATE);
191// inputs_mask[i] = sum(len_indicator[i..])
192let mut inputs_mask =
193 gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec();
194 inputs_mask.reverse();
195196let padded_inputs = inputs
197 .iter()
198 .zip(inputs_mask.iter())
199 .map(|(input, mask)| gate.mul(ctx, *input, *mask))
200 .collect_vec();
201for i in 0..RATE {
202// Add all inputs.
203self.s[i + 1] = gate.add(ctx, self.s[i + 1], padded_inputs[i]);
204// Add the extra 1 after inputs.
205if i + 2 < T {
206self.s[i + 2] = gate.add(ctx, self.s[i + 2], len_indicator[i]);
207 }
208 }
209// If len == 0, inputs_mask is all 0. Then the extra 1 should be added into s[1].
210let empty_extra_one = gate.not(ctx, inputs_mask[0]);
211self.s[1] = gate.add(ctx, self.s[1], empty_extra_one);
212 }
213214fn apply_mds(
215&mut self,
216 ctx: &mut Context<F>,
217 gate: &impl GateInstructions<F>,
218 mds: &[[F; T]; T],
219 ) {
220let res = mds
221 .iter()
222 .map(|row| {
223 gate.inner_product(ctx, self.s.iter().copied(), row.iter().map(|c| Constant(*c)))
224 })
225 .collect::<Vec<_>>();
226227self.s = res.try_into().unwrap();
228 }
229230fn apply_sparse_mds(
231&mut self,
232 ctx: &mut Context<F>,
233 gate: &impl GateInstructions<F>,
234 mds: &SparseMDSMatrix<F, T, RATE>,
235 ) {
236self.s = iter::once(gate.inner_product(
237 ctx,
238self.s.iter().copied(),
239 mds.row.iter().map(|c| Constant(*c)),
240 ))
241 .chain(
242 mds.col_hat
243 .iter()
244 .zip(self.s.iter().skip(1))
245 .map(|(coeff, state)| gate.mul_add(ctx, self.s[0], Constant(*coeff), *state)),
246 )
247 .collect::<Vec<_>>()
248 .try_into()
249 .unwrap();
250 }
251}