1use std::{any::type_name, sync::Arc};
2
3use openvm_stark_backend::{
4 config::StarkConfig,
5 interaction::fri_log_up::FriLogUpPhase,
6 keygen::MultiStarkKeygenBuilder,
7 p3_challenger::DuplexChallenger,
8 p3_commit::ExtensionMmcs,
9 p3_field::{extension::BinomialExtensionField, Field, PrimeCharacteristicRing},
10 prover::{
11 cpu::{CpuBackend, CpuDevice},
12 MultiTraceStarkProver,
13 },
14};
15use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
16use p3_dft::Radix2DitParallel;
17use p3_fri::{FriParameters as P3FriParameters, TwoAdicFriPcs};
18use p3_merkle_tree::MerkleTreeMmcs;
19use p3_poseidon2::ExternalLayerConstants;
20use p3_symmetric::{CryptographicPermutation, PaddingFreeSponge, TruncatedPermutation};
21use rand::{rngs::StdRng, SeedableRng};
22use zkhash::{
23 ark_ff::PrimeField as _, fields::babybear::FpBabyBear as HorizenBabyBear,
24 poseidon2::poseidon2_instance_babybear::RC16,
25};
26
27use super::{
28 instrument::{HashStatistics, InstrumentCounter, Instrumented, StarkHashStatistics},
29 FriParameters,
30};
31use crate::{
32 assert_sc_compatible_with_serde,
33 config::fri_params::{
34 SecurityParameters, MAX_BATCH_SIZE_LOG_BLOWUP_1, MAX_BATCH_SIZE_LOG_BLOWUP_2,
35 MAX_NUM_CONSTRAINTS,
36 },
37 engine::{StarkEngine, StarkEngineWithHashInstrumentation, StarkFriEngine},
38};
39
40const RATE: usize = 8;
41const WIDTH: usize = 16; const DIGEST_WIDTH: usize = 8;
44
45type Val = BabyBear;
46type PackedVal = <Val as Field>::Packing;
47type Challenge = BinomialExtensionField<Val, 4>;
48type Perm = Poseidon2BabyBear<WIDTH>;
49type InstrPerm = Instrumented<Perm>;
50
51type Hash<P> = PaddingFreeSponge<P, WIDTH, RATE, DIGEST_WIDTH>;
53type Compress<P> = TruncatedPermutation<P, 2, DIGEST_WIDTH, WIDTH>;
54type ValMmcs<P> =
55 MerkleTreeMmcs<PackedVal, <Val as Field>::Packing, Hash<P>, Compress<P>, DIGEST_WIDTH>;
56type ChallengeMmcs<P> = ExtensionMmcs<Val, Challenge, ValMmcs<P>>;
57pub type Challenger<P> = DuplexChallenger<Val, P, WIDTH, RATE>;
58type Dft = Radix2DitParallel<Val>;
59type Pcs<P> = TwoAdicFriPcs<Val, Dft, ValMmcs<P>, ChallengeMmcs<P>>;
60type RapPhase<P> = FriLogUpPhase<Val, Challenge, Challenger<P>>;
61
62pub type BabyBearPermutationConfig<P> = StarkConfig<Pcs<P>, RapPhase<P>, Challenge, Challenger<P>>;
63pub type BabyBearPoseidon2Config = BabyBearPermutationConfig<Perm>;
64pub type BabyBearPoseidon2Engine = BabyBearPermutationEngine<Perm>;
65
66assert_sc_compatible_with_serde!(BabyBearPoseidon2Config);
67
68pub struct BabyBearPermutationEngine<P>
69where
70 P: CryptographicPermutation<[Val; WIDTH]>
71 + CryptographicPermutation<[PackedVal; WIDTH]>
72 + Clone,
73{
74 pub fri_params: FriParameters,
75 pub device: CpuDevice<BabyBearPermutationConfig<P>>,
76 pub perm: P,
77 pub max_constraint_degree: usize,
78}
79
80impl<P> StarkEngine for BabyBearPermutationEngine<P>
81where
82 P: CryptographicPermutation<[Val; WIDTH]>
83 + CryptographicPermutation<[PackedVal; WIDTH]>
84 + Clone,
85{
86 type SC = BabyBearPermutationConfig<P>;
87 type PB = CpuBackend<Self::SC>;
88 type PD = CpuDevice<Self::SC>;
89
90 fn config(&self) -> &BabyBearPermutationConfig<P> {
91 &self.device.config
92 }
93
94 fn device(&self) -> &CpuDevice<BabyBearPermutationConfig<P>> {
95 &self.device
96 }
97
98 fn keygen_builder(&self) -> MultiStarkKeygenBuilder<'_, Self::SC> {
99 let mut builder = MultiStarkKeygenBuilder::new(self.config());
100 builder.set_max_constraint_degree(self.max_constraint_degree);
101 let max_batch_size = if self.fri_params.log_blowup == 1 {
102 MAX_BATCH_SIZE_LOG_BLOWUP_1
103 } else {
104 MAX_BATCH_SIZE_LOG_BLOWUP_2
105 };
106 builder.max_batch_size = Some(max_batch_size);
107 builder.max_num_constraints = Some(MAX_NUM_CONSTRAINTS);
108
109 builder
110 }
111
112 fn prover(&self) -> MultiTraceStarkProver<BabyBearPermutationConfig<P>> {
113 MultiTraceStarkProver::new(
114 CpuBackend::default(),
115 self.device.clone(),
116 self.new_challenger(),
117 )
118 }
119
120 fn max_constraint_degree(&self) -> Option<usize> {
121 Some(self.max_constraint_degree)
122 }
123
124 fn new_challenger(&self) -> Challenger<P> {
125 Challenger::new(self.perm.clone())
126 }
127}
128
129impl<P> StarkEngineWithHashInstrumentation for BabyBearPermutationEngine<Instrumented<P>>
130where
131 P: CryptographicPermutation<[Val; WIDTH]>
132 + CryptographicPermutation<[PackedVal; WIDTH]>
133 + Clone,
134{
135 fn clear_instruments(&mut self) {
136 self.perm.input_lens_by_type.lock().unwrap().clear();
137 }
138 fn stark_hash_statistics<T>(&self, custom: T) -> StarkHashStatistics<T> {
139 let counter = self.perm.input_lens_by_type.lock().unwrap();
140 let permutations = counter.iter().fold(0, |total, (name, lens)| {
141 if name == type_name::<[Val; WIDTH]>() {
142 let count: usize = lens.iter().sum();
143 println!("Permutation: {name}, Count: {count}");
144 total + count
145 } else {
146 panic!("Permutation type not yet supported: {}", name);
147 }
148 });
149
150 StarkHashStatistics {
151 name: type_name::<P>().to_string(),
152 stats: HashStatistics { permutations },
153 fri_params: self.fri_params,
154 custom,
155 }
156 }
157}
158
159pub fn default_engine() -> BabyBearPoseidon2Engine {
161 default_engine_impl(FriParameters::standard_fast())
162}
163
164fn default_engine_impl(fri_params: FriParameters) -> BabyBearPoseidon2Engine {
166 let perm = default_perm();
167 let security_params = SecurityParameters::new_baby_bear_100_bits(fri_params);
168 engine_from_perm(perm, security_params)
169}
170
171pub fn default_config(perm: &Perm) -> BabyBearPoseidon2Config {
173 config_from_perm(perm, SecurityParameters::standard_fast())
174}
175
176pub fn engine_from_perm<P>(
177 perm: P,
178 security_params: SecurityParameters,
179) -> BabyBearPermutationEngine<P>
180where
181 P: CryptographicPermutation<[Val; WIDTH]>
182 + CryptographicPermutation<[PackedVal; WIDTH]>
183 + Clone,
184{
185 let fri_params = security_params.fri_params;
186 let max_constraint_degree = fri_params.max_constraint_degree();
187 let config = config_from_perm(&perm, security_params);
188 BabyBearPermutationEngine {
189 device: CpuDevice::new(Arc::new(config), fri_params.log_blowup),
190 perm,
191 fri_params,
192 max_constraint_degree,
193 }
194}
195
196pub fn config_from_perm<P>(
197 perm: &P,
198 security_params: SecurityParameters,
199) -> BabyBearPermutationConfig<P>
200where
201 P: CryptographicPermutation<[Val; WIDTH]>
202 + CryptographicPermutation<[PackedVal; WIDTH]>
203 + Clone,
204{
205 let hash = Hash::new(perm.clone());
206 let compress = Compress::new(perm.clone());
207 let val_mmcs = ValMmcs::new(hash, compress);
208 let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone());
209 let dft = Dft::default();
210 let SecurityParameters {
211 fri_params,
212 log_up_params,
213 deep_ali_params,
214 } = security_params;
215 let fri_config = P3FriParameters {
216 log_blowup: fri_params.log_blowup,
217 log_final_poly_len: fri_params.log_final_poly_len,
218 num_queries: fri_params.num_queries,
219 commit_proof_of_work_bits: fri_params.commit_proof_of_work_bits,
220 query_proof_of_work_bits: fri_params.query_proof_of_work_bits,
221 mmcs: challenge_mmcs,
222 };
223 let pcs = Pcs::new(dft, val_mmcs, fri_config);
224 let challenger = Challenger::new(perm.clone());
225 let rap_phase = FriLogUpPhase::new(log_up_params, fri_params.log_blowup);
226 BabyBearPermutationConfig::new(pcs, challenger, rap_phase, deep_ali_params)
227}
228
229pub fn default_perm() -> Perm {
232 let (external_constants, internal_constants) = horizen_round_consts_16();
233 Perm::new(external_constants, internal_constants)
234}
235
236pub fn random_perm() -> Perm {
237 let seed = [42; 32];
238 let mut rng = StdRng::from_seed(seed);
239 Perm::new_from_rng_128(&mut rng)
240}
241
242pub fn random_instrumented_perm() -> InstrPerm {
243 let perm = random_perm();
244 Instrumented::new(perm)
245}
246
247fn horizen_to_p3(horizen_babybear: HorizenBabyBear) -> BabyBear {
248 BabyBear::from_u64(horizen_babybear.into_bigint().0[0])
249}
250
251pub fn horizen_round_consts_16() -> (ExternalLayerConstants<BabyBear, 16>, Vec<BabyBear>) {
252 let p3_rc16: Vec<Vec<BabyBear>> = RC16
253 .iter()
254 .map(|round| {
255 round
256 .iter()
257 .map(|babybear| horizen_to_p3(*babybear))
258 .collect()
259 })
260 .collect();
261
262 let rounds_f = 8;
263 let rounds_p = 13;
264 let rounds_f_beginning = rounds_f / 2;
265 let p_end = rounds_f_beginning + rounds_p;
266 let initial: Vec<[BabyBear; 16]> = p3_rc16[..rounds_f_beginning]
267 .iter()
268 .cloned()
269 .map(|round| round.try_into().unwrap())
270 .collect();
271 let terminal: Vec<[BabyBear; 16]> = p3_rc16[p_end..]
272 .iter()
273 .cloned()
274 .map(|round| round.try_into().unwrap())
275 .collect();
276 let internal_round_constants: Vec<BabyBear> = p3_rc16[rounds_f_beginning..p_end]
277 .iter()
278 .map(|round| round[0])
279 .collect();
280 (
281 ExternalLayerConstants::new(initial, terminal),
282 internal_round_constants,
283 )
284}
285
286#[allow(dead_code)]
289pub fn print_hash_counts(hash_counter: &InstrumentCounter, compress_counter: &InstrumentCounter) {
290 let hash_counter = hash_counter.lock().unwrap();
291 let mut hash_count = 0;
292 hash_counter.iter().for_each(|(name, lens)| {
293 if name == type_name::<(Val, [Val; DIGEST_WIDTH])>() {
294 let count = lens.iter().fold(0, |count, len| count + len.div_ceil(RATE));
295 println!("Hash: {name}, Count: {count}");
296 hash_count += count;
297 } else {
298 panic!("Hash type not yet supported: {}", name);
299 }
300 });
301 drop(hash_counter);
302 let compress_counter = compress_counter.lock().unwrap();
303 let mut compress_count = 0;
304 compress_counter.iter().for_each(|(name, lens)| {
305 if name == type_name::<[Val; DIGEST_WIDTH]>() {
306 let count = lens.iter().fold(0, |count, len| {
307 count + (DIGEST_WIDTH * len).div_ceil(WIDTH)
309 });
310 println!("Compress: {name}, Count: {count}");
311 compress_count += count;
312 } else {
313 panic!("Compress type not yet supported: {}", name);
314 }
315 });
316 let total_count = hash_count + compress_count;
317 println!("Total Count: {total_count}");
318}
319
320impl StarkFriEngine for BabyBearPoseidon2Engine {
321 fn new(fri_params: FriParameters) -> Self {
322 default_engine_impl(fri_params)
323 }
324 fn fri_params(&self) -> FriParameters {
325 self.fri_params
326 }
327}