halo2_base/poseidon/hasher/
spec.rs
1use crate::{
2 ff::{FromUniformBytes, PrimeField},
3 poseidon::hasher::mds::*,
4};
5
6use getset::{CopyGetters, Getters};
7use poseidon_rs::poseidon::primitives::Spec as PoseidonSpec; use std::marker::PhantomData;
9
10#[derive(Debug)]
12pub(crate) struct Poseidon128Pow5Gen<
13 F: PrimeField,
14 const T: usize,
15 const RATE: usize,
16 const R_F: usize,
17 const R_P: usize,
18 const SECURE_MDS: usize,
19> {
20 _marker: PhantomData<F>,
21}
22
23impl<
24 F: PrimeField,
25 const T: usize,
26 const RATE: usize,
27 const R_F: usize,
28 const R_P: usize,
29 const SECURE_MDS: usize,
30 > PoseidonSpec<F, T, RATE> for Poseidon128Pow5Gen<F, T, RATE, R_F, R_P, SECURE_MDS>
31{
32 fn full_rounds() -> usize {
33 R_F
34 }
35
36 fn partial_rounds() -> usize {
37 R_P
38 }
39
40 fn sbox(val: F) -> F {
41 val.pow_vartime([5])
42 }
43
44 fn secure_mds() -> usize {
47 SECURE_MDS
48 }
49}
50
51#[derive(Debug, Clone, Getters, CopyGetters)]
58pub struct OptimizedPoseidonSpec<F: PrimeField, const T: usize, const RATE: usize> {
59 #[getset(get_copy = "pub")]
61 pub(crate) r_f: usize,
62 #[getset(get = "pub")]
64 pub(crate) mds_matrices: MDSMatrices<F, T, RATE>,
65 #[getset(get = "pub")]
67 pub(crate) constants: OptimizedConstants<F, T>,
68}
69
70#[derive(Debug, Clone, Getters)]
74pub struct OptimizedConstants<F: PrimeField, const T: usize> {
75 #[getset(get = "pub")]
77 pub(crate) start: Vec<[F; T]>,
78 #[getset(get = "pub")]
80 pub(crate) partial: Vec<F>,
81 #[getset(get = "pub")]
83 pub(crate) end: Vec<[F; T]>,
84}
85
86impl<F: PrimeField, const T: usize, const RATE: usize> OptimizedPoseidonSpec<F, T, RATE> {
87 pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>() -> Self
89 where
90 F: FromUniformBytes<64> + Ord,
91 {
92 let (round_constants, mds, mds_inv) =
93 Poseidon128Pow5Gen::<F, T, RATE, R_F, R_P, SECURE_MDS>::constants();
94 let mds = MDSMatrix(mds);
95 let inverse_mds = MDSMatrix(mds_inv);
96
97 let constants =
98 Self::calculate_optimized_constants(R_F, R_P, round_constants, &inverse_mds);
99 let (sparse_matrices, pre_sparse_mds) = Self::calculate_sparse_matrices(R_P, &mds);
100
101 Self {
102 r_f: R_F,
103 constants,
104 mds_matrices: MDSMatrices { mds, sparse_matrices, pre_sparse_mds },
105 }
106 }
107
108 fn calculate_optimized_constants(
109 r_f: usize,
110 r_p: usize,
111 constants: Vec<[F; T]>,
112 inverse_mds: &MDSMatrix<F, T, RATE>,
113 ) -> OptimizedConstants<F, T> {
114 let (number_of_rounds, r_f_half) = (r_f + r_p, r_f / 2);
115 assert_eq!(constants.len(), number_of_rounds);
116
117 let mut constants_start: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half];
119 constants_start[0] = constants[0];
120 for (optimized, constants) in
121 constants_start.iter_mut().skip(1).zip(constants.iter().skip(1))
122 {
123 *optimized = inverse_mds.mul_vector(constants);
124 }
125
126 let mut acc = constants[r_f_half + r_p];
128 let mut constants_partial = vec![F::ZERO; r_p];
129 for (optimized, constants) in constants_partial
130 .iter_mut()
131 .rev()
132 .zip(constants.iter().skip(r_f_half).rev().skip(r_f_half))
133 {
134 let mut tmp = inverse_mds.mul_vector(&acc);
135 *optimized = tmp[0];
136
137 tmp[0] = F::ZERO;
138 for ((acc, tmp), constant) in acc.iter_mut().zip(tmp).zip(constants.iter()) {
139 *acc = tmp + constant
140 }
141 }
142 constants_start.push(inverse_mds.mul_vector(&acc));
143
144 let mut constants_end: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half - 1];
146 for (optimized, constants) in
147 constants_end.iter_mut().zip(constants.iter().skip(r_f_half + r_p + 1))
148 {
149 *optimized = inverse_mds.mul_vector(constants);
150 }
151
152 OptimizedConstants {
153 start: constants_start,
154 partial: constants_partial,
155 end: constants_end,
156 }
157 }
158
159 fn calculate_sparse_matrices(
160 r_p: usize,
161 mds: &MDSMatrix<F, T, RATE>,
162 ) -> (Vec<SparseMDSMatrix<F, T, RATE>>, MDSMatrix<F, T, RATE>) {
163 let mds = mds.transpose();
164 let mut acc = mds.clone();
165 let mut sparse_matrices = (0..r_p)
166 .map(|_| {
167 let (m_prime, m_prime_prime) = acc.factorise();
168 acc = mds.mul(&m_prime);
169 m_prime_prime
170 })
171 .collect::<Vec<SparseMDSMatrix<F, T, RATE>>>();
172
173 sparse_matrices.reverse();
174 (sparse_matrices, acc.transpose())
175 }
176}