openvm_stark_sdk/config/
baby_bear_poseidon2.rs

1use std::{any::type_name, sync::Arc};
2
3use openvm_stark_backend::{
4    config::StarkConfig,
5    interaction::fri_log_up::FriLogUpPhase,
6    p3_challenger::DuplexChallenger,
7    p3_commit::ExtensionMmcs,
8    p3_field::{extension::BinomialExtensionField, Field, FieldAlgebra},
9    prover::{
10        cpu::{CpuBackend, CpuDevice},
11        MultiTraceStarkProver,
12    },
13};
14use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
15use p3_dft::Radix2DitParallel;
16use p3_fri::{FriConfig, TwoAdicFriPcs};
17use p3_merkle_tree::MerkleTreeMmcs;
18use p3_poseidon2::ExternalLayerConstants;
19use p3_symmetric::{CryptographicPermutation, PaddingFreeSponge, TruncatedPermutation};
20use rand::{rngs::StdRng, SeedableRng};
21use zkhash::{
22    ark_ff::PrimeField as _, fields::babybear::FpBabyBear as HorizenBabyBear,
23    poseidon2::poseidon2_instance_babybear::RC16,
24};
25
26use super::{
27    instrument::{HashStatistics, InstrumentCounter, Instrumented, StarkHashStatistics},
28    FriParameters,
29};
30use crate::{
31    assert_sc_compatible_with_serde,
32    config::{
33        fri_params::SecurityParameters, log_up_params::log_up_security_params_baby_bear_100_bits,
34    },
35    engine::{StarkEngine, StarkEngineWithHashInstrumentation, StarkFriEngine},
36};
37
38const RATE: usize = 8;
39// permutation width
40const WIDTH: usize = 16; // rate + capacity
41const DIGEST_WIDTH: usize = 8;
42
43type Val = BabyBear;
44type PackedVal = <Val as Field>::Packing;
45type Challenge = BinomialExtensionField<Val, 4>;
46type Perm = Poseidon2BabyBear<WIDTH>;
47type InstrPerm = Instrumented<Perm>;
48
49// Generic over P: CryptographicPermutation<[F; WIDTH]>
50type Hash<P> = PaddingFreeSponge<P, WIDTH, RATE, DIGEST_WIDTH>;
51type Compress<P> = TruncatedPermutation<P, 2, DIGEST_WIDTH, WIDTH>;
52type ValMmcs<P> =
53    MerkleTreeMmcs<PackedVal, <Val as Field>::Packing, Hash<P>, Compress<P>, DIGEST_WIDTH>;
54type ChallengeMmcs<P> = ExtensionMmcs<Val, Challenge, ValMmcs<P>>;
55pub type Challenger<P> = DuplexChallenger<Val, P, WIDTH, RATE>;
56type Dft = Radix2DitParallel<Val>;
57type Pcs<P> = TwoAdicFriPcs<Val, Dft, ValMmcs<P>, ChallengeMmcs<P>>;
58type RapPhase<P> = FriLogUpPhase<Val, Challenge, Challenger<P>>;
59
60pub type BabyBearPermutationConfig<P> = StarkConfig<Pcs<P>, RapPhase<P>, Challenge, Challenger<P>>;
61pub type BabyBearPoseidon2Config = BabyBearPermutationConfig<Perm>;
62pub type BabyBearPoseidon2Engine = BabyBearPermutationEngine<Perm>;
63
64assert_sc_compatible_with_serde!(BabyBearPoseidon2Config);
65
66pub struct BabyBearPermutationEngine<P>
67where
68    P: CryptographicPermutation<[Val; WIDTH]>
69        + CryptographicPermutation<[PackedVal; WIDTH]>
70        + Clone,
71{
72    pub fri_params: FriParameters,
73    pub device: CpuDevice<BabyBearPermutationConfig<P>>,
74    pub perm: P,
75    pub max_constraint_degree: usize,
76}
77
78impl<P> StarkEngine for BabyBearPermutationEngine<P>
79where
80    P: CryptographicPermutation<[Val; WIDTH]>
81        + CryptographicPermutation<[PackedVal; WIDTH]>
82        + Clone,
83{
84    type SC = BabyBearPermutationConfig<P>;
85    type PB = CpuBackend<Self::SC>;
86    type PD = CpuDevice<Self::SC>;
87
88    fn config(&self) -> &BabyBearPermutationConfig<P> {
89        &self.device.config
90    }
91
92    fn device(&self) -> &CpuDevice<BabyBearPermutationConfig<P>> {
93        &self.device
94    }
95
96    fn prover(&self) -> MultiTraceStarkProver<BabyBearPermutationConfig<P>> {
97        MultiTraceStarkProver::new(
98            CpuBackend::default(),
99            self.device.clone(),
100            self.new_challenger(),
101        )
102    }
103
104    fn max_constraint_degree(&self) -> Option<usize> {
105        Some(self.max_constraint_degree)
106    }
107
108    fn new_challenger(&self) -> Challenger<P> {
109        Challenger::new(self.perm.clone())
110    }
111}
112
113impl<P> StarkEngineWithHashInstrumentation for BabyBearPermutationEngine<Instrumented<P>>
114where
115    P: CryptographicPermutation<[Val; WIDTH]>
116        + CryptographicPermutation<[PackedVal; WIDTH]>
117        + Clone,
118{
119    fn clear_instruments(&mut self) {
120        self.perm.input_lens_by_type.lock().unwrap().clear();
121    }
122    fn stark_hash_statistics<T>(&self, custom: T) -> StarkHashStatistics<T> {
123        let counter = self.perm.input_lens_by_type.lock().unwrap();
124        let permutations = counter.iter().fold(0, |total, (name, lens)| {
125            if name == type_name::<[Val; WIDTH]>() {
126                let count: usize = lens.iter().sum();
127                println!("Permutation: {name}, Count: {count}");
128                total + count
129            } else {
130                panic!("Permutation type not yet supported: {}", name);
131            }
132        });
133
134        StarkHashStatistics {
135            name: type_name::<P>().to_string(),
136            stats: HashStatistics { permutations },
137            fri_params: self.fri_params,
138            custom,
139        }
140    }
141}
142
143/// `pcs_log_degree` is the upper bound on the log_2(PCS polynomial degree).
144pub fn default_engine() -> BabyBearPoseidon2Engine {
145    default_engine_impl(FriParameters::standard_fast())
146}
147
148/// `pcs_log_degree` is the upper bound on the log_2(PCS polynomial degree).
149fn default_engine_impl(fri_params: FriParameters) -> BabyBearPoseidon2Engine {
150    let perm = default_perm();
151    let security_params = SecurityParameters {
152        fri_params,
153        log_up_params: log_up_security_params_baby_bear_100_bits(),
154    };
155    engine_from_perm(perm, security_params)
156}
157
158/// `pcs_log_degree` is the upper bound on the log_2(PCS polynomial degree).
159pub fn default_config(perm: &Perm) -> BabyBearPoseidon2Config {
160    config_from_perm(perm, SecurityParameters::standard_fast())
161}
162
163pub fn engine_from_perm<P>(
164    perm: P,
165    security_params: SecurityParameters,
166) -> BabyBearPermutationEngine<P>
167where
168    P: CryptographicPermutation<[Val; WIDTH]>
169        + CryptographicPermutation<[PackedVal; WIDTH]>
170        + Clone,
171{
172    let fri_params = security_params.fri_params;
173    let max_constraint_degree = fri_params.max_constraint_degree();
174    let config = config_from_perm(&perm, security_params);
175    BabyBearPermutationEngine {
176        device: CpuDevice::new(Arc::new(config), fri_params.log_blowup),
177        perm,
178        fri_params,
179        max_constraint_degree,
180    }
181}
182
183pub fn config_from_perm<P>(
184    perm: &P,
185    security_params: SecurityParameters,
186) -> BabyBearPermutationConfig<P>
187where
188    P: CryptographicPermutation<[Val; WIDTH]>
189        + CryptographicPermutation<[PackedVal; WIDTH]>
190        + Clone,
191{
192    let hash = Hash::new(perm.clone());
193    let compress = Compress::new(perm.clone());
194    let val_mmcs = ValMmcs::new(hash, compress);
195    let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone());
196    let dft = Dft::default();
197    let SecurityParameters {
198        fri_params,
199        log_up_params,
200    } = security_params;
201    let fri_config = FriConfig {
202        log_blowup: fri_params.log_blowup,
203        log_final_poly_len: fri_params.log_final_poly_len,
204        num_queries: fri_params.num_queries,
205        proof_of_work_bits: fri_params.proof_of_work_bits,
206        mmcs: challenge_mmcs,
207    };
208    let pcs = Pcs::new(dft, val_mmcs, fri_config);
209    let rap_phase = FriLogUpPhase::new(log_up_params, fri_params.log_blowup);
210    BabyBearPermutationConfig::new(pcs, rap_phase)
211}
212
213/// Uses HorizenLabs Poseidon2 round constants, but plonky3 Mat4 and also
214/// with a p3 Monty reduction factor.
215pub fn default_perm() -> Perm {
216    let (external_constants, internal_constants) = horizen_round_consts_16();
217    Perm::new(external_constants, internal_constants)
218}
219
220pub fn random_perm() -> Perm {
221    let seed = [42; 32];
222    let mut rng = StdRng::from_seed(seed);
223    Perm::new_from_rng_128(&mut rng)
224}
225
226pub fn random_instrumented_perm() -> InstrPerm {
227    let perm = random_perm();
228    Instrumented::new(perm)
229}
230
231fn horizen_to_p3(horizen_babybear: HorizenBabyBear) -> BabyBear {
232    BabyBear::from_canonical_u64(horizen_babybear.into_bigint().0[0])
233}
234
235pub fn horizen_round_consts_16() -> (ExternalLayerConstants<BabyBear, 16>, Vec<BabyBear>) {
236    let p3_rc16: Vec<Vec<BabyBear>> = RC16
237        .iter()
238        .map(|round| {
239            round
240                .iter()
241                .map(|babybear| horizen_to_p3(*babybear))
242                .collect()
243        })
244        .collect();
245
246    let rounds_f = 8;
247    let rounds_p = 13;
248    let rounds_f_beginning = rounds_f / 2;
249    let p_end = rounds_f_beginning + rounds_p;
250    let initial: Vec<[BabyBear; 16]> = p3_rc16[..rounds_f_beginning]
251        .iter()
252        .cloned()
253        .map(|round| round.try_into().unwrap())
254        .collect();
255    let terminal: Vec<[BabyBear; 16]> = p3_rc16[p_end..]
256        .iter()
257        .cloned()
258        .map(|round| round.try_into().unwrap())
259        .collect();
260    let internal_round_constants: Vec<BabyBear> = p3_rc16[rounds_f_beginning..p_end]
261        .iter()
262        .map(|round| round[0])
263        .collect();
264    (
265        ExternalLayerConstants::new(initial, terminal),
266        internal_round_constants,
267    )
268}
269
270/// Logs hash count statistics to stdout and returns as struct.
271/// Count of 1 corresponds to a Poseidon2 permutation with rate RATE that outputs OUT field elements
272#[allow(dead_code)]
273pub fn print_hash_counts(hash_counter: &InstrumentCounter, compress_counter: &InstrumentCounter) {
274    let hash_counter = hash_counter.lock().unwrap();
275    let mut hash_count = 0;
276    hash_counter.iter().for_each(|(name, lens)| {
277        if name == type_name::<(Val, [Val; DIGEST_WIDTH])>() {
278            let count = lens.iter().fold(0, |count, len| count + len.div_ceil(RATE));
279            println!("Hash: {name}, Count: {count}");
280            hash_count += count;
281        } else {
282            panic!("Hash type not yet supported: {}", name);
283        }
284    });
285    drop(hash_counter);
286    let compress_counter = compress_counter.lock().unwrap();
287    let mut compress_count = 0;
288    compress_counter.iter().for_each(|(name, lens)| {
289        if name == type_name::<[Val; DIGEST_WIDTH]>() {
290            let count = lens.iter().fold(0, |count, len| {
291                // len should always be N=2 for TruncatedPermutation
292                count + (DIGEST_WIDTH * len).div_ceil(WIDTH)
293            });
294            println!("Compress: {name}, Count: {count}");
295            compress_count += count;
296        } else {
297            panic!("Compress type not yet supported: {}", name);
298        }
299    });
300    let total_count = hash_count + compress_count;
301    println!("Total Count: {total_count}");
302}
303
304impl StarkFriEngine for BabyBearPoseidon2Engine {
305    fn new(fri_params: FriParameters) -> Self {
306        default_engine_impl(fri_params)
307    }
308    fn fri_params(&self) -> FriParameters {
309        self.fri_params
310    }
311}