openvm_stark_sdk/config/
baby_bear_poseidon2.rs

1use std::{any::type_name, sync::Arc};
2
3use openvm_stark_backend::{
4    config::StarkConfig,
5    interaction::fri_log_up::FriLogUpPhase,
6    keygen::MultiStarkKeygenBuilder,
7    p3_challenger::DuplexChallenger,
8    p3_commit::ExtensionMmcs,
9    p3_field::{extension::BinomialExtensionField, Field, PrimeCharacteristicRing},
10    prover::{
11        cpu::{CpuBackend, CpuDevice},
12        MultiTraceStarkProver,
13    },
14};
15use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
16use p3_dft::Radix2DitParallel;
17use p3_fri::{FriParameters as P3FriParameters, TwoAdicFriPcs};
18use p3_merkle_tree::MerkleTreeMmcs;
19use p3_poseidon2::ExternalLayerConstants;
20use p3_symmetric::{CryptographicPermutation, PaddingFreeSponge, TruncatedPermutation};
21use rand::{rngs::StdRng, SeedableRng};
22use zkhash::{
23    ark_ff::PrimeField as _, fields::babybear::FpBabyBear as HorizenBabyBear,
24    poseidon2::poseidon2_instance_babybear::RC16,
25};
26
27use super::{
28    instrument::{HashStatistics, InstrumentCounter, Instrumented, StarkHashStatistics},
29    FriParameters,
30};
31use crate::{
32    assert_sc_compatible_with_serde,
33    config::fri_params::{
34        SecurityParameters, MAX_BATCH_SIZE_LOG_BLOWUP_1, MAX_BATCH_SIZE_LOG_BLOWUP_2,
35        MAX_NUM_CONSTRAINTS,
36    },
37    engine::{StarkEngine, StarkEngineWithHashInstrumentation, StarkFriEngine},
38};
39
40const RATE: usize = 8;
41// permutation width
42const WIDTH: usize = 16; // rate + capacity
43const DIGEST_WIDTH: usize = 8;
44
45type Val = BabyBear;
46type PackedVal = <Val as Field>::Packing;
47type Challenge = BinomialExtensionField<Val, 4>;
48type Perm = Poseidon2BabyBear<WIDTH>;
49type InstrPerm = Instrumented<Perm>;
50
51// Generic over P: CryptographicPermutation<[F; WIDTH]>
52type Hash<P> = PaddingFreeSponge<P, WIDTH, RATE, DIGEST_WIDTH>;
53type Compress<P> = TruncatedPermutation<P, 2, DIGEST_WIDTH, WIDTH>;
54type ValMmcs<P> =
55    MerkleTreeMmcs<PackedVal, <Val as Field>::Packing, Hash<P>, Compress<P>, DIGEST_WIDTH>;
56type ChallengeMmcs<P> = ExtensionMmcs<Val, Challenge, ValMmcs<P>>;
57pub type Challenger<P> = DuplexChallenger<Val, P, WIDTH, RATE>;
58type Dft = Radix2DitParallel<Val>;
59type Pcs<P> = TwoAdicFriPcs<Val, Dft, ValMmcs<P>, ChallengeMmcs<P>>;
60type RapPhase<P> = FriLogUpPhase<Val, Challenge, Challenger<P>>;
61
62pub type BabyBearPermutationConfig<P> = StarkConfig<Pcs<P>, RapPhase<P>, Challenge, Challenger<P>>;
63pub type BabyBearPoseidon2Config = BabyBearPermutationConfig<Perm>;
64pub type BabyBearPoseidon2Engine = BabyBearPermutationEngine<Perm>;
65
66assert_sc_compatible_with_serde!(BabyBearPoseidon2Config);
67
68pub struct BabyBearPermutationEngine<P>
69where
70    P: CryptographicPermutation<[Val; WIDTH]>
71        + CryptographicPermutation<[PackedVal; WIDTH]>
72        + Clone,
73{
74    pub fri_params: FriParameters,
75    pub device: CpuDevice<BabyBearPermutationConfig<P>>,
76    pub perm: P,
77    pub max_constraint_degree: usize,
78}
79
80impl<P> StarkEngine for BabyBearPermutationEngine<P>
81where
82    P: CryptographicPermutation<[Val; WIDTH]>
83        + CryptographicPermutation<[PackedVal; WIDTH]>
84        + Clone,
85{
86    type SC = BabyBearPermutationConfig<P>;
87    type PB = CpuBackend<Self::SC>;
88    type PD = CpuDevice<Self::SC>;
89
90    fn config(&self) -> &BabyBearPermutationConfig<P> {
91        &self.device.config
92    }
93
94    fn device(&self) -> &CpuDevice<BabyBearPermutationConfig<P>> {
95        &self.device
96    }
97
98    fn keygen_builder(&self) -> MultiStarkKeygenBuilder<'_, Self::SC> {
99        let mut builder = MultiStarkKeygenBuilder::new(self.config());
100        builder.set_max_constraint_degree(self.max_constraint_degree);
101        let max_batch_size = if self.fri_params.log_blowup == 1 {
102            MAX_BATCH_SIZE_LOG_BLOWUP_1
103        } else {
104            MAX_BATCH_SIZE_LOG_BLOWUP_2
105        };
106        builder.max_batch_size = Some(max_batch_size);
107        builder.max_num_constraints = Some(MAX_NUM_CONSTRAINTS);
108
109        builder
110    }
111
112    fn prover(&self) -> MultiTraceStarkProver<BabyBearPermutationConfig<P>> {
113        MultiTraceStarkProver::new(
114            CpuBackend::default(),
115            self.device.clone(),
116            self.new_challenger(),
117        )
118    }
119
120    fn max_constraint_degree(&self) -> Option<usize> {
121        Some(self.max_constraint_degree)
122    }
123
124    fn new_challenger(&self) -> Challenger<P> {
125        Challenger::new(self.perm.clone())
126    }
127}
128
129impl<P> StarkEngineWithHashInstrumentation for BabyBearPermutationEngine<Instrumented<P>>
130where
131    P: CryptographicPermutation<[Val; WIDTH]>
132        + CryptographicPermutation<[PackedVal; WIDTH]>
133        + Clone,
134{
135    fn clear_instruments(&mut self) {
136        self.perm.input_lens_by_type.lock().unwrap().clear();
137    }
138    fn stark_hash_statistics<T>(&self, custom: T) -> StarkHashStatistics<T> {
139        let counter = self.perm.input_lens_by_type.lock().unwrap();
140        let permutations = counter.iter().fold(0, |total, (name, lens)| {
141            if name == type_name::<[Val; WIDTH]>() {
142                let count: usize = lens.iter().sum();
143                println!("Permutation: {name}, Count: {count}");
144                total + count
145            } else {
146                panic!("Permutation type not yet supported: {}", name);
147            }
148        });
149
150        StarkHashStatistics {
151            name: type_name::<P>().to_string(),
152            stats: HashStatistics { permutations },
153            fri_params: self.fri_params,
154            custom,
155        }
156    }
157}
158
159/// `pcs_log_degree` is the upper bound on the log_2(PCS polynomial degree).
160pub fn default_engine() -> BabyBearPoseidon2Engine {
161    default_engine_impl(FriParameters::standard_fast())
162}
163
164/// `pcs_log_degree` is the upper bound on the log_2(PCS polynomial degree).
165fn default_engine_impl(fri_params: FriParameters) -> BabyBearPoseidon2Engine {
166    let perm = default_perm();
167    let security_params = SecurityParameters::new_baby_bear_100_bits(fri_params);
168    engine_from_perm(perm, security_params)
169}
170
171/// `pcs_log_degree` is the upper bound on the log_2(PCS polynomial degree).
172pub fn default_config(perm: &Perm) -> BabyBearPoseidon2Config {
173    config_from_perm(perm, SecurityParameters::standard_fast())
174}
175
176pub fn engine_from_perm<P>(
177    perm: P,
178    security_params: SecurityParameters,
179) -> BabyBearPermutationEngine<P>
180where
181    P: CryptographicPermutation<[Val; WIDTH]>
182        + CryptographicPermutation<[PackedVal; WIDTH]>
183        + Clone,
184{
185    let fri_params = security_params.fri_params;
186    let max_constraint_degree = fri_params.max_constraint_degree();
187    let config = config_from_perm(&perm, security_params);
188    BabyBearPermutationEngine {
189        device: CpuDevice::new(Arc::new(config), fri_params.log_blowup),
190        perm,
191        fri_params,
192        max_constraint_degree,
193    }
194}
195
196pub fn config_from_perm<P>(
197    perm: &P,
198    security_params: SecurityParameters,
199) -> BabyBearPermutationConfig<P>
200where
201    P: CryptographicPermutation<[Val; WIDTH]>
202        + CryptographicPermutation<[PackedVal; WIDTH]>
203        + Clone,
204{
205    let hash = Hash::new(perm.clone());
206    let compress = Compress::new(perm.clone());
207    let val_mmcs = ValMmcs::new(hash, compress);
208    let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone());
209    let dft = Dft::default();
210    let SecurityParameters {
211        fri_params,
212        log_up_params,
213        deep_ali_params,
214    } = security_params;
215    let fri_config = P3FriParameters {
216        log_blowup: fri_params.log_blowup,
217        log_final_poly_len: fri_params.log_final_poly_len,
218        num_queries: fri_params.num_queries,
219        commit_proof_of_work_bits: fri_params.commit_proof_of_work_bits,
220        query_proof_of_work_bits: fri_params.query_proof_of_work_bits,
221        mmcs: challenge_mmcs,
222    };
223    let pcs = Pcs::new(dft, val_mmcs, fri_config);
224    let challenger = Challenger::new(perm.clone());
225    let rap_phase = FriLogUpPhase::new(log_up_params, fri_params.log_blowup);
226    BabyBearPermutationConfig::new(pcs, challenger, rap_phase, deep_ali_params)
227}
228
229/// Uses HorizenLabs Poseidon2 round constants, but plonky3 Mat4 and also
230/// with a p3 Monty reduction factor.
231pub fn default_perm() -> Perm {
232    let (external_constants, internal_constants) = horizen_round_consts_16();
233    Perm::new(external_constants, internal_constants)
234}
235
236pub fn random_perm() -> Perm {
237    let seed = [42; 32];
238    let mut rng = StdRng::from_seed(seed);
239    Perm::new_from_rng_128(&mut rng)
240}
241
242pub fn random_instrumented_perm() -> InstrPerm {
243    let perm = random_perm();
244    Instrumented::new(perm)
245}
246
247fn horizen_to_p3(horizen_babybear: HorizenBabyBear) -> BabyBear {
248    BabyBear::from_u64(horizen_babybear.into_bigint().0[0])
249}
250
251pub fn horizen_round_consts_16() -> (ExternalLayerConstants<BabyBear, 16>, Vec<BabyBear>) {
252    let p3_rc16: Vec<Vec<BabyBear>> = RC16
253        .iter()
254        .map(|round| {
255            round
256                .iter()
257                .map(|babybear| horizen_to_p3(*babybear))
258                .collect()
259        })
260        .collect();
261
262    let rounds_f = 8;
263    let rounds_p = 13;
264    let rounds_f_beginning = rounds_f / 2;
265    let p_end = rounds_f_beginning + rounds_p;
266    let initial: Vec<[BabyBear; 16]> = p3_rc16[..rounds_f_beginning]
267        .iter()
268        .cloned()
269        .map(|round| round.try_into().unwrap())
270        .collect();
271    let terminal: Vec<[BabyBear; 16]> = p3_rc16[p_end..]
272        .iter()
273        .cloned()
274        .map(|round| round.try_into().unwrap())
275        .collect();
276    let internal_round_constants: Vec<BabyBear> = p3_rc16[rounds_f_beginning..p_end]
277        .iter()
278        .map(|round| round[0])
279        .collect();
280    (
281        ExternalLayerConstants::new(initial, terminal),
282        internal_round_constants,
283    )
284}
285
286/// Logs hash count statistics to stdout and returns as struct.
287/// Count of 1 corresponds to a Poseidon2 permutation with rate RATE that outputs OUT field elements
288#[allow(dead_code)]
289pub fn print_hash_counts(hash_counter: &InstrumentCounter, compress_counter: &InstrumentCounter) {
290    let hash_counter = hash_counter.lock().unwrap();
291    let mut hash_count = 0;
292    hash_counter.iter().for_each(|(name, lens)| {
293        if name == type_name::<(Val, [Val; DIGEST_WIDTH])>() {
294            let count = lens.iter().fold(0, |count, len| count + len.div_ceil(RATE));
295            println!("Hash: {name}, Count: {count}");
296            hash_count += count;
297        } else {
298            panic!("Hash type not yet supported: {}", name);
299        }
300    });
301    drop(hash_counter);
302    let compress_counter = compress_counter.lock().unwrap();
303    let mut compress_count = 0;
304    compress_counter.iter().for_each(|(name, lens)| {
305        if name == type_name::<[Val; DIGEST_WIDTH]>() {
306            let count = lens.iter().fold(0, |count, len| {
307                // len should always be N=2 for TruncatedPermutation
308                count + (DIGEST_WIDTH * len).div_ceil(WIDTH)
309            });
310            println!("Compress: {name}, Count: {count}");
311            compress_count += count;
312        } else {
313            panic!("Compress type not yet supported: {}", name);
314        }
315    });
316    let total_count = hash_count + compress_count;
317    println!("Total Count: {total_count}");
318}
319
320impl StarkFriEngine for BabyBearPoseidon2Engine {
321    fn new(fri_params: FriParameters) -> Self {
322        default_engine_impl(fri_params)
323    }
324    fn fri_params(&self) -> FriParameters {
325        self.fri_params
326    }
327}