halo2_base/poseidon/hasher/
spec.rs

1use crate::{
2    ff::{FromUniformBytes, PrimeField},
3    poseidon::hasher::mds::*,
4};
5
6use getset::{CopyGetters, Getters};
7use poseidon_rs::poseidon::primitives::Spec as PoseidonSpec; // trait
8use std::marker::PhantomData;
9
10// struct so we can use PoseidonSpec trait to generate round constants and MDS matrix
11#[derive(Debug)]
12pub(crate) struct Poseidon128Pow5Gen<
13    F: PrimeField,
14    const T: usize,
15    const RATE: usize,
16    const R_F: usize,
17    const R_P: usize,
18    const SECURE_MDS: usize,
19> {
20    _marker: PhantomData<F>,
21}
22
23impl<
24        F: PrimeField,
25        const T: usize,
26        const RATE: usize,
27        const R_F: usize,
28        const R_P: usize,
29        const SECURE_MDS: usize,
30    > PoseidonSpec<F, T, RATE> for Poseidon128Pow5Gen<F, T, RATE, R_F, R_P, SECURE_MDS>
31{
32    fn full_rounds() -> usize {
33        R_F
34    }
35
36    fn partial_rounds() -> usize {
37        R_P
38    }
39
40    fn sbox(val: F) -> F {
41        val.pow_vartime([5])
42    }
43
44    // see "Avoiding insecure matrices" in Section 2.3 of https://eprint.iacr.org/2019/458.pdf
45    // most Specs used in practice have SECURE_MDS = 0
46    fn secure_mds() -> usize {
47        SECURE_MDS
48    }
49}
50
51// We use the optimized Poseidon implementation described in Supplementary Material Section B of https://eprint.iacr.org/2019/458.pdf
52// This involves some further computation of optimized constants and sparse MDS matrices beyond what the Scroll PoseidonSpec generates
53// The implementation below is adapted from https://github.com/privacy-scaling-explorations/poseidon
54
55/// `OptimizedPoseidonSpec` holds construction parameters as well as constants that are used in
56/// permutation step.
57#[derive(Debug, Clone, Getters, CopyGetters)]
58pub struct OptimizedPoseidonSpec<F: PrimeField, const T: usize, const RATE: usize> {
59    /// Number of full rounds
60    #[getset(get_copy = "pub")]
61    pub(crate) r_f: usize,
62    /// MDS matrices
63    #[getset(get = "pub")]
64    pub(crate) mds_matrices: MDSMatrices<F, T, RATE>,
65    /// Round constants
66    #[getset(get = "pub")]
67    pub(crate) constants: OptimizedConstants<F, T>,
68}
69
70/// `OptimizedConstants` has round constants that are added each round. While
71/// full rounds has T sized constants there is a single constant for each
72/// partial round
73#[derive(Debug, Clone, Getters)]
74pub struct OptimizedConstants<F: PrimeField, const T: usize> {
75    /// start
76    #[getset(get = "pub")]
77    pub(crate) start: Vec<[F; T]>,
78    /// partial
79    #[getset(get = "pub")]
80    pub(crate) partial: Vec<F>,
81    /// end
82    #[getset(get = "pub")]
83    pub(crate) end: Vec<[F; T]>,
84}
85
86impl<F: PrimeField, const T: usize, const RATE: usize> OptimizedPoseidonSpec<F, T, RATE> {
87    /// Generate new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated
88    pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>() -> Self
89    where
90        F: FromUniformBytes<64> + Ord,
91    {
92        let (round_constants, mds, mds_inv) =
93            Poseidon128Pow5Gen::<F, T, RATE, R_F, R_P, SECURE_MDS>::constants();
94        let mds = MDSMatrix(mds);
95        let inverse_mds = MDSMatrix(mds_inv);
96
97        let constants =
98            Self::calculate_optimized_constants(R_F, R_P, round_constants, &inverse_mds);
99        let (sparse_matrices, pre_sparse_mds) = Self::calculate_sparse_matrices(R_P, &mds);
100
101        Self {
102            r_f: R_F,
103            constants,
104            mds_matrices: MDSMatrices { mds, sparse_matrices, pre_sparse_mds },
105        }
106    }
107
108    fn calculate_optimized_constants(
109        r_f: usize,
110        r_p: usize,
111        constants: Vec<[F; T]>,
112        inverse_mds: &MDSMatrix<F, T, RATE>,
113    ) -> OptimizedConstants<F, T> {
114        let (number_of_rounds, r_f_half) = (r_f + r_p, r_f / 2);
115        assert_eq!(constants.len(), number_of_rounds);
116
117        // Calculate optimized constants for first half of the full rounds
118        let mut constants_start: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half];
119        constants_start[0] = constants[0];
120        for (optimized, constants) in
121            constants_start.iter_mut().skip(1).zip(constants.iter().skip(1))
122        {
123            *optimized = inverse_mds.mul_vector(constants);
124        }
125
126        // Calculate constants for partial rounds
127        let mut acc = constants[r_f_half + r_p];
128        let mut constants_partial = vec![F::ZERO; r_p];
129        for (optimized, constants) in constants_partial
130            .iter_mut()
131            .rev()
132            .zip(constants.iter().skip(r_f_half).rev().skip(r_f_half))
133        {
134            let mut tmp = inverse_mds.mul_vector(&acc);
135            *optimized = tmp[0];
136
137            tmp[0] = F::ZERO;
138            for ((acc, tmp), constant) in acc.iter_mut().zip(tmp).zip(constants.iter()) {
139                *acc = tmp + constant
140            }
141        }
142        constants_start.push(inverse_mds.mul_vector(&acc));
143
144        // Calculate optimized constants for ending half of the full rounds
145        let mut constants_end: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half - 1];
146        for (optimized, constants) in
147            constants_end.iter_mut().zip(constants.iter().skip(r_f_half + r_p + 1))
148        {
149            *optimized = inverse_mds.mul_vector(constants);
150        }
151
152        OptimizedConstants {
153            start: constants_start,
154            partial: constants_partial,
155            end: constants_end,
156        }
157    }
158
159    fn calculate_sparse_matrices(
160        r_p: usize,
161        mds: &MDSMatrix<F, T, RATE>,
162    ) -> (Vec<SparseMDSMatrix<F, T, RATE>>, MDSMatrix<F, T, RATE>) {
163        let mds = mds.transpose();
164        let mut acc = mds.clone();
165        let mut sparse_matrices = (0..r_p)
166            .map(|_| {
167                let (m_prime, m_prime_prime) = acc.factorise();
168                acc = mds.mul(&m_prime);
169                m_prime_prime
170            })
171            .collect::<Vec<SparseMDSMatrix<F, T, RATE>>>();
172
173        sparse_matrices.reverse();
174        (sparse_matrices, acc.transpose())
175    }
176}