1use std::sync::Arc;
2
3use num_bigint::BigUint;
4use openvm_algebra_transpiler::Fp2Opcode;
5use openvm_circuit::{
6 arch::{
7 AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge,
8 ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
9 VmExecutionExtension, VmProverExtension,
10 },
11 system::{memory::SharedMemoryHelper, SystemPort},
12};
13use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
14use openvm_circuit_primitives::{
15 bitwise_op_lookup::{
16 BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
17 SharedBitwiseOperationLookupChip,
18 },
19 var_range::VariableRangeCheckerBus,
20};
21use openvm_instructions::{LocalOpcode, VmOpcode};
22use openvm_mod_circuit_builder::ExprBuilderConfig;
23use openvm_stark_backend::{
24 config::{StarkGenericConfig, Val},
25 p3_field::PrimeField32,
26 prover::cpu::{CpuBackend, CpuDevice},
27};
28use openvm_stark_sdk::engine::StarkEngine;
29use serde::{Deserialize, Serialize};
30use serde_with::{serde_as, DisplayFromStr};
31use strum::EnumCount;
32
33use crate::{
34 fp2_chip::{
35 get_fp2_addsub_air, get_fp2_addsub_chip, get_fp2_addsub_step, get_fp2_muldiv_air,
36 get_fp2_muldiv_chip, get_fp2_muldiv_step, Fp2Air, Fp2Executor,
37 },
38 AlgebraCpuProverExt, ModularExtension,
39};
40
41#[serde_as]
42#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
43pub struct Fp2Extension {
44 #[serde_as(as = "Vec<(_, DisplayFromStr)>")]
47 pub supported_moduli: Vec<(String, BigUint)>,
48}
49
50impl Fp2Extension {
51 pub fn generate_complex_init(&self, modular_config: &ModularExtension) -> String {
52 fn get_index_of_modulus(modulus: &BigUint, modular_config: &ModularExtension) -> usize {
53 modular_config
54 .supported_moduli
55 .iter()
56 .position(|m| m == modulus)
57 .expect("Modulus used in Fp2Extension not found in ModularExtension")
58 }
59
60 let supported_moduli = self
61 .supported_moduli
62 .iter()
63 .map(|(name, modulus)| {
64 format!(
65 "\"{}\" {{ mod_idx = {} }}",
66 name,
67 get_index_of_modulus(modulus, modular_config)
68 )
69 })
70 .collect::<Vec<String>>()
71 .join(", ");
72
73 format!("openvm_algebra_guest::complex_macros::complex_init! {{ {supported_moduli} }}")
74 }
75}
76
77#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
78#[cfg_attr(
79 feature = "aot",
80 derive(
81 openvm_circuit_derive::AotExecutor,
82 openvm_circuit_derive::AotMeteredExecutor
83 )
84)]
85pub enum Fp2ExtensionExecutor {
86 Fp2AddSubRv32_32(Fp2Executor<2, 32>), Fp2MulDivRv32_32(Fp2Executor<2, 32>), Fp2AddSubRv32_48(Fp2Executor<6, 16>), Fp2MulDivRv32_48(Fp2Executor<6, 16>), }
93
94impl<F: PrimeField32> VmExecutionExtension<F> for Fp2Extension {
95 type Executor = Fp2ExtensionExecutor;
96
97 fn extend_execution(
98 &self,
99 inventory: &mut ExecutorInventoryBuilder<F, Fp2ExtensionExecutor>,
100 ) -> Result<(), ExecutorInventoryError> {
101 let pointer_max_bits = inventory.pointer_max_bits();
102 let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16);
104 for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() {
105 let bytes = modulus.bits().div_ceil(8);
107 let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
108
109 if bytes <= 32 {
110 let config = ExprBuilderConfig {
111 modulus: modulus.clone(),
112 num_limbs: 32,
113 limb_bits: 8,
114 };
115 let addsub = get_fp2_addsub_step(
116 config.clone(),
117 dummy_range_checker_bus,
118 pointer_max_bits,
119 start_offset,
120 );
121
122 inventory.add_executor(
123 Fp2ExtensionExecutor::Fp2AddSubRv32_32(addsub),
124 ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize))
125 .map(|x| VmOpcode::from_usize(x + start_offset)),
126 )?;
127
128 let muldiv = get_fp2_muldiv_step(
129 config,
130 dummy_range_checker_bus,
131 pointer_max_bits,
132 start_offset,
133 );
134
135 inventory.add_executor(
136 Fp2ExtensionExecutor::Fp2MulDivRv32_32(muldiv),
137 ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize))
138 .map(|x| VmOpcode::from_usize(x + start_offset)),
139 )?;
140 } else if bytes <= 48 {
141 let config = ExprBuilderConfig {
142 modulus: modulus.clone(),
143 num_limbs: 48,
144 limb_bits: 8,
145 };
146 let addsub = get_fp2_addsub_step(
147 config.clone(),
148 dummy_range_checker_bus,
149 pointer_max_bits,
150 start_offset,
151 );
152
153 inventory.add_executor(
154 Fp2ExtensionExecutor::Fp2AddSubRv32_48(addsub),
155 ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize))
156 .map(|x| VmOpcode::from_usize(x + start_offset)),
157 )?;
158
159 let muldiv = get_fp2_muldiv_step(
160 config,
161 dummy_range_checker_bus,
162 pointer_max_bits,
163 start_offset,
164 );
165
166 inventory.add_executor(
167 Fp2ExtensionExecutor::Fp2MulDivRv32_48(muldiv),
168 ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize))
169 .map(|x| VmOpcode::from_usize(x + start_offset)),
170 )?;
171 } else {
172 panic!("Modulus too large");
173 }
174 }
175 Ok(())
176 }
177}
178
179impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for Fp2Extension {
180 fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
181 let SystemPort {
182 execution_bus,
183 program_bus,
184 memory_bridge,
185 } = inventory.system().port();
186
187 let exec_bridge = ExecutionBridge::new(execution_bus, program_bus);
188 let range_checker_bus = inventory.range_checker().bus;
189 let pointer_max_bits = inventory.pointer_max_bits();
190
191 let bitwise_lu = {
192 let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
194 if let Some(air) = existing_air {
195 air.bus
196 } else {
197 let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
198 let air = BitwiseOperationLookupAir::<8>::new(bus);
199 inventory.add_air(air);
200 air.bus
201 }
202 };
203 for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() {
204 let bytes = modulus.bits().div_ceil(8);
206 let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
207
208 if bytes <= 32 {
209 let config = ExprBuilderConfig {
210 modulus: modulus.clone(),
211 num_limbs: 32,
212 limb_bits: 8,
213 };
214
215 let addsub = get_fp2_addsub_air::<2, 32>(
216 exec_bridge,
217 memory_bridge,
218 config.clone(),
219 range_checker_bus,
220 bitwise_lu,
221 pointer_max_bits,
222 start_offset,
223 );
224 inventory.add_air(addsub);
225
226 let muldiv = get_fp2_muldiv_air::<2, 32>(
227 exec_bridge,
228 memory_bridge,
229 config,
230 range_checker_bus,
231 bitwise_lu,
232 pointer_max_bits,
233 start_offset,
234 );
235 inventory.add_air(muldiv);
236 } else if bytes <= 48 {
237 let config = ExprBuilderConfig {
238 modulus: modulus.clone(),
239 num_limbs: 48,
240 limb_bits: 8,
241 };
242
243 let addsub = get_fp2_addsub_air::<6, 16>(
244 exec_bridge,
245 memory_bridge,
246 config.clone(),
247 range_checker_bus,
248 bitwise_lu,
249 pointer_max_bits,
250 start_offset,
251 );
252 inventory.add_air(addsub);
253
254 let muldiv = get_fp2_muldiv_air::<6, 16>(
255 exec_bridge,
256 memory_bridge,
257 config,
258 range_checker_bus,
259 bitwise_lu,
260 pointer_max_bits,
261 start_offset,
262 );
263 inventory.add_air(muldiv);
264 } else {
265 panic!("Modulus too large");
266 }
267 }
268
269 Ok(())
270 }
271}
272
273impl<E, SC, RA> VmProverExtension<E, RA, Fp2Extension> for AlgebraCpuProverExt
276where
277 SC: StarkGenericConfig,
278 E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
279 RA: RowMajorMatrixArena<Val<SC>>,
280 Val<SC>: PrimeField32,
281{
282 fn extend_prover(
283 &self,
284 extension: &Fp2Extension,
285 inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
286 ) -> Result<(), ChipInventoryError> {
287 let range_checker = inventory.range_checker()?.clone();
288 let timestamp_max_bits = inventory.timestamp_max_bits();
289 let pointer_max_bits = inventory.airs().pointer_max_bits();
290 let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
291 let bitwise_lu = {
292 let existing_chip = inventory
293 .find_chip::<SharedBitwiseOperationLookupChip<8>>()
294 .next();
295 if let Some(chip) = existing_chip {
296 chip.clone()
297 } else {
298 let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
299 let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
300 inventory.add_periphery_chip(chip.clone());
301 chip
302 }
303 };
304 for (_, modulus) in extension.supported_moduli.iter() {
305 let bytes = modulus.bits().div_ceil(8);
307
308 if bytes <= 32 {
309 let config = ExprBuilderConfig {
310 modulus: modulus.clone(),
311 num_limbs: 32,
312 limb_bits: 8,
313 };
314
315 inventory.next_air::<Fp2Air<2, 32>>()?;
316 let addsub = get_fp2_addsub_chip::<Val<SC>, 2, 32>(
317 config.clone(),
318 mem_helper.clone(),
319 range_checker.clone(),
320 bitwise_lu.clone(),
321 pointer_max_bits,
322 );
323 inventory.add_executor_chip(addsub);
324
325 inventory.next_air::<Fp2Air<2, 32>>()?;
326 let muldiv = get_fp2_muldiv_chip::<Val<SC>, 2, 32>(
327 config,
328 mem_helper.clone(),
329 range_checker.clone(),
330 bitwise_lu.clone(),
331 pointer_max_bits,
332 );
333 inventory.add_executor_chip(muldiv);
334 } else if bytes <= 48 {
335 let config = ExprBuilderConfig {
336 modulus: modulus.clone(),
337 num_limbs: 48,
338 limb_bits: 8,
339 };
340
341 inventory.next_air::<Fp2Air<6, 16>>()?;
342 let addsub = get_fp2_addsub_chip::<Val<SC>, 6, 16>(
343 config.clone(),
344 mem_helper.clone(),
345 range_checker.clone(),
346 bitwise_lu.clone(),
347 pointer_max_bits,
348 );
349 inventory.add_executor_chip(addsub);
350
351 inventory.next_air::<Fp2Air<6, 16>>()?;
352 let muldiv = get_fp2_muldiv_chip::<Val<SC>, 6, 16>(
353 config,
354 mem_helper.clone(),
355 range_checker.clone(),
356 bitwise_lu.clone(),
357 pointer_max_bits,
358 );
359 inventory.add_executor_chip(muldiv);
360 } else {
361 panic!("Modulus too large");
362 }
363 }
364
365 Ok(())
366 }
367}