snark_verifier/util/hash/
poseidon.rs
1use halo2_base::poseidon::hasher::{mds::SparseMDSMatrix, spec::OptimizedPoseidonSpec};
4
5use crate::{
6 loader::{LoadedScalar, ScalarLoader},
7 util::{
8 arithmetic::{FieldExt, PrimeField},
9 Itertools,
10 },
11};
12use std::{iter, marker::PhantomData, mem};
13
14#[cfg(test)]
15mod tests;
16
17#[derive(Clone, Debug)]
19struct State<F: PrimeField, L, const T: usize, const RATE: usize> {
20 inner: [L; T],
21 _marker: PhantomData<F>,
22}
23
24impl<F: PrimeField, L: LoadedScalar<F>, const T: usize, const RATE: usize> State<F, L, T, RATE> {
27 fn new(inner: [L; T]) -> Self {
28 Self { inner, _marker: PhantomData }
29 }
30
31 fn default(loader: &L::Loader) -> Self {
32 let mut default_state = [F::ZERO; T];
33 default_state[0] = F::from_u128(1u128 << 64);
37 Self::new(default_state.map(|state| loader.load_const(&state)))
38 }
39
40 fn loader(&self) -> &L::Loader {
41 self.inner[0].loader()
42 }
43
44 fn power5_with_constant(value: &L, constant: &F) -> L {
45 value.loader().sum_products_with_const(&[(value, &value.square().square())], *constant)
46 }
47
48 fn sbox_full(&mut self, constants: &[F; T]) {
49 for (state, constant) in self.inner.iter_mut().zip(constants.iter()) {
50 *state = Self::power5_with_constant(state, constant);
51 }
52 }
53
54 fn sbox_part(&mut self, constant: &F) {
55 self.inner[0] = Self::power5_with_constant(&self.inner[0], constant);
56 }
57
58 fn absorb_with_pre_constants(&mut self, inputs: &[L], pre_constants: &[F; T]) {
59 assert!(inputs.len() < T);
60
61 self.inner[0] = self.loader().sum_with_const(&[&self.inner[0]], pre_constants[0]);
62 self.inner.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs).for_each(
63 |((state, constant), input)| {
64 *state = state.loader().sum_with_const(&[state, input], *constant);
65 },
66 );
67 self.inner
68 .iter_mut()
69 .zip(pre_constants.iter())
70 .skip(1 + inputs.len())
71 .enumerate()
72 .for_each(|(idx, (state, constant))| {
73 *state = state.loader().sum_with_const(
74 &[state],
75 if idx == 0 { F::ONE + constant } else { *constant },
76 );
79 });
80 }
81
82 fn apply_mds(&mut self, mds: &[[F; T]; T]) {
83 self.inner = mds
84 .iter()
85 .map(|row| {
86 self.loader()
87 .sum_with_coeff(&row.iter().cloned().zip(self.inner.iter()).collect_vec())
88 })
89 .collect_vec()
90 .try_into()
91 .unwrap();
92 }
93
94 fn apply_sparse_mds(&mut self, mds: &SparseMDSMatrix<F, T, RATE>) {
95 self.inner = iter::once(
96 self.loader()
97 .sum_with_coeff(&mds.row().iter().cloned().zip(self.inner.iter()).collect_vec()),
98 )
99 .chain(mds.col_hat().iter().zip(self.inner.iter().skip(1)).map(|(coeff, state)| {
100 self.loader().sum_with_coeff(&[(*coeff, &self.inner[0]), (F::ONE, state)])
101 }))
102 .collect_vec()
103 .try_into()
104 .unwrap();
105 }
106}
107
108#[derive(Debug)]
110pub struct Poseidon<F: PrimeField, L, const T: usize, const RATE: usize> {
111 spec: OptimizedPoseidonSpec<F, T, RATE>,
112 default_state: State<F, L, T, RATE>,
113 state: State<F, L, T, RATE>,
114 buf: Vec<L>,
115}
116
117impl<F: PrimeField, L: LoadedScalar<F>, const T: usize, const RATE: usize> Poseidon<F, L, T, RATE> {
118 pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>(
121 loader: &L::Loader,
122 ) -> Self
123 where
124 F: FieldExt,
125 {
126 let default_state = State::default(loader);
127 Self {
128 spec: OptimizedPoseidonSpec::new::<R_F, R_P, SECURE_MDS>(),
129 state: default_state.clone(),
130 default_state,
131 buf: Vec::new(),
132 }
133 }
134
135 pub fn from_spec(loader: &L::Loader, spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
137 let default_state = State::default(loader);
138 Self { spec, state: default_state.clone(), default_state, buf: Vec::new() }
139 }
140
141 pub fn clear(&mut self) {
143 self.state = self.default_state.clone();
144 self.buf.clear();
145 }
146
147 pub fn update(&mut self, elements: &[L]) {
149 self.buf.extend_from_slice(elements);
150 }
151
152 pub fn squeeze(&mut self) -> L {
155 let buf = mem::take(&mut self.buf);
156 let exact = buf.len() % RATE == 0;
157
158 for chunk in buf.chunks(RATE) {
159 self.permutation(chunk);
160 }
161 if exact {
162 self.permutation(&[]);
163 }
164
165 self.state.inner[1].clone()
166 }
167
168 fn permutation(&mut self, inputs: &[L]) {
169 let r_f = self.spec.r_f() / 2;
170 let mds = self.spec.mds_matrices().mds().as_ref();
171 let pre_sparse_mds = self.spec.mds_matrices().pre_sparse_mds().as_ref();
172 let sparse_matrices = &self.spec.mds_matrices().sparse_matrices();
173
174 let constants = self.spec.constants().start();
176 self.state.absorb_with_pre_constants(inputs, &constants[0]);
177 for constants in constants.iter().skip(1).take(r_f - 1) {
178 self.state.sbox_full(constants);
179 self.state.apply_mds(mds);
180 }
181 self.state.sbox_full(constants.last().unwrap());
182 self.state.apply_mds(pre_sparse_mds);
183
184 let constants = self.spec.constants().partial();
186 for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) {
187 self.state.sbox_part(constant);
188 self.state.apply_sparse_mds(sparse_mds);
189 }
190
191 let constants = self.spec.constants().end();
193 for constants in constants.iter() {
194 self.state.sbox_full(constants);
195 self.state.apply_mds(mds);
196 }
197 self.state.sbox_full(&[F::ZERO; T]);
198 self.state.apply_mds(mds);
199 }
200}