1use std::sync::Arc;
2
3use openvm_stark_backend::{
4 config::StarkConfig,
5 interaction::fri_log_up::FriLogUpPhase,
6 keygen::MultiStarkKeygenBuilder,
7 p3_challenger::MultiField32Challenger,
8 p3_commit::ExtensionMmcs,
9 p3_field::extension::BinomialExtensionField,
10 prover::{
11 cpu::{CpuBackend, CpuDevice},
12 MultiTraceStarkProver,
13 },
14};
15use p3_baby_bear::BabyBear;
16use p3_bn254::{Bn254, Poseidon2Bn254};
17use p3_dft::Radix2DitParallel;
18use p3_fri::{FriParameters as P3FriParameters, TwoAdicFriPcs};
19use p3_merkle_tree::MerkleTreeMmcs;
20use p3_poseidon2::ExternalLayerConstants;
21use p3_symmetric::{CryptographicPermutation, MultiField32PaddingFreeSponge, TruncatedPermutation};
22use zkhash::{
23 ark_ff::PrimeField as _, fields::bn256::FpBN256 as ark_FpBN256,
24 poseidon2::poseidon2_instance_bn256::RC3,
25};
26
27use super::FriParameters;
28use crate::{
29 assert_sc_compatible_with_serde,
30 config::fri_params::{
31 SecurityParameters, MAX_BATCH_SIZE_LOG_BLOWUP_1, MAX_BATCH_SIZE_LOG_BLOWUP_2,
32 MAX_NUM_CONSTRAINTS,
33 },
34 engine::{StarkEngine, StarkFriEngine},
35};
36
37const WIDTH: usize = 3;
38const RATE: usize = 16;
40const DIGEST_WIDTH: usize = 1;
41
42type Val = BabyBear;
44type Challenge = BinomialExtensionField<Val, 4>;
45type Perm = Poseidon2Bn254<WIDTH>;
46type Hash<P> = MultiField32PaddingFreeSponge<Val, Bn254, P, WIDTH, RATE, DIGEST_WIDTH>;
47type Compress<P> = TruncatedPermutation<P, 2, 1, WIDTH>;
48type ValMmcs<P> = MerkleTreeMmcs<BabyBear, Bn254, Hash<P>, Compress<P>, 1>;
49type ChallengeMmcs<P> = ExtensionMmcs<Val, Challenge, ValMmcs<P>>;
50type Dft = Radix2DitParallel<Val>;
51type Challenger<P> = MultiField32Challenger<Val, Bn254, P, WIDTH, 2>;
52type Pcs<P> = TwoAdicFriPcs<Val, Dft, ValMmcs<P>, ChallengeMmcs<P>>;
53type RapPhase<P> = FriLogUpPhase<Val, Challenge, Challenger<P>>;
54
55pub type BabyBearPermutationRootConfig<P> =
56 StarkConfig<Pcs<P>, RapPhase<P>, Challenge, Challenger<P>>;
57pub type BabyBearPoseidon2RootConfig = BabyBearPermutationRootConfig<Perm>;
58pub type BabyBearPoseidon2RootEngine = BabyBearPermutationRootEngine<Perm>;
59
60assert_sc_compatible_with_serde!(BabyBearPoseidon2RootConfig);
61
62pub struct BabyBearPermutationRootEngine<P>
63where
64 P: CryptographicPermutation<[Bn254; WIDTH]> + Clone,
65{
66 pub fri_params: FriParameters,
67 pub device: CpuDevice<BabyBearPermutationRootConfig<P>>,
68 pub perm: P,
69 pub max_constraint_degree: usize,
70}
71
72impl<P> StarkEngine for BabyBearPermutationRootEngine<P>
73where
74 P: CryptographicPermutation<[Bn254; WIDTH]> + Clone,
75{
76 type SC = BabyBearPermutationRootConfig<P>;
77 type PB = CpuBackend<Self::SC>;
78 type PD = CpuDevice<Self::SC>;
79
80 fn config(&self) -> &BabyBearPermutationRootConfig<P> {
81 &self.device.config
82 }
83
84 fn device(&self) -> &CpuDevice<BabyBearPermutationRootConfig<P>> {
85 &self.device
86 }
87
88 fn keygen_builder(&self) -> MultiStarkKeygenBuilder<'_, Self::SC> {
89 let mut builder = MultiStarkKeygenBuilder::new(self.config());
90 builder.set_max_constraint_degree(self.max_constraint_degree);
91 let max_batch_size = if self.fri_params.log_blowup == 1 {
92 MAX_BATCH_SIZE_LOG_BLOWUP_1
93 } else {
94 MAX_BATCH_SIZE_LOG_BLOWUP_2
95 };
96 builder.max_batch_size = Some(max_batch_size);
97 builder.max_num_constraints = Some(MAX_NUM_CONSTRAINTS);
98
99 builder
100 }
101
102 fn prover(&self) -> MultiTraceStarkProver<BabyBearPermutationRootConfig<P>> {
103 MultiTraceStarkProver::new(
104 CpuBackend::default(),
105 self.device.clone(),
106 self.new_challenger(),
107 )
108 }
109
110 fn max_constraint_degree(&self) -> Option<usize> {
111 Some(self.max_constraint_degree)
112 }
113
114 fn new_challenger(&self) -> Challenger<P> {
115 Challenger::new(self.perm.clone()).unwrap()
116 }
117}
118
119pub fn default_engine() -> BabyBearPoseidon2RootEngine {
121 default_engine_impl(SecurityParameters::standard_fast())
122}
123
124fn default_engine_impl(security_params: SecurityParameters) -> BabyBearPoseidon2RootEngine {
126 let perm = root_perm();
127 engine_from_perm(perm, security_params)
128}
129
130pub fn default_config(perm: &Perm) -> BabyBearPoseidon2RootConfig {
132 config_from_perm(perm, SecurityParameters::standard_fast())
133}
134
135pub fn engine_from_perm<P>(
136 perm: P,
137 security_params: SecurityParameters,
138) -> BabyBearPermutationRootEngine<P>
139where
140 P: CryptographicPermutation<[Bn254; WIDTH]> + Clone,
141{
142 let fri_params = security_params.fri_params;
143 let max_constraint_degree = fri_params.max_constraint_degree();
144 let config = config_from_perm(&perm, security_params);
145 BabyBearPermutationRootEngine {
146 device: CpuDevice::new(Arc::new(config), fri_params.log_blowup),
147 perm,
148 fri_params,
149 max_constraint_degree,
150 }
151}
152
153pub fn config_from_perm<P>(
154 perm: &P,
155 security_params: SecurityParameters,
156) -> BabyBearPermutationRootConfig<P>
157where
158 P: CryptographicPermutation<[Bn254; WIDTH]> + Clone,
159{
160 let hash = Hash::new(perm.clone()).unwrap();
161 let compress = Compress::new(perm.clone());
162 let val_mmcs = ValMmcs::new(hash, compress);
163 let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone());
164 let dft = Dft::default();
165 let SecurityParameters {
166 fri_params,
167 log_up_params,
168 deep_ali_params,
169 } = security_params;
170 let fri_config = P3FriParameters {
171 log_blowup: fri_params.log_blowup,
172 log_final_poly_len: fri_params.log_final_poly_len,
173 num_queries: fri_params.num_queries,
174 commit_proof_of_work_bits: fri_params.commit_proof_of_work_bits,
175 query_proof_of_work_bits: fri_params.query_proof_of_work_bits,
176 mmcs: challenge_mmcs,
177 };
178 let pcs = Pcs::new(dft, val_mmcs, fri_config);
179 let challenger = Challenger::new(perm.clone()).unwrap();
180 let rap_phase = FriLogUpPhase::new(log_up_params, fri_params.log_blowup);
181 BabyBearPermutationRootConfig::new(pcs, challenger, rap_phase, deep_ali_params)
182}
183
184pub fn root_perm() -> Perm {
186 const ROUNDS_F: usize = 8;
187 const ROUNDS_P: usize = 56;
188 let mut round_constants = bn254_poseidon2_rc3();
189 let internal_end = (ROUNDS_F / 2) + ROUNDS_P;
190 let terminal = round_constants.split_off(internal_end);
191 let internal_round_constants = round_constants.split_off(ROUNDS_F / 2);
192 let internal_round_constants = internal_round_constants
193 .into_iter()
194 .map(|vec| vec[0])
195 .collect::<Vec<_>>();
196 let initial = round_constants;
197
198 let external_round_constants = ExternalLayerConstants::new(initial, terminal);
199 Perm::new(external_round_constants, internal_round_constants)
200}
201
202fn bn254_from_ark_ff(input: ark_FpBN256) -> Bn254 {
203 let limbs_le = input.into_bigint().0;
204 let bytes = limbs_le
206 .iter()
207 .flat_map(|limb| limb.to_le_bytes())
208 .collect::<Vec<_>>();
209 let big = num_bigint::BigUint::from_bytes_le(&bytes);
210 Bn254::from_biguint(big).expect("Invalid BN254 element")
211}
212
213fn bn254_poseidon2_rc3() -> Vec<[Bn254; 3]> {
214 RC3.iter()
215 .map(|vec| {
216 vec.iter()
217 .cloned()
218 .map(bn254_from_ark_ff)
219 .collect::<Vec<_>>()
220 .try_into()
221 .unwrap()
222 })
223 .collect()
224}
225
226impl StarkFriEngine for BabyBearPoseidon2RootEngine {
227 fn new(fri_params: FriParameters) -> Self {
228 let security_params = SecurityParameters::new_baby_bear_100_bits(fri_params);
229 default_engine_impl(security_params)
230 }
231 fn fri_params(&self) -> FriParameters {
232 self.fri_params
233 }
234}