1use std::sync::Arc;
2
3use num_bigint::BigUint;
4use openvm_algebra_transpiler::Fp2Opcode;
5use openvm_circuit::{
6 arch::{
7 AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge,
8 ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
9 VmExecutionExtension, VmProverExtension,
10 },
11 system::{memory::SharedMemoryHelper, SystemPort},
12};
13use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
14use openvm_circuit_primitives::{
15 bitwise_op_lookup::{
16 BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
17 SharedBitwiseOperationLookupChip,
18 },
19 var_range::VariableRangeCheckerBus,
20};
21use openvm_instructions::{LocalOpcode, VmOpcode};
22use openvm_mod_circuit_builder::ExprBuilderConfig;
23use openvm_stark_backend::{
24 config::{StarkGenericConfig, Val},
25 p3_field::PrimeField32,
26 prover::cpu::{CpuBackend, CpuDevice},
27};
28use openvm_stark_sdk::engine::StarkEngine;
29use serde::{Deserialize, Serialize};
30use serde_with::{serde_as, DisplayFromStr};
31use strum::EnumCount;
32
33use crate::{
34 fp2_chip::{
35 get_fp2_addsub_air, get_fp2_addsub_chip, get_fp2_addsub_step, get_fp2_muldiv_air,
36 get_fp2_muldiv_chip, get_fp2_muldiv_step, Fp2Air, Fp2Executor,
37 },
38 AlgebraCpuProverExt, ModularExtension,
39};
40
41#[serde_as]
42#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
43pub struct Fp2Extension {
44 #[serde_as(as = "Vec<(_, DisplayFromStr)>")]
47 pub supported_moduli: Vec<(String, BigUint)>,
48}
49
50impl Fp2Extension {
51 pub fn generate_complex_init(&self, modular_config: &ModularExtension) -> String {
52 fn get_index_of_modulus(modulus: &BigUint, modular_config: &ModularExtension) -> usize {
53 modular_config
54 .supported_moduli
55 .iter()
56 .position(|m| m == modulus)
57 .expect("Modulus used in Fp2Extension not found in ModularExtension")
58 }
59
60 let supported_moduli = self
61 .supported_moduli
62 .iter()
63 .map(|(name, modulus)| {
64 format!(
65 "\"{}\" {{ mod_idx = {} }}",
66 name,
67 get_index_of_modulus(modulus, modular_config)
68 )
69 })
70 .collect::<Vec<String>>()
71 .join(", ");
72
73 format!("openvm_algebra_guest::complex_macros::complex_init! {{ {supported_moduli} }}")
74 }
75}
76
77#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
78pub enum Fp2ExtensionExecutor {
79 Fp2AddSubRv32_32(Fp2Executor<2, 32>), Fp2MulDivRv32_32(Fp2Executor<2, 32>), Fp2AddSubRv32_48(Fp2Executor<6, 16>), Fp2MulDivRv32_48(Fp2Executor<6, 16>), }
86
87impl<F: PrimeField32> VmExecutionExtension<F> for Fp2Extension {
88 type Executor = Fp2ExtensionExecutor;
89
90 fn extend_execution(
91 &self,
92 inventory: &mut ExecutorInventoryBuilder<F, Fp2ExtensionExecutor>,
93 ) -> Result<(), ExecutorInventoryError> {
94 let pointer_max_bits = inventory.pointer_max_bits();
95 let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16);
97 for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() {
98 let bytes = modulus.bits().div_ceil(8);
100 let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
101
102 if bytes <= 32 {
103 let config = ExprBuilderConfig {
104 modulus: modulus.clone(),
105 num_limbs: 32,
106 limb_bits: 8,
107 };
108 let addsub = get_fp2_addsub_step(
109 config.clone(),
110 dummy_range_checker_bus,
111 pointer_max_bits,
112 start_offset,
113 );
114
115 inventory.add_executor(
116 Fp2ExtensionExecutor::Fp2AddSubRv32_32(addsub),
117 ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize))
118 .map(|x| VmOpcode::from_usize(x + start_offset)),
119 )?;
120
121 let muldiv = get_fp2_muldiv_step(
122 config,
123 dummy_range_checker_bus,
124 pointer_max_bits,
125 start_offset,
126 );
127
128 inventory.add_executor(
129 Fp2ExtensionExecutor::Fp2MulDivRv32_32(muldiv),
130 ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize))
131 .map(|x| VmOpcode::from_usize(x + start_offset)),
132 )?;
133 } else if bytes <= 48 {
134 let config = ExprBuilderConfig {
135 modulus: modulus.clone(),
136 num_limbs: 48,
137 limb_bits: 8,
138 };
139 let addsub = get_fp2_addsub_step(
140 config.clone(),
141 dummy_range_checker_bus,
142 pointer_max_bits,
143 start_offset,
144 );
145
146 inventory.add_executor(
147 Fp2ExtensionExecutor::Fp2AddSubRv32_48(addsub),
148 ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize))
149 .map(|x| VmOpcode::from_usize(x + start_offset)),
150 )?;
151
152 let muldiv = get_fp2_muldiv_step(
153 config,
154 dummy_range_checker_bus,
155 pointer_max_bits,
156 start_offset,
157 );
158
159 inventory.add_executor(
160 Fp2ExtensionExecutor::Fp2MulDivRv32_48(muldiv),
161 ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize))
162 .map(|x| VmOpcode::from_usize(x + start_offset)),
163 )?;
164 } else {
165 panic!("Modulus too large");
166 }
167 }
168 Ok(())
169 }
170}
171
172impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for Fp2Extension {
173 fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
174 let SystemPort {
175 execution_bus,
176 program_bus,
177 memory_bridge,
178 } = inventory.system().port();
179
180 let exec_bridge = ExecutionBridge::new(execution_bus, program_bus);
181 let range_checker_bus = inventory.range_checker().bus;
182 let pointer_max_bits = inventory.pointer_max_bits();
183
184 let bitwise_lu = {
185 let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
187 if let Some(air) = existing_air {
188 air.bus
189 } else {
190 let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
191 let air = BitwiseOperationLookupAir::<8>::new(bus);
192 inventory.add_air(air);
193 air.bus
194 }
195 };
196 for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() {
197 let bytes = modulus.bits().div_ceil(8);
199 let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
200
201 if bytes <= 32 {
202 let config = ExprBuilderConfig {
203 modulus: modulus.clone(),
204 num_limbs: 32,
205 limb_bits: 8,
206 };
207
208 let addsub = get_fp2_addsub_air::<2, 32>(
209 exec_bridge,
210 memory_bridge,
211 config.clone(),
212 range_checker_bus,
213 bitwise_lu,
214 pointer_max_bits,
215 start_offset,
216 );
217 inventory.add_air(addsub);
218
219 let muldiv = get_fp2_muldiv_air::<2, 32>(
220 exec_bridge,
221 memory_bridge,
222 config,
223 range_checker_bus,
224 bitwise_lu,
225 pointer_max_bits,
226 start_offset,
227 );
228 inventory.add_air(muldiv);
229 } else if bytes <= 48 {
230 let config = ExprBuilderConfig {
231 modulus: modulus.clone(),
232 num_limbs: 48,
233 limb_bits: 8,
234 };
235
236 let addsub = get_fp2_addsub_air::<6, 16>(
237 exec_bridge,
238 memory_bridge,
239 config.clone(),
240 range_checker_bus,
241 bitwise_lu,
242 pointer_max_bits,
243 start_offset,
244 );
245 inventory.add_air(addsub);
246
247 let muldiv = get_fp2_muldiv_air::<6, 16>(
248 exec_bridge,
249 memory_bridge,
250 config,
251 range_checker_bus,
252 bitwise_lu,
253 pointer_max_bits,
254 start_offset,
255 );
256 inventory.add_air(muldiv);
257 } else {
258 panic!("Modulus too large");
259 }
260 }
261
262 Ok(())
263 }
264}
265
266impl<E, SC, RA> VmProverExtension<E, RA, Fp2Extension> for AlgebraCpuProverExt
269where
270 SC: StarkGenericConfig,
271 E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
272 RA: RowMajorMatrixArena<Val<SC>>,
273 Val<SC>: PrimeField32,
274{
275 fn extend_prover(
276 &self,
277 extension: &Fp2Extension,
278 inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
279 ) -> Result<(), ChipInventoryError> {
280 let range_checker = inventory.range_checker()?.clone();
281 let timestamp_max_bits = inventory.timestamp_max_bits();
282 let pointer_max_bits = inventory.airs().pointer_max_bits();
283 let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
284 let bitwise_lu = {
285 let existing_chip = inventory
286 .find_chip::<SharedBitwiseOperationLookupChip<8>>()
287 .next();
288 if let Some(chip) = existing_chip {
289 chip.clone()
290 } else {
291 let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
292 let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
293 inventory.add_periphery_chip(chip.clone());
294 chip
295 }
296 };
297 for (_, modulus) in extension.supported_moduli.iter() {
298 let bytes = modulus.bits().div_ceil(8);
300
301 if bytes <= 32 {
302 let config = ExprBuilderConfig {
303 modulus: modulus.clone(),
304 num_limbs: 32,
305 limb_bits: 8,
306 };
307
308 inventory.next_air::<Fp2Air<2, 32>>()?;
309 let addsub = get_fp2_addsub_chip::<Val<SC>, 2, 32>(
310 config.clone(),
311 mem_helper.clone(),
312 range_checker.clone(),
313 bitwise_lu.clone(),
314 pointer_max_bits,
315 );
316 inventory.add_executor_chip(addsub);
317
318 inventory.next_air::<Fp2Air<2, 32>>()?;
319 let muldiv = get_fp2_muldiv_chip::<Val<SC>, 2, 32>(
320 config,
321 mem_helper.clone(),
322 range_checker.clone(),
323 bitwise_lu.clone(),
324 pointer_max_bits,
325 );
326 inventory.add_executor_chip(muldiv);
327 } else if bytes <= 48 {
328 let config = ExprBuilderConfig {
329 modulus: modulus.clone(),
330 num_limbs: 48,
331 limb_bits: 8,
332 };
333
334 inventory.next_air::<Fp2Air<6, 16>>()?;
335 let addsub = get_fp2_addsub_chip::<Val<SC>, 6, 16>(
336 config.clone(),
337 mem_helper.clone(),
338 range_checker.clone(),
339 bitwise_lu.clone(),
340 pointer_max_bits,
341 );
342 inventory.add_executor_chip(addsub);
343
344 inventory.next_air::<Fp2Air<6, 16>>()?;
345 let muldiv = get_fp2_muldiv_chip::<Val<SC>, 6, 16>(
346 config,
347 mem_helper.clone(),
348 range_checker.clone(),
349 bitwise_lu.clone(),
350 pointer_max_bits,
351 );
352 inventory.add_executor_chip(muldiv);
353 } else {
354 panic!("Modulus too large");
355 }
356 }
357
358 Ok(())
359 }
360}