openvm_algebra_circuit/extension/
fp2.rs

1use std::sync::Arc;
2
3use num_bigint::BigUint;
4use openvm_algebra_transpiler::Fp2Opcode;
5use openvm_circuit::{
6    arch::{
7        AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge,
8        ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
9        VmExecutionExtension, VmProverExtension,
10    },
11    system::{memory::SharedMemoryHelper, SystemPort},
12};
13use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
14use openvm_circuit_primitives::{
15    bitwise_op_lookup::{
16        BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
17        SharedBitwiseOperationLookupChip,
18    },
19    var_range::VariableRangeCheckerBus,
20};
21use openvm_instructions::{LocalOpcode, VmOpcode};
22use openvm_mod_circuit_builder::ExprBuilderConfig;
23use openvm_stark_backend::{
24    config::{StarkGenericConfig, Val},
25    p3_field::PrimeField32,
26    prover::cpu::{CpuBackend, CpuDevice},
27};
28use openvm_stark_sdk::engine::StarkEngine;
29use serde::{Deserialize, Serialize};
30use serde_with::{serde_as, DisplayFromStr};
31use strum::EnumCount;
32
33use crate::{
34    fp2_chip::{
35        get_fp2_addsub_air, get_fp2_addsub_chip, get_fp2_addsub_step, get_fp2_muldiv_air,
36        get_fp2_muldiv_chip, get_fp2_muldiv_step, Fp2Air, Fp2Executor,
37    },
38    AlgebraCpuProverExt, ModularExtension,
39};
40
41#[serde_as]
42#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
43pub struct Fp2Extension {
44    // (name, modulus)
45    // name must match the struct name defined by complex_declare
46    #[serde_as(as = "Vec<(_, DisplayFromStr)>")]
47    pub supported_moduli: Vec<(String, BigUint)>,
48}
49
50impl Fp2Extension {
51    pub fn generate_complex_init(&self, modular_config: &ModularExtension) -> String {
52        fn get_index_of_modulus(modulus: &BigUint, modular_config: &ModularExtension) -> usize {
53            modular_config
54                .supported_moduli
55                .iter()
56                .position(|m| m == modulus)
57                .expect("Modulus used in Fp2Extension not found in ModularExtension")
58        }
59
60        let supported_moduli = self
61            .supported_moduli
62            .iter()
63            .map(|(name, modulus)| {
64                format!(
65                    "\"{}\" {{ mod_idx = {} }}",
66                    name,
67                    get_index_of_modulus(modulus, modular_config)
68                )
69            })
70            .collect::<Vec<String>>()
71            .join(", ");
72
73        format!("openvm_algebra_guest::complex_macros::complex_init! {{ {supported_moduli} }}")
74    }
75}
76
77#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
78pub enum Fp2ExtensionExecutor {
79    // 32 limbs prime
80    Fp2AddSubRv32_32(Fp2Executor<2, 32>), // Fp2AddSub
81    Fp2MulDivRv32_32(Fp2Executor<2, 32>), // Fp2MulDiv
82    // 48 limbs prime
83    Fp2AddSubRv32_48(Fp2Executor<6, 16>), // Fp2AddSub
84    Fp2MulDivRv32_48(Fp2Executor<6, 16>), // Fp2MulDiv
85}
86
87impl<F: PrimeField32> VmExecutionExtension<F> for Fp2Extension {
88    type Executor = Fp2ExtensionExecutor;
89
90    fn extend_execution(
91        &self,
92        inventory: &mut ExecutorInventoryBuilder<F, Fp2ExtensionExecutor>,
93    ) -> Result<(), ExecutorInventoryError> {
94        let pointer_max_bits = inventory.pointer_max_bits();
95        // TODO: somehow get the range checker bus from `ExecutorInventory`
96        let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16);
97        for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() {
98            // determine the number of bytes needed to represent a prime field element
99            let bytes = modulus.bits().div_ceil(8);
100            let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
101
102            if bytes <= 32 {
103                let config = ExprBuilderConfig {
104                    modulus: modulus.clone(),
105                    num_limbs: 32,
106                    limb_bits: 8,
107                };
108                let addsub = get_fp2_addsub_step(
109                    config.clone(),
110                    dummy_range_checker_bus,
111                    pointer_max_bits,
112                    start_offset,
113                );
114
115                inventory.add_executor(
116                    Fp2ExtensionExecutor::Fp2AddSubRv32_32(addsub),
117                    ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize))
118                        .map(|x| VmOpcode::from_usize(x + start_offset)),
119                )?;
120
121                let muldiv = get_fp2_muldiv_step(
122                    config,
123                    dummy_range_checker_bus,
124                    pointer_max_bits,
125                    start_offset,
126                );
127
128                inventory.add_executor(
129                    Fp2ExtensionExecutor::Fp2MulDivRv32_32(muldiv),
130                    ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize))
131                        .map(|x| VmOpcode::from_usize(x + start_offset)),
132                )?;
133            } else if bytes <= 48 {
134                let config = ExprBuilderConfig {
135                    modulus: modulus.clone(),
136                    num_limbs: 48,
137                    limb_bits: 8,
138                };
139                let addsub = get_fp2_addsub_step(
140                    config.clone(),
141                    dummy_range_checker_bus,
142                    pointer_max_bits,
143                    start_offset,
144                );
145
146                inventory.add_executor(
147                    Fp2ExtensionExecutor::Fp2AddSubRv32_48(addsub),
148                    ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize))
149                        .map(|x| VmOpcode::from_usize(x + start_offset)),
150                )?;
151
152                let muldiv = get_fp2_muldiv_step(
153                    config,
154                    dummy_range_checker_bus,
155                    pointer_max_bits,
156                    start_offset,
157                );
158
159                inventory.add_executor(
160                    Fp2ExtensionExecutor::Fp2MulDivRv32_48(muldiv),
161                    ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize))
162                        .map(|x| VmOpcode::from_usize(x + start_offset)),
163                )?;
164            } else {
165                panic!("Modulus too large");
166            }
167        }
168        Ok(())
169    }
170}
171
172impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for Fp2Extension {
173    fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
174        let SystemPort {
175            execution_bus,
176            program_bus,
177            memory_bridge,
178        } = inventory.system().port();
179
180        let exec_bridge = ExecutionBridge::new(execution_bus, program_bus);
181        let range_checker_bus = inventory.range_checker().bus;
182        let pointer_max_bits = inventory.pointer_max_bits();
183
184        let bitwise_lu = {
185            // A trick to get around Rust's borrow rules
186            let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
187            if let Some(air) = existing_air {
188                air.bus
189            } else {
190                let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
191                let air = BitwiseOperationLookupAir::<8>::new(bus);
192                inventory.add_air(air);
193                air.bus
194            }
195        };
196        for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() {
197            // determine the number of bytes needed to represent a prime field element
198            let bytes = modulus.bits().div_ceil(8);
199            let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
200
201            if bytes <= 32 {
202                let config = ExprBuilderConfig {
203                    modulus: modulus.clone(),
204                    num_limbs: 32,
205                    limb_bits: 8,
206                };
207
208                let addsub = get_fp2_addsub_air::<2, 32>(
209                    exec_bridge,
210                    memory_bridge,
211                    config.clone(),
212                    range_checker_bus,
213                    bitwise_lu,
214                    pointer_max_bits,
215                    start_offset,
216                );
217                inventory.add_air(addsub);
218
219                let muldiv = get_fp2_muldiv_air::<2, 32>(
220                    exec_bridge,
221                    memory_bridge,
222                    config,
223                    range_checker_bus,
224                    bitwise_lu,
225                    pointer_max_bits,
226                    start_offset,
227                );
228                inventory.add_air(muldiv);
229            } else if bytes <= 48 {
230                let config = ExprBuilderConfig {
231                    modulus: modulus.clone(),
232                    num_limbs: 48,
233                    limb_bits: 8,
234                };
235
236                let addsub = get_fp2_addsub_air::<6, 16>(
237                    exec_bridge,
238                    memory_bridge,
239                    config.clone(),
240                    range_checker_bus,
241                    bitwise_lu,
242                    pointer_max_bits,
243                    start_offset,
244                );
245                inventory.add_air(addsub);
246
247                let muldiv = get_fp2_muldiv_air::<6, 16>(
248                    exec_bridge,
249                    memory_bridge,
250                    config,
251                    range_checker_bus,
252                    bitwise_lu,
253                    pointer_max_bits,
254                    start_offset,
255                );
256                inventory.add_air(muldiv);
257            } else {
258                panic!("Modulus too large");
259            }
260        }
261
262        Ok(())
263    }
264}
265
266// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker,
267// BitwiseOperationLookupChip) are specific to CpuBackend.
268impl<E, SC, RA> VmProverExtension<E, RA, Fp2Extension> for AlgebraCpuProverExt
269where
270    SC: StarkGenericConfig,
271    E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
272    RA: RowMajorMatrixArena<Val<SC>>,
273    Val<SC>: PrimeField32,
274{
275    fn extend_prover(
276        &self,
277        extension: &Fp2Extension,
278        inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
279    ) -> Result<(), ChipInventoryError> {
280        let range_checker = inventory.range_checker()?.clone();
281        let timestamp_max_bits = inventory.timestamp_max_bits();
282        let pointer_max_bits = inventory.airs().pointer_max_bits();
283        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
284        let bitwise_lu = {
285            let existing_chip = inventory
286                .find_chip::<SharedBitwiseOperationLookupChip<8>>()
287                .next();
288            if let Some(chip) = existing_chip {
289                chip.clone()
290            } else {
291                let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
292                let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
293                inventory.add_periphery_chip(chip.clone());
294                chip
295            }
296        };
297        for (_, modulus) in extension.supported_moduli.iter() {
298            // determine the number of bytes needed to represent a prime field element
299            let bytes = modulus.bits().div_ceil(8);
300
301            if bytes <= 32 {
302                let config = ExprBuilderConfig {
303                    modulus: modulus.clone(),
304                    num_limbs: 32,
305                    limb_bits: 8,
306                };
307
308                inventory.next_air::<Fp2Air<2, 32>>()?;
309                let addsub = get_fp2_addsub_chip::<Val<SC>, 2, 32>(
310                    config.clone(),
311                    mem_helper.clone(),
312                    range_checker.clone(),
313                    bitwise_lu.clone(),
314                    pointer_max_bits,
315                );
316                inventory.add_executor_chip(addsub);
317
318                inventory.next_air::<Fp2Air<2, 32>>()?;
319                let muldiv = get_fp2_muldiv_chip::<Val<SC>, 2, 32>(
320                    config,
321                    mem_helper.clone(),
322                    range_checker.clone(),
323                    bitwise_lu.clone(),
324                    pointer_max_bits,
325                );
326                inventory.add_executor_chip(muldiv);
327            } else if bytes <= 48 {
328                let config = ExprBuilderConfig {
329                    modulus: modulus.clone(),
330                    num_limbs: 48,
331                    limb_bits: 8,
332                };
333
334                inventory.next_air::<Fp2Air<6, 16>>()?;
335                let addsub = get_fp2_addsub_chip::<Val<SC>, 6, 16>(
336                    config.clone(),
337                    mem_helper.clone(),
338                    range_checker.clone(),
339                    bitwise_lu.clone(),
340                    pointer_max_bits,
341                );
342                inventory.add_executor_chip(addsub);
343
344                inventory.next_air::<Fp2Air<6, 16>>()?;
345                let muldiv = get_fp2_muldiv_chip::<Val<SC>, 6, 16>(
346                    config,
347                    mem_helper.clone(),
348                    range_checker.clone(),
349                    bitwise_lu.clone(),
350                    pointer_max_bits,
351                );
352                inventory.add_executor_chip(muldiv);
353            } else {
354                panic!("Modulus too large");
355            }
356        }
357
358        Ok(())
359    }
360}