openvm_stark_sdk/config/
baby_bear_poseidon2.rs

1use std::any::type_name;
2
3use openvm_stark_backend::{
4    config::StarkConfig,
5    interaction::fri_log_up::FriLogUpPhase,
6    p3_challenger::DuplexChallenger,
7    p3_commit::ExtensionMmcs,
8    p3_field::{extension::BinomialExtensionField, Field, FieldAlgebra},
9};
10use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
11use p3_dft::Radix2DitParallel;
12use p3_fri::{FriConfig, TwoAdicFriPcs};
13use p3_merkle_tree::MerkleTreeMmcs;
14use p3_poseidon2::ExternalLayerConstants;
15use p3_symmetric::{CryptographicPermutation, PaddingFreeSponge, TruncatedPermutation};
16use rand::{rngs::StdRng, SeedableRng};
17use zkhash::{
18    ark_ff::PrimeField as _, fields::babybear::FpBabyBear as HorizenBabyBear,
19    poseidon2::poseidon2_instance_babybear::RC16,
20};
21
22use super::{
23    instrument::{HashStatistics, InstrumentCounter, Instrumented, StarkHashStatistics},
24    FriParameters,
25};
26use crate::{
27    assert_sc_compatible_with_serde,
28    config::{
29        fri_params::SecurityParameters, log_up_params::log_up_security_params_baby_bear_100_bits,
30    },
31    engine::{StarkEngine, StarkEngineWithHashInstrumentation, StarkFriEngine},
32};
33
34const RATE: usize = 8;
35// permutation width
36const WIDTH: usize = 16; // rate + capacity
37const DIGEST_WIDTH: usize = 8;
38
39type Val = BabyBear;
40type PackedVal = <Val as Field>::Packing;
41type Challenge = BinomialExtensionField<Val, 4>;
42type Perm = Poseidon2BabyBear<WIDTH>;
43type InstrPerm = Instrumented<Perm>;
44
45// Generic over P: CryptographicPermutation<[F; WIDTH]>
46type Hash<P> = PaddingFreeSponge<P, WIDTH, RATE, DIGEST_WIDTH>;
47type Compress<P> = TruncatedPermutation<P, 2, DIGEST_WIDTH, WIDTH>;
48type ValMmcs<P> =
49    MerkleTreeMmcs<PackedVal, <Val as Field>::Packing, Hash<P>, Compress<P>, DIGEST_WIDTH>;
50type ChallengeMmcs<P> = ExtensionMmcs<Val, Challenge, ValMmcs<P>>;
51pub type Challenger<P> = DuplexChallenger<Val, P, WIDTH, RATE>;
52type Dft = Radix2DitParallel<Val>;
53type Pcs<P> = TwoAdicFriPcs<Val, Dft, ValMmcs<P>, ChallengeMmcs<P>>;
54type RapPhase<P> = FriLogUpPhase<Val, Challenge, Challenger<P>>;
55
56pub type BabyBearPermutationConfig<P> = StarkConfig<Pcs<P>, RapPhase<P>, Challenge, Challenger<P>>;
57pub type BabyBearPoseidon2Config = BabyBearPermutationConfig<Perm>;
58pub type BabyBearPoseidon2Engine = BabyBearPermutationEngine<Perm>;
59
60assert_sc_compatible_with_serde!(BabyBearPoseidon2Config);
61
62pub struct BabyBearPermutationEngine<P>
63where
64    P: CryptographicPermutation<[Val; WIDTH]>
65        + CryptographicPermutation<[PackedVal; WIDTH]>
66        + Clone,
67{
68    pub fri_params: FriParameters,
69    pub config: BabyBearPermutationConfig<P>,
70    pub perm: P,
71    pub max_constraint_degree: usize,
72}
73
74impl<P> StarkEngine<BabyBearPermutationConfig<P>> for BabyBearPermutationEngine<P>
75where
76    P: CryptographicPermutation<[Val; WIDTH]>
77        + CryptographicPermutation<[PackedVal; WIDTH]>
78        + Clone,
79{
80    fn config(&self) -> &BabyBearPermutationConfig<P> {
81        &self.config
82    }
83
84    fn max_constraint_degree(&self) -> Option<usize> {
85        Some(self.max_constraint_degree)
86    }
87
88    fn new_challenger(&self) -> Challenger<P> {
89        Challenger::new(self.perm.clone())
90    }
91}
92
93impl<P> StarkEngineWithHashInstrumentation<BabyBearPermutationConfig<Instrumented<P>>>
94    for BabyBearPermutationEngine<Instrumented<P>>
95where
96    P: CryptographicPermutation<[Val; WIDTH]>
97        + CryptographicPermutation<[PackedVal; WIDTH]>
98        + Clone,
99{
100    fn clear_instruments(&mut self) {
101        self.perm.input_lens_by_type.lock().unwrap().clear();
102    }
103    fn stark_hash_statistics<T>(&self, custom: T) -> StarkHashStatistics<T> {
104        let counter = self.perm.input_lens_by_type.lock().unwrap();
105        let permutations = counter.iter().fold(0, |total, (name, lens)| {
106            if name == type_name::<[Val; WIDTH]>() {
107                let count: usize = lens.iter().sum();
108                println!("Permutation: {name}, Count: {count}");
109                total + count
110            } else {
111                panic!("Permutation type not yet supported: {}", name);
112            }
113        });
114
115        StarkHashStatistics {
116            name: type_name::<P>().to_string(),
117            stats: HashStatistics { permutations },
118            fri_params: self.fri_params,
119            custom,
120        }
121    }
122}
123
124/// `pcs_log_degree` is the upper bound on the log_2(PCS polynomial degree).
125pub fn default_engine() -> BabyBearPoseidon2Engine {
126    default_engine_impl(FriParameters::standard_fast())
127}
128
129/// `pcs_log_degree` is the upper bound on the log_2(PCS polynomial degree).
130fn default_engine_impl(fri_params: FriParameters) -> BabyBearPoseidon2Engine {
131    let perm = default_perm();
132    let security_params = SecurityParameters {
133        fri_params,
134        log_up_params: log_up_security_params_baby_bear_100_bits(),
135    };
136    engine_from_perm(perm, security_params)
137}
138
139/// `pcs_log_degree` is the upper bound on the log_2(PCS polynomial degree).
140pub fn default_config(perm: &Perm) -> BabyBearPoseidon2Config {
141    config_from_perm(perm, SecurityParameters::standard_fast())
142}
143
144pub fn engine_from_perm<P>(
145    perm: P,
146    security_params: SecurityParameters,
147) -> BabyBearPermutationEngine<P>
148where
149    P: CryptographicPermutation<[Val; WIDTH]>
150        + CryptographicPermutation<[PackedVal; WIDTH]>
151        + Clone,
152{
153    let fri_params = security_params.fri_params;
154    let max_constraint_degree = fri_params.max_constraint_degree();
155    let config = config_from_perm(&perm, security_params);
156    BabyBearPermutationEngine {
157        config,
158        perm,
159        fri_params,
160        max_constraint_degree,
161    }
162}
163
164pub fn config_from_perm<P>(
165    perm: &P,
166    security_params: SecurityParameters,
167) -> BabyBearPermutationConfig<P>
168where
169    P: CryptographicPermutation<[Val; WIDTH]>
170        + CryptographicPermutation<[PackedVal; WIDTH]>
171        + Clone,
172{
173    let hash = Hash::new(perm.clone());
174    let compress = Compress::new(perm.clone());
175    let val_mmcs = ValMmcs::new(hash, compress);
176    let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone());
177    let dft = Dft::default();
178    let SecurityParameters {
179        fri_params,
180        log_up_params,
181    } = security_params;
182    let fri_config = FriConfig {
183        log_blowup: fri_params.log_blowup,
184        log_final_poly_len: fri_params.log_final_poly_len,
185        num_queries: fri_params.num_queries,
186        proof_of_work_bits: fri_params.proof_of_work_bits,
187        mmcs: challenge_mmcs,
188    };
189    let pcs = Pcs::new(dft, val_mmcs, fri_config);
190    let rap_phase = FriLogUpPhase::new(log_up_params);
191    BabyBearPermutationConfig::new(pcs, rap_phase)
192}
193
194/// Uses HorizenLabs Poseidon2 round constants, but plonky3 Mat4 and also
195/// with a p3 Monty reduction factor.
196pub fn default_perm() -> Perm {
197    let (external_constants, internal_constants) = horizen_round_consts_16();
198    Perm::new(external_constants, internal_constants)
199}
200
201pub fn random_perm() -> Perm {
202    let seed = [42; 32];
203    let mut rng = StdRng::from_seed(seed);
204    Perm::new_from_rng_128(&mut rng)
205}
206
207pub fn random_instrumented_perm() -> InstrPerm {
208    let perm = random_perm();
209    Instrumented::new(perm)
210}
211
212fn horizen_to_p3(horizen_babybear: HorizenBabyBear) -> BabyBear {
213    BabyBear::from_canonical_u64(horizen_babybear.into_bigint().0[0])
214}
215
216pub fn horizen_round_consts_16() -> (ExternalLayerConstants<BabyBear, 16>, Vec<BabyBear>) {
217    let p3_rc16: Vec<Vec<BabyBear>> = RC16
218        .iter()
219        .map(|round| {
220            round
221                .iter()
222                .map(|babybear| horizen_to_p3(*babybear))
223                .collect()
224        })
225        .collect();
226
227    let rounds_f = 8;
228    let rounds_p = 13;
229    let rounds_f_beginning = rounds_f / 2;
230    let p_end = rounds_f_beginning + rounds_p;
231    let initial: Vec<[BabyBear; 16]> = p3_rc16[..rounds_f_beginning]
232        .iter()
233        .cloned()
234        .map(|round| round.try_into().unwrap())
235        .collect();
236    let terminal: Vec<[BabyBear; 16]> = p3_rc16[p_end..]
237        .iter()
238        .cloned()
239        .map(|round| round.try_into().unwrap())
240        .collect();
241    let internal_round_constants: Vec<BabyBear> = p3_rc16[rounds_f_beginning..p_end]
242        .iter()
243        .map(|round| round[0])
244        .collect();
245    (
246        ExternalLayerConstants::new(initial, terminal),
247        internal_round_constants,
248    )
249}
250
251/// Logs hash count statistics to stdout and returns as struct.
252/// Count of 1 corresponds to a Poseidon2 permutation with rate RATE that outputs OUT field elements
253#[allow(dead_code)]
254pub fn print_hash_counts(hash_counter: &InstrumentCounter, compress_counter: &InstrumentCounter) {
255    let hash_counter = hash_counter.lock().unwrap();
256    let mut hash_count = 0;
257    hash_counter.iter().for_each(|(name, lens)| {
258        if name == type_name::<(Val, [Val; DIGEST_WIDTH])>() {
259            let count = lens.iter().fold(0, |count, len| count + len.div_ceil(RATE));
260            println!("Hash: {name}, Count: {count}");
261            hash_count += count;
262        } else {
263            panic!("Hash type not yet supported: {}", name);
264        }
265    });
266    drop(hash_counter);
267    let compress_counter = compress_counter.lock().unwrap();
268    let mut compress_count = 0;
269    compress_counter.iter().for_each(|(name, lens)| {
270        if name == type_name::<[Val; DIGEST_WIDTH]>() {
271            let count = lens.iter().fold(0, |count, len| {
272                // len should always be N=2 for TruncatedPermutation
273                count + (DIGEST_WIDTH * len).div_ceil(WIDTH)
274            });
275            println!("Compress: {name}, Count: {count}");
276            compress_count += count;
277        } else {
278            panic!("Compress type not yet supported: {}", name);
279        }
280    });
281    let total_count = hash_count + compress_count;
282    println!("Total Count: {total_count}");
283}
284
285impl StarkFriEngine<BabyBearPoseidon2Config> for BabyBearPoseidon2Engine {
286    fn new(fri_params: FriParameters) -> Self {
287        default_engine_impl(fri_params)
288    }
289    fn fri_params(&self) -> FriParameters {
290        self.fri_params
291    }
292}