openvm_algebra_circuit/extension/
fp2.rs

1use std::sync::Arc;
2
3use num_bigint::BigUint;
4use openvm_algebra_transpiler::Fp2Opcode;
5use openvm_circuit::{
6    arch::{
7        AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge,
8        ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
9        VmExecutionExtension, VmProverExtension,
10    },
11    system::{memory::SharedMemoryHelper, SystemPort},
12};
13use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
14use openvm_circuit_primitives::{
15    bitwise_op_lookup::{
16        BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
17        SharedBitwiseOperationLookupChip,
18    },
19    var_range::VariableRangeCheckerBus,
20};
21use openvm_instructions::{LocalOpcode, VmOpcode};
22use openvm_mod_circuit_builder::ExprBuilderConfig;
23use openvm_stark_backend::{
24    config::{StarkGenericConfig, Val},
25    p3_field::PrimeField32,
26    prover::cpu::{CpuBackend, CpuDevice},
27};
28use openvm_stark_sdk::engine::StarkEngine;
29use serde::{Deserialize, Serialize};
30use serde_with::{serde_as, DisplayFromStr};
31use strum::EnumCount;
32
33use crate::{
34    fp2_chip::{
35        get_fp2_addsub_air, get_fp2_addsub_chip, get_fp2_addsub_step, get_fp2_muldiv_air,
36        get_fp2_muldiv_chip, get_fp2_muldiv_step, Fp2Air, Fp2Executor,
37    },
38    AlgebraCpuProverExt, ModularExtension,
39};
40
41#[serde_as]
42#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
43pub struct Fp2Extension {
44    // (name, modulus)
45    // name must match the struct name defined by complex_declare
46    #[serde_as(as = "Vec<(_, DisplayFromStr)>")]
47    pub supported_moduli: Vec<(String, BigUint)>,
48}
49
50impl Fp2Extension {
51    pub fn generate_complex_init(&self, modular_config: &ModularExtension) -> String {
52        fn get_index_of_modulus(modulus: &BigUint, modular_config: &ModularExtension) -> usize {
53            modular_config
54                .supported_moduli
55                .iter()
56                .position(|m| m == modulus)
57                .expect("Modulus used in Fp2Extension not found in ModularExtension")
58        }
59
60        let supported_moduli = self
61            .supported_moduli
62            .iter()
63            .map(|(name, modulus)| {
64                format!(
65                    "\"{}\" {{ mod_idx = {} }}",
66                    name,
67                    get_index_of_modulus(modulus, modular_config)
68                )
69            })
70            .collect::<Vec<String>>()
71            .join(", ");
72
73        format!("openvm_algebra_guest::complex_macros::complex_init! {{ {supported_moduli} }}")
74    }
75}
76
77#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
78#[cfg_attr(
79    feature = "aot",
80    derive(
81        openvm_circuit_derive::AotExecutor,
82        openvm_circuit_derive::AotMeteredExecutor
83    )
84)]
85pub enum Fp2ExtensionExecutor {
86    // 32 limbs prime
87    Fp2AddSubRv32_32(Fp2Executor<2, 32>), // Fp2AddSub
88    Fp2MulDivRv32_32(Fp2Executor<2, 32>), // Fp2MulDiv
89    // 48 limbs prime
90    Fp2AddSubRv32_48(Fp2Executor<6, 16>), // Fp2AddSub
91    Fp2MulDivRv32_48(Fp2Executor<6, 16>), // Fp2MulDiv
92}
93
94impl<F: PrimeField32> VmExecutionExtension<F> for Fp2Extension {
95    type Executor = Fp2ExtensionExecutor;
96
97    fn extend_execution(
98        &self,
99        inventory: &mut ExecutorInventoryBuilder<F, Fp2ExtensionExecutor>,
100    ) -> Result<(), ExecutorInventoryError> {
101        let pointer_max_bits = inventory.pointer_max_bits();
102        // TODO: somehow get the range checker bus from `ExecutorInventory`
103        let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16);
104        for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() {
105            // determine the number of bytes needed to represent a prime field element
106            let bytes = modulus.bits().div_ceil(8);
107            let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
108
109            if bytes <= 32 {
110                let config = ExprBuilderConfig {
111                    modulus: modulus.clone(),
112                    num_limbs: 32,
113                    limb_bits: 8,
114                };
115                let addsub = get_fp2_addsub_step(
116                    config.clone(),
117                    dummy_range_checker_bus,
118                    pointer_max_bits,
119                    start_offset,
120                );
121
122                inventory.add_executor(
123                    Fp2ExtensionExecutor::Fp2AddSubRv32_32(addsub),
124                    ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize))
125                        .map(|x| VmOpcode::from_usize(x + start_offset)),
126                )?;
127
128                let muldiv = get_fp2_muldiv_step(
129                    config,
130                    dummy_range_checker_bus,
131                    pointer_max_bits,
132                    start_offset,
133                );
134
135                inventory.add_executor(
136                    Fp2ExtensionExecutor::Fp2MulDivRv32_32(muldiv),
137                    ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize))
138                        .map(|x| VmOpcode::from_usize(x + start_offset)),
139                )?;
140            } else if bytes <= 48 {
141                let config = ExprBuilderConfig {
142                    modulus: modulus.clone(),
143                    num_limbs: 48,
144                    limb_bits: 8,
145                };
146                let addsub = get_fp2_addsub_step(
147                    config.clone(),
148                    dummy_range_checker_bus,
149                    pointer_max_bits,
150                    start_offset,
151                );
152
153                inventory.add_executor(
154                    Fp2ExtensionExecutor::Fp2AddSubRv32_48(addsub),
155                    ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize))
156                        .map(|x| VmOpcode::from_usize(x + start_offset)),
157                )?;
158
159                let muldiv = get_fp2_muldiv_step(
160                    config,
161                    dummy_range_checker_bus,
162                    pointer_max_bits,
163                    start_offset,
164                );
165
166                inventory.add_executor(
167                    Fp2ExtensionExecutor::Fp2MulDivRv32_48(muldiv),
168                    ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize))
169                        .map(|x| VmOpcode::from_usize(x + start_offset)),
170                )?;
171            } else {
172                panic!("Modulus too large");
173            }
174        }
175        Ok(())
176    }
177}
178
179impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for Fp2Extension {
180    fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
181        let SystemPort {
182            execution_bus,
183            program_bus,
184            memory_bridge,
185        } = inventory.system().port();
186
187        let exec_bridge = ExecutionBridge::new(execution_bus, program_bus);
188        let range_checker_bus = inventory.range_checker().bus;
189        let pointer_max_bits = inventory.pointer_max_bits();
190
191        let bitwise_lu = {
192            // A trick to get around Rust's borrow rules
193            let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
194            if let Some(air) = existing_air {
195                air.bus
196            } else {
197                let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
198                let air = BitwiseOperationLookupAir::<8>::new(bus);
199                inventory.add_air(air);
200                air.bus
201            }
202        };
203        for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() {
204            // determine the number of bytes needed to represent a prime field element
205            let bytes = modulus.bits().div_ceil(8);
206            let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
207
208            if bytes <= 32 {
209                let config = ExprBuilderConfig {
210                    modulus: modulus.clone(),
211                    num_limbs: 32,
212                    limb_bits: 8,
213                };
214
215                let addsub = get_fp2_addsub_air::<2, 32>(
216                    exec_bridge,
217                    memory_bridge,
218                    config.clone(),
219                    range_checker_bus,
220                    bitwise_lu,
221                    pointer_max_bits,
222                    start_offset,
223                );
224                inventory.add_air(addsub);
225
226                let muldiv = get_fp2_muldiv_air::<2, 32>(
227                    exec_bridge,
228                    memory_bridge,
229                    config,
230                    range_checker_bus,
231                    bitwise_lu,
232                    pointer_max_bits,
233                    start_offset,
234                );
235                inventory.add_air(muldiv);
236            } else if bytes <= 48 {
237                let config = ExprBuilderConfig {
238                    modulus: modulus.clone(),
239                    num_limbs: 48,
240                    limb_bits: 8,
241                };
242
243                let addsub = get_fp2_addsub_air::<6, 16>(
244                    exec_bridge,
245                    memory_bridge,
246                    config.clone(),
247                    range_checker_bus,
248                    bitwise_lu,
249                    pointer_max_bits,
250                    start_offset,
251                );
252                inventory.add_air(addsub);
253
254                let muldiv = get_fp2_muldiv_air::<6, 16>(
255                    exec_bridge,
256                    memory_bridge,
257                    config,
258                    range_checker_bus,
259                    bitwise_lu,
260                    pointer_max_bits,
261                    start_offset,
262                );
263                inventory.add_air(muldiv);
264            } else {
265                panic!("Modulus too large");
266            }
267        }
268
269        Ok(())
270    }
271}
272
273// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker,
274// BitwiseOperationLookupChip) are specific to CpuBackend.
275impl<E, SC, RA> VmProverExtension<E, RA, Fp2Extension> for AlgebraCpuProverExt
276where
277    SC: StarkGenericConfig,
278    E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
279    RA: RowMajorMatrixArena<Val<SC>>,
280    Val<SC>: PrimeField32,
281{
282    fn extend_prover(
283        &self,
284        extension: &Fp2Extension,
285        inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
286    ) -> Result<(), ChipInventoryError> {
287        let range_checker = inventory.range_checker()?.clone();
288        let timestamp_max_bits = inventory.timestamp_max_bits();
289        let pointer_max_bits = inventory.airs().pointer_max_bits();
290        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
291        let bitwise_lu = {
292            let existing_chip = inventory
293                .find_chip::<SharedBitwiseOperationLookupChip<8>>()
294                .next();
295            if let Some(chip) = existing_chip {
296                chip.clone()
297            } else {
298                let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
299                let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
300                inventory.add_periphery_chip(chip.clone());
301                chip
302            }
303        };
304        for (_, modulus) in extension.supported_moduli.iter() {
305            // determine the number of bytes needed to represent a prime field element
306            let bytes = modulus.bits().div_ceil(8);
307
308            if bytes <= 32 {
309                let config = ExprBuilderConfig {
310                    modulus: modulus.clone(),
311                    num_limbs: 32,
312                    limb_bits: 8,
313                };
314
315                inventory.next_air::<Fp2Air<2, 32>>()?;
316                let addsub = get_fp2_addsub_chip::<Val<SC>, 2, 32>(
317                    config.clone(),
318                    mem_helper.clone(),
319                    range_checker.clone(),
320                    bitwise_lu.clone(),
321                    pointer_max_bits,
322                );
323                inventory.add_executor_chip(addsub);
324
325                inventory.next_air::<Fp2Air<2, 32>>()?;
326                let muldiv = get_fp2_muldiv_chip::<Val<SC>, 2, 32>(
327                    config,
328                    mem_helper.clone(),
329                    range_checker.clone(),
330                    bitwise_lu.clone(),
331                    pointer_max_bits,
332                );
333                inventory.add_executor_chip(muldiv);
334            } else if bytes <= 48 {
335                let config = ExprBuilderConfig {
336                    modulus: modulus.clone(),
337                    num_limbs: 48,
338                    limb_bits: 8,
339                };
340
341                inventory.next_air::<Fp2Air<6, 16>>()?;
342                let addsub = get_fp2_addsub_chip::<Val<SC>, 6, 16>(
343                    config.clone(),
344                    mem_helper.clone(),
345                    range_checker.clone(),
346                    bitwise_lu.clone(),
347                    pointer_max_bits,
348                );
349                inventory.add_executor_chip(addsub);
350
351                inventory.next_air::<Fp2Air<6, 16>>()?;
352                let muldiv = get_fp2_muldiv_chip::<Val<SC>, 6, 16>(
353                    config,
354                    mem_helper.clone(),
355                    range_checker.clone(),
356                    bitwise_lu.clone(),
357                    pointer_max_bits,
358                );
359                inventory.add_executor_chip(muldiv);
360            } else {
361                panic!("Modulus too large");
362            }
363        }
364
365        Ok(())
366    }
367}