1use std::{any::type_name, sync::Arc};
2
3use openvm_stark_backend::{
4 config::StarkConfig,
5 interaction::fri_log_up::FriLogUpPhase,
6 p3_challenger::DuplexChallenger,
7 p3_commit::ExtensionMmcs,
8 p3_field::{extension::BinomialExtensionField, Field, FieldAlgebra},
9 prover::{
10 cpu::{CpuBackend, CpuDevice},
11 MultiTraceStarkProver,
12 },
13};
14use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
15use p3_dft::Radix2DitParallel;
16use p3_fri::{FriConfig, TwoAdicFriPcs};
17use p3_merkle_tree::MerkleTreeMmcs;
18use p3_poseidon2::ExternalLayerConstants;
19use p3_symmetric::{CryptographicPermutation, PaddingFreeSponge, TruncatedPermutation};
20use rand::{rngs::StdRng, SeedableRng};
21use zkhash::{
22 ark_ff::PrimeField as _, fields::babybear::FpBabyBear as HorizenBabyBear,
23 poseidon2::poseidon2_instance_babybear::RC16,
24};
25
26use super::{
27 instrument::{HashStatistics, InstrumentCounter, Instrumented, StarkHashStatistics},
28 FriParameters,
29};
30use crate::{
31 assert_sc_compatible_with_serde,
32 config::{
33 fri_params::SecurityParameters, log_up_params::log_up_security_params_baby_bear_100_bits,
34 },
35 engine::{StarkEngine, StarkEngineWithHashInstrumentation, StarkFriEngine},
36};
37
38const RATE: usize = 8;
39const WIDTH: usize = 16; const DIGEST_WIDTH: usize = 8;
42
43type Val = BabyBear;
44type PackedVal = <Val as Field>::Packing;
45type Challenge = BinomialExtensionField<Val, 4>;
46type Perm = Poseidon2BabyBear<WIDTH>;
47type InstrPerm = Instrumented<Perm>;
48
49type Hash<P> = PaddingFreeSponge<P, WIDTH, RATE, DIGEST_WIDTH>;
51type Compress<P> = TruncatedPermutation<P, 2, DIGEST_WIDTH, WIDTH>;
52type ValMmcs<P> =
53 MerkleTreeMmcs<PackedVal, <Val as Field>::Packing, Hash<P>, Compress<P>, DIGEST_WIDTH>;
54type ChallengeMmcs<P> = ExtensionMmcs<Val, Challenge, ValMmcs<P>>;
55pub type Challenger<P> = DuplexChallenger<Val, P, WIDTH, RATE>;
56type Dft = Radix2DitParallel<Val>;
57type Pcs<P> = TwoAdicFriPcs<Val, Dft, ValMmcs<P>, ChallengeMmcs<P>>;
58type RapPhase<P> = FriLogUpPhase<Val, Challenge, Challenger<P>>;
59
60pub type BabyBearPermutationConfig<P> = StarkConfig<Pcs<P>, RapPhase<P>, Challenge, Challenger<P>>;
61pub type BabyBearPoseidon2Config = BabyBearPermutationConfig<Perm>;
62pub type BabyBearPoseidon2Engine = BabyBearPermutationEngine<Perm>;
63
64assert_sc_compatible_with_serde!(BabyBearPoseidon2Config);
65
66pub struct BabyBearPermutationEngine<P>
67where
68 P: CryptographicPermutation<[Val; WIDTH]>
69 + CryptographicPermutation<[PackedVal; WIDTH]>
70 + Clone,
71{
72 pub fri_params: FriParameters,
73 pub device: CpuDevice<BabyBearPermutationConfig<P>>,
74 pub perm: P,
75 pub max_constraint_degree: usize,
76}
77
78impl<P> StarkEngine for BabyBearPermutationEngine<P>
79where
80 P: CryptographicPermutation<[Val; WIDTH]>
81 + CryptographicPermutation<[PackedVal; WIDTH]>
82 + Clone,
83{
84 type SC = BabyBearPermutationConfig<P>;
85 type PB = CpuBackend<Self::SC>;
86 type PD = CpuDevice<Self::SC>;
87
88 fn config(&self) -> &BabyBearPermutationConfig<P> {
89 &self.device.config
90 }
91
92 fn device(&self) -> &CpuDevice<BabyBearPermutationConfig<P>> {
93 &self.device
94 }
95
96 fn prover(&self) -> MultiTraceStarkProver<BabyBearPermutationConfig<P>> {
97 MultiTraceStarkProver::new(
98 CpuBackend::default(),
99 self.device.clone(),
100 self.new_challenger(),
101 )
102 }
103
104 fn max_constraint_degree(&self) -> Option<usize> {
105 Some(self.max_constraint_degree)
106 }
107
108 fn new_challenger(&self) -> Challenger<P> {
109 Challenger::new(self.perm.clone())
110 }
111}
112
113impl<P> StarkEngineWithHashInstrumentation for BabyBearPermutationEngine<Instrumented<P>>
114where
115 P: CryptographicPermutation<[Val; WIDTH]>
116 + CryptographicPermutation<[PackedVal; WIDTH]>
117 + Clone,
118{
119 fn clear_instruments(&mut self) {
120 self.perm.input_lens_by_type.lock().unwrap().clear();
121 }
122 fn stark_hash_statistics<T>(&self, custom: T) -> StarkHashStatistics<T> {
123 let counter = self.perm.input_lens_by_type.lock().unwrap();
124 let permutations = counter.iter().fold(0, |total, (name, lens)| {
125 if name == type_name::<[Val; WIDTH]>() {
126 let count: usize = lens.iter().sum();
127 println!("Permutation: {name}, Count: {count}");
128 total + count
129 } else {
130 panic!("Permutation type not yet supported: {}", name);
131 }
132 });
133
134 StarkHashStatistics {
135 name: type_name::<P>().to_string(),
136 stats: HashStatistics { permutations },
137 fri_params: self.fri_params,
138 custom,
139 }
140 }
141}
142
143pub fn default_engine() -> BabyBearPoseidon2Engine {
145 default_engine_impl(FriParameters::standard_fast())
146}
147
148fn default_engine_impl(fri_params: FriParameters) -> BabyBearPoseidon2Engine {
150 let perm = default_perm();
151 let security_params = SecurityParameters {
152 fri_params,
153 log_up_params: log_up_security_params_baby_bear_100_bits(),
154 };
155 engine_from_perm(perm, security_params)
156}
157
158pub fn default_config(perm: &Perm) -> BabyBearPoseidon2Config {
160 config_from_perm(perm, SecurityParameters::standard_fast())
161}
162
163pub fn engine_from_perm<P>(
164 perm: P,
165 security_params: SecurityParameters,
166) -> BabyBearPermutationEngine<P>
167where
168 P: CryptographicPermutation<[Val; WIDTH]>
169 + CryptographicPermutation<[PackedVal; WIDTH]>
170 + Clone,
171{
172 let fri_params = security_params.fri_params;
173 let max_constraint_degree = fri_params.max_constraint_degree();
174 let config = config_from_perm(&perm, security_params);
175 BabyBearPermutationEngine {
176 device: CpuDevice::new(Arc::new(config), fri_params.log_blowup),
177 perm,
178 fri_params,
179 max_constraint_degree,
180 }
181}
182
183pub fn config_from_perm<P>(
184 perm: &P,
185 security_params: SecurityParameters,
186) -> BabyBearPermutationConfig<P>
187where
188 P: CryptographicPermutation<[Val; WIDTH]>
189 + CryptographicPermutation<[PackedVal; WIDTH]>
190 + Clone,
191{
192 let hash = Hash::new(perm.clone());
193 let compress = Compress::new(perm.clone());
194 let val_mmcs = ValMmcs::new(hash, compress);
195 let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone());
196 let dft = Dft::default();
197 let SecurityParameters {
198 fri_params,
199 log_up_params,
200 } = security_params;
201 let fri_config = FriConfig {
202 log_blowup: fri_params.log_blowup,
203 log_final_poly_len: fri_params.log_final_poly_len,
204 num_queries: fri_params.num_queries,
205 proof_of_work_bits: fri_params.proof_of_work_bits,
206 mmcs: challenge_mmcs,
207 };
208 let pcs = Pcs::new(dft, val_mmcs, fri_config);
209 let rap_phase = FriLogUpPhase::new(log_up_params, fri_params.log_blowup);
210 BabyBearPermutationConfig::new(pcs, rap_phase)
211}
212
213pub fn default_perm() -> Perm {
216 let (external_constants, internal_constants) = horizen_round_consts_16();
217 Perm::new(external_constants, internal_constants)
218}
219
220pub fn random_perm() -> Perm {
221 let seed = [42; 32];
222 let mut rng = StdRng::from_seed(seed);
223 Perm::new_from_rng_128(&mut rng)
224}
225
226pub fn random_instrumented_perm() -> InstrPerm {
227 let perm = random_perm();
228 Instrumented::new(perm)
229}
230
231fn horizen_to_p3(horizen_babybear: HorizenBabyBear) -> BabyBear {
232 BabyBear::from_canonical_u64(horizen_babybear.into_bigint().0[0])
233}
234
235pub fn horizen_round_consts_16() -> (ExternalLayerConstants<BabyBear, 16>, Vec<BabyBear>) {
236 let p3_rc16: Vec<Vec<BabyBear>> = RC16
237 .iter()
238 .map(|round| {
239 round
240 .iter()
241 .map(|babybear| horizen_to_p3(*babybear))
242 .collect()
243 })
244 .collect();
245
246 let rounds_f = 8;
247 let rounds_p = 13;
248 let rounds_f_beginning = rounds_f / 2;
249 let p_end = rounds_f_beginning + rounds_p;
250 let initial: Vec<[BabyBear; 16]> = p3_rc16[..rounds_f_beginning]
251 .iter()
252 .cloned()
253 .map(|round| round.try_into().unwrap())
254 .collect();
255 let terminal: Vec<[BabyBear; 16]> = p3_rc16[p_end..]
256 .iter()
257 .cloned()
258 .map(|round| round.try_into().unwrap())
259 .collect();
260 let internal_round_constants: Vec<BabyBear> = p3_rc16[rounds_f_beginning..p_end]
261 .iter()
262 .map(|round| round[0])
263 .collect();
264 (
265 ExternalLayerConstants::new(initial, terminal),
266 internal_round_constants,
267 )
268}
269
270#[allow(dead_code)]
273pub fn print_hash_counts(hash_counter: &InstrumentCounter, compress_counter: &InstrumentCounter) {
274 let hash_counter = hash_counter.lock().unwrap();
275 let mut hash_count = 0;
276 hash_counter.iter().for_each(|(name, lens)| {
277 if name == type_name::<(Val, [Val; DIGEST_WIDTH])>() {
278 let count = lens.iter().fold(0, |count, len| count + len.div_ceil(RATE));
279 println!("Hash: {name}, Count: {count}");
280 hash_count += count;
281 } else {
282 panic!("Hash type not yet supported: {}", name);
283 }
284 });
285 drop(hash_counter);
286 let compress_counter = compress_counter.lock().unwrap();
287 let mut compress_count = 0;
288 compress_counter.iter().for_each(|(name, lens)| {
289 if name == type_name::<[Val; DIGEST_WIDTH]>() {
290 let count = lens.iter().fold(0, |count, len| {
291 count + (DIGEST_WIDTH * len).div_ceil(WIDTH)
293 });
294 println!("Compress: {name}, Count: {count}");
295 compress_count += count;
296 } else {
297 panic!("Compress type not yet supported: {}", name);
298 }
299 });
300 let total_count = hash_count + compress_count;
301 println!("Total Count: {total_count}");
302}
303
304impl StarkFriEngine for BabyBearPoseidon2Engine {
305 fn new(fri_params: FriParameters) -> Self {
306 default_engine_impl(fri_params)
307 }
308 fn fri_params(&self) -> FriParameters {
309 self.fri_params
310 }
311}