snark_verifier/util/hash/
poseidon.rs

1//! Trait based implementation of Poseidon permutation
2
3use halo2_base::poseidon::hasher::{mds::SparseMDSMatrix, spec::OptimizedPoseidonSpec};
4
5use crate::{
6    loader::{LoadedScalar, ScalarLoader},
7    util::{
8        arithmetic::{FieldExt, PrimeField},
9        Itertools,
10    },
11};
12use std::{iter, marker::PhantomData, mem};
13
14#[cfg(test)]
15mod tests;
16
17// this works for any loader, where the two loaders used are NativeLoader (native rust) and Halo2Loader (ZK circuit)
18#[derive(Clone, Debug)]
19struct State<F: PrimeField, L, const T: usize, const RATE: usize> {
20    inner: [L; T],
21    _marker: PhantomData<F>,
22}
23
24// the transcript hash implementation is the one suggested in the original paper https://eprint.iacr.org/2019/458.pdf
25// another reference implementation is https://github.com/privacy-scaling-explorations/halo2wrong/tree/master/transcript/src
26impl<F: PrimeField, L: LoadedScalar<F>, const T: usize, const RATE: usize> State<F, L, T, RATE> {
27    fn new(inner: [L; T]) -> Self {
28        Self { inner, _marker: PhantomData }
29    }
30
31    fn default(loader: &L::Loader) -> Self {
32        let mut default_state = [F::ZERO; T];
33        // from Section 4.2 of https://eprint.iacr.org/2019/458.pdf
34        // • Variable-Input-Length Hashing. The capacity value is 2^64 + (o−1) where o the output length.
35        // for our transcript use cases, o = 1
36        default_state[0] = F::from_u128(1u128 << 64);
37        Self::new(default_state.map(|state| loader.load_const(&state)))
38    }
39
40    fn loader(&self) -> &L::Loader {
41        self.inner[0].loader()
42    }
43
44    fn power5_with_constant(value: &L, constant: &F) -> L {
45        value.loader().sum_products_with_const(&[(value, &value.square().square())], *constant)
46    }
47
48    fn sbox_full(&mut self, constants: &[F; T]) {
49        for (state, constant) in self.inner.iter_mut().zip(constants.iter()) {
50            *state = Self::power5_with_constant(state, constant);
51        }
52    }
53
54    fn sbox_part(&mut self, constant: &F) {
55        self.inner[0] = Self::power5_with_constant(&self.inner[0], constant);
56    }
57
58    fn absorb_with_pre_constants(&mut self, inputs: &[L], pre_constants: &[F; T]) {
59        assert!(inputs.len() < T);
60
61        self.inner[0] = self.loader().sum_with_const(&[&self.inner[0]], pre_constants[0]);
62        self.inner.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs).for_each(
63            |((state, constant), input)| {
64                *state = state.loader().sum_with_const(&[state, input], *constant);
65            },
66        );
67        self.inner
68            .iter_mut()
69            .zip(pre_constants.iter())
70            .skip(1 + inputs.len())
71            .enumerate()
72            .for_each(|(idx, (state, constant))| {
73                *state = state.loader().sum_with_const(
74                    &[state],
75                    if idx == 0 { F::ONE + constant } else { *constant },
76                    // the if idx == 0 { F::ONE } else { F::ZERO } is to pad the input with a single 1 and then 0s
77                    // this is the padding suggested in pg 31 of https://eprint.iacr.org/2019/458.pdf and in Section 4.2 (Variable-Input-Length Hashing. The padding consists of one field element being 1, and the remaining elements being 0.)
78                );
79            });
80    }
81
82    fn apply_mds(&mut self, mds: &[[F; T]; T]) {
83        self.inner = mds
84            .iter()
85            .map(|row| {
86                self.loader()
87                    .sum_with_coeff(&row.iter().cloned().zip(self.inner.iter()).collect_vec())
88            })
89            .collect_vec()
90            .try_into()
91            .unwrap();
92    }
93
94    fn apply_sparse_mds(&mut self, mds: &SparseMDSMatrix<F, T, RATE>) {
95        self.inner = iter::once(
96            self.loader()
97                .sum_with_coeff(&mds.row().iter().cloned().zip(self.inner.iter()).collect_vec()),
98        )
99        .chain(mds.col_hat().iter().zip(self.inner.iter().skip(1)).map(|(coeff, state)| {
100            self.loader().sum_with_coeff(&[(*coeff, &self.inner[0]), (F::ONE, state)])
101        }))
102        .collect_vec()
103        .try_into()
104        .unwrap();
105    }
106}
107
108/// Poseidon hasher with configurable `RATE`.
109#[derive(Debug)]
110pub struct Poseidon<F: PrimeField, L, const T: usize, const RATE: usize> {
111    spec: OptimizedPoseidonSpec<F, T, RATE>,
112    default_state: State<F, L, T, RATE>,
113    state: State<F, L, T, RATE>,
114    buf: Vec<L>,
115}
116
117impl<F: PrimeField, L: LoadedScalar<F>, const T: usize, const RATE: usize> Poseidon<F, L, T, RATE> {
118    /// Initialize a poseidon hasher.
119    /// Generates a new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated
120    pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>(
121        loader: &L::Loader,
122    ) -> Self
123    where
124        F: FieldExt,
125    {
126        let default_state = State::default(loader);
127        Self {
128            spec: OptimizedPoseidonSpec::new::<R_F, R_P, SECURE_MDS>(),
129            state: default_state.clone(),
130            default_state,
131            buf: Vec::new(),
132        }
133    }
134
135    /// Initialize a poseidon hasher from an existing spec.
136    pub fn from_spec(loader: &L::Loader, spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
137        let default_state = State::default(loader);
138        Self { spec, state: default_state.clone(), default_state, buf: Vec::new() }
139    }
140
141    /// Reset state to default and clear the buffer.
142    pub fn clear(&mut self) {
143        self.state = self.default_state.clone();
144        self.buf.clear();
145    }
146
147    /// Store given `elements` into buffer.
148    pub fn update(&mut self, elements: &[L]) {
149        self.buf.extend_from_slice(elements);
150    }
151
152    /// Consume buffer and perform permutation, then output second element of
153    /// state.
154    pub fn squeeze(&mut self) -> L {
155        let buf = mem::take(&mut self.buf);
156        let exact = buf.len() % RATE == 0;
157
158        for chunk in buf.chunks(RATE) {
159            self.permutation(chunk);
160        }
161        if exact {
162            self.permutation(&[]);
163        }
164
165        self.state.inner[1].clone()
166    }
167
168    fn permutation(&mut self, inputs: &[L]) {
169        let r_f = self.spec.r_f() / 2;
170        let mds = self.spec.mds_matrices().mds().as_ref();
171        let pre_sparse_mds = self.spec.mds_matrices().pre_sparse_mds().as_ref();
172        let sparse_matrices = &self.spec.mds_matrices().sparse_matrices();
173
174        // First half of the full rounds
175        let constants = self.spec.constants().start();
176        self.state.absorb_with_pre_constants(inputs, &constants[0]);
177        for constants in constants.iter().skip(1).take(r_f - 1) {
178            self.state.sbox_full(constants);
179            self.state.apply_mds(mds);
180        }
181        self.state.sbox_full(constants.last().unwrap());
182        self.state.apply_mds(pre_sparse_mds);
183
184        // Partial rounds
185        let constants = self.spec.constants().partial();
186        for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) {
187            self.state.sbox_part(constant);
188            self.state.apply_sparse_mds(sparse_mds);
189        }
190
191        // Second half of the full rounds
192        let constants = self.spec.constants().end();
193        for constants in constants.iter() {
194            self.state.sbox_full(constants);
195            self.state.apply_mds(mds);
196        }
197        self.state.sbox_full(&[F::ZERO; T]);
198        self.state.apply_mds(mds);
199    }
200}