halo2_base/poseidon/hasher/
state.rs1use std::iter;
2
3use itertools::Itertools;
4
5use crate::{
6 gates::GateInstructions,
7 poseidon::hasher::{mds::SparseMDSMatrix, spec::OptimizedPoseidonSpec},
8 safe_types::SafeBool,
9 utils::ScalarField,
10 AssignedValue, Context,
11 QuantumCell::{Constant, Existing},
12};
13
14#[derive(Clone, Debug)]
15pub(crate) struct PoseidonState<F: ScalarField, const T: usize, const RATE: usize> {
16 pub(crate) s: [AssignedValue<F>; T],
17}
18
19impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonState<F, T, RATE> {
20 pub fn default(ctx: &mut Context<F>) -> Self {
21 let mut default_state = [F::ZERO; T];
22 default_state[0] = F::from_u128(1u128 << 64);
26 Self { s: default_state.map(|f| ctx.load_constant(f)) }
27 }
28
29 pub fn permutation(
36 &mut self,
37 ctx: &mut Context<F>,
38 gate: &impl GateInstructions<F>,
39 inputs: &[AssignedValue<F>],
40 len: Option<AssignedValue<F>>,
41 spec: &OptimizedPoseidonSpec<F, T, RATE>,
42 ) {
43 let r_f = spec.r_f / 2;
44 let mds = &spec.mds_matrices.mds.0;
45 let pre_sparse_mds = &spec.mds_matrices.pre_sparse_mds.0;
46 let sparse_matrices = &spec.mds_matrices.sparse_matrices;
47
48 let constants = &spec.constants.start;
50 if let Some(len) = len {
51 let padded_inputs: [AssignedValue<F>; RATE] =
53 core::array::from_fn(
54 |i| if i < inputs.len() { inputs[i] } else { ctx.load_zero() },
55 );
56 self.absorb_var_len_with_pre_constants(ctx, gate, padded_inputs, len, &constants[0]);
57 } else {
58 self.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]);
59 }
60 for constants in constants.iter().skip(1).take(r_f - 1) {
61 self.sbox_full(ctx, gate, constants);
62 self.apply_mds(ctx, gate, mds);
63 }
64 self.sbox_full(ctx, gate, constants.last().unwrap());
65 self.apply_mds(ctx, gate, pre_sparse_mds);
66
67 let constants = &spec.constants.partial;
69 for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) {
70 self.sbox_part(ctx, gate, constant);
71 self.apply_sparse_mds(ctx, gate, sparse_mds);
72 }
73
74 let constants = &spec.constants.end;
76 for constants in constants.iter() {
77 self.sbox_full(ctx, gate, constants);
78 self.apply_mds(ctx, gate, mds);
79 }
80 self.sbox_full(ctx, gate, &[F::ZERO; T]);
81 self.apply_mds(ctx, gate, mds);
82 }
83
84 pub fn select(
86 &mut self,
87 ctx: &mut Context<F>,
88 gate: &impl GateInstructions<F>,
89 selector: SafeBool<F>,
90 set_to: &Self,
91 ) {
92 for i in 0..T {
93 self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref());
94 }
95 }
96
97 fn x_power5_with_constant(
98 ctx: &mut Context<F>,
99 gate: &impl GateInstructions<F>,
100 x: AssignedValue<F>,
101 constant: &F,
102 ) -> AssignedValue<F> {
103 let x2 = gate.mul(ctx, x, x);
104 let x4 = gate.mul(ctx, x2, x2);
105 gate.mul_add(ctx, x, x4, Constant(*constant))
106 }
107
108 fn sbox_full(
109 &mut self,
110 ctx: &mut Context<F>,
111 gate: &impl GateInstructions<F>,
112 constants: &[F; T],
113 ) {
114 for (x, constant) in self.s.iter_mut().zip(constants.iter()) {
115 *x = Self::x_power5_with_constant(ctx, gate, *x, constant);
116 }
117 }
118
119 fn sbox_part(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>, constant: &F) {
120 let x = &mut self.s[0];
121 *x = Self::x_power5_with_constant(ctx, gate, *x, constant);
122 }
123
124 fn absorb_with_pre_constants(
125 &mut self,
126 ctx: &mut Context<F>,
127 gate: &impl GateInstructions<F>,
128 inputs: &[AssignedValue<F>],
129 pre_constants: &[F; T],
130 ) {
131 assert!(inputs.len() < T);
132
133 self.s[0] = gate.add(ctx, self.s[0], Constant(pre_constants[0]));
144
145 for ((x, constant), input) in
147 self.s.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs.iter())
148 {
149 *x = gate.sum(ctx, [Existing(*x), Existing(*input), Constant(*constant)]);
150 }
151
152 let offset = inputs.len() + 1;
153 for (i, (x, constant)) in
155 self.s.iter_mut().zip(pre_constants.iter()).skip(offset).enumerate()
156 {
157 *x = gate.add(ctx, *x, Constant(if i == 0 { F::ONE + constant } else { *constant }));
158 }
161 }
162
163 fn absorb_var_len_with_pre_constants(
167 &mut self,
168 ctx: &mut Context<F>,
169 gate: &impl GateInstructions<F>,
170 inputs: [AssignedValue<F>; RATE],
171 len: AssignedValue<F>,
172 pre_constants: &[F; T],
173 ) {
174 for (i, pre_const) in pre_constants.iter().enumerate() {
185 self.s[i] = gate.add(ctx, self.s[i], Constant(*pre_const));
186 }
187
188 let idx = gate.dec(ctx, len);
190 let len_indicator = gate.idx_to_indicator(ctx, idx, RATE);
191 let mut inputs_mask =
193 gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec();
194 inputs_mask.reverse();
195
196 let padded_inputs = inputs
197 .iter()
198 .zip(inputs_mask.iter())
199 .map(|(input, mask)| gate.mul(ctx, *input, *mask))
200 .collect_vec();
201 for i in 0..RATE {
202 self.s[i + 1] = gate.add(ctx, self.s[i + 1], padded_inputs[i]);
204 if i + 2 < T {
206 self.s[i + 2] = gate.add(ctx, self.s[i + 2], len_indicator[i]);
207 }
208 }
209 let empty_extra_one = gate.not(ctx, inputs_mask[0]);
211 self.s[1] = gate.add(ctx, self.s[1], empty_extra_one);
212 }
213
214 fn apply_mds(
215 &mut self,
216 ctx: &mut Context<F>,
217 gate: &impl GateInstructions<F>,
218 mds: &[[F; T]; T],
219 ) {
220 let res = mds
221 .iter()
222 .map(|row| {
223 gate.inner_product(ctx, self.s.iter().copied(), row.iter().map(|c| Constant(*c)))
224 })
225 .collect::<Vec<_>>();
226
227 self.s = res.try_into().unwrap();
228 }
229
230 fn apply_sparse_mds(
231 &mut self,
232 ctx: &mut Context<F>,
233 gate: &impl GateInstructions<F>,
234 mds: &SparseMDSMatrix<F, T, RATE>,
235 ) {
236 self.s = iter::once(gate.inner_product(
237 ctx,
238 self.s.iter().copied(),
239 mds.row.iter().map(|c| Constant(*c)),
240 ))
241 .chain(
242 mds.col_hat
243 .iter()
244 .zip(self.s.iter().skip(1))
245 .map(|(coeff, state)| gate.mul_add(ctx, self.s[0], Constant(*coeff), *state)),
246 )
247 .collect::<Vec<_>>()
248 .try_into()
249 .unwrap();
250 }
251}