openvm_algebra_circuit/extension/
hybrid.rs

1//! Prover extension for the GPU backend which still does trace generation on CPU.
2
3use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
4use openvm_circuit::{
5    arch::*,
6    system::{
7        cuda::{
8            extensions::{
9                get_inventory_range_checker, get_or_create_bitwise_op_lookup, SystemGpuBuilder,
10            },
11            SystemChipInventoryGPU,
12        },
13        memory::SharedMemoryHelper,
14    },
15};
16use openvm_circuit_primitives::bigint::utils::big_uint_to_limbs;
17use openvm_cuda_backend::{
18    chip::{cpu_proving_ctx_to_gpu, get_empty_air_proving_ctx},
19    engine::GpuBabyBearPoseidon2Engine,
20    prover_backend::GpuBackend,
21    types::{F, SC},
22};
23use openvm_instructions::LocalOpcode;
24use openvm_mod_circuit_builder::{ExprBuilderConfig, FieldExpressionMetadata};
25use openvm_rv32_adapters::{
26    Rv32IsEqualModAdapterCols, Rv32IsEqualModAdapterExecutor, Rv32IsEqualModAdapterFiller,
27    Rv32IsEqualModAdapterRecord, Rv32VecHeapAdapterCols, Rv32VecHeapAdapterExecutor,
28};
29use openvm_rv32im_circuit::Rv32ImGpuProverExt;
30use openvm_stark_backend::{p3_air::BaseAir, prover::types::AirProvingContext, Chip};
31use strum::EnumCount;
32
33use crate::{
34    fp2_chip::{get_fp2_addsub_chip, get_fp2_muldiv_chip, Fp2Air, Fp2Chip},
35    modular_chip::*,
36    AlgebraRecord, Fp2Extension, ModularExtension, Rv32ModularConfig, Rv32ModularWithFp2Config,
37};
38
39#[derive(derive_new::new)]
40pub struct HybridModularChip<F, const BLOCKS: usize, const BLOCK_SIZE: usize> {
41    cpu: ModularChip<F, BLOCKS, BLOCK_SIZE>,
42}
43
44// Auto-implementation of Chip for GpuBackend for a Cpu Chip by doing conversion
45// of Dense->Matrix Record Arena, cpu tracegen, and then H2D transfer of the trace matrix.
46impl<const BLOCKS: usize, const BLOCK_SIZE: usize> Chip<DenseRecordArena, GpuBackend>
47    for HybridModularChip<F, BLOCKS, BLOCK_SIZE>
48{
49    fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
50        let total_input_limbs =
51            self.cpu.inner.num_inputs() * self.cpu.inner.expr.canonical_num_limbs();
52        let layout = AdapterCoreLayout::with_metadata(FieldExpressionMetadata::<
53            F,
54            Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
55        >::new(total_input_limbs));
56
57        let record_size = RecordSeeker::<
58            DenseRecordArena,
59            AlgebraRecord<2, BLOCKS, BLOCK_SIZE>,
60            _,
61        >::get_aligned_record_size(&layout);
62
63        let records = arena.allocated();
64        if records.is_empty() {
65            return get_empty_air_proving_ctx::<GpuBackend>();
66        }
67        debug_assert_eq!(records.len() % record_size, 0);
68
69        let num_records = records.len() / record_size;
70
71        let height = num_records.next_power_of_two();
72        let mut seeker = arena
73            .get_record_seeker::<AlgebraRecord<2, BLOCKS, BLOCK_SIZE>, AdapterCoreLayout<
74                FieldExpressionMetadata<
75                    F,
76                    Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
77                >,
78            >>();
79        let adapter_width =
80            Rv32VecHeapAdapterCols::<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>::width();
81        let width = adapter_width + BaseAir::<F>::width(&self.cpu.inner.expr);
82        let mut matrix_arena = MatrixRecordArena::<F>::with_capacity(height, width);
83        seeker.transfer_to_matrix_arena(&mut matrix_arena, layout);
84        let ctx = self.cpu.generate_proving_ctx(matrix_arena);
85        cpu_proving_ctx_to_gpu(ctx)
86    }
87}
88
89#[derive(derive_new::new)]
90pub struct HybridModularIsEqualChip<
91    F,
92    const NUM_LANES: usize,
93    const LANE_SIZE: usize,
94    const TOTAL_LIMBS: usize,
95> {
96    cpu: ModularIsEqualChip<F, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
97}
98
99impl<const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_LIMBS: usize>
100    Chip<DenseRecordArena, GpuBackend>
101    for HybridModularIsEqualChip<F, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>
102{
103    fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
104        let record_size = size_of::<(
105            Rv32IsEqualModAdapterRecord<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
106            ModularIsEqualRecord<TOTAL_LIMBS>,
107        )>();
108        let trace_width = Rv32IsEqualModAdapterCols::<F, 2, NUM_LANES, LANE_SIZE>::width()
109            + ModularIsEqualCoreCols::<F, TOTAL_LIMBS>::width();
110        let records = arena.allocated();
111        if records.is_empty() {
112            return get_empty_air_proving_ctx::<GpuBackend>();
113        }
114        debug_assert_eq!(records.len() % record_size, 0);
115
116        let num_records = records.len() / record_size;
117        let height = num_records.next_power_of_two();
118        let mut seeker = arena.get_record_seeker::<(
119            &mut Rv32IsEqualModAdapterRecord<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
120            &mut ModularIsEqualRecord<TOTAL_LIMBS>,
121        ), EmptyAdapterCoreLayout<
122            F,
123            Rv32IsEqualModAdapterExecutor<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
124        >>();
125        let mut matrix_arena = MatrixRecordArena::<F>::with_capacity(height, trace_width);
126        seeker.transfer_to_matrix_arena(&mut matrix_arena, EmptyAdapterCoreLayout::new());
127        let ctx = self.cpu.generate_proving_ctx(matrix_arena);
128        cpu_proving_ctx_to_gpu(ctx)
129    }
130}
131
132#[derive(Clone, Copy, Default)]
133pub struct AlgebraHybridProverExt;
134
135impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, ModularExtension>
136    for AlgebraHybridProverExt
137{
138    fn extend_prover(
139        &self,
140        extension: &ModularExtension,
141        inventory: &mut ChipInventory<SC, DenseRecordArena, GpuBackend>,
142    ) -> Result<(), ChipInventoryError> {
143        let range_checker_gpu = get_inventory_range_checker(inventory);
144        let timestamp_max_bits = inventory.timestamp_max_bits();
145        let pointer_max_bits = inventory.airs().pointer_max_bits();
146        let range_checker = range_checker_gpu.cpu_chip.clone().unwrap();
147        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
148        let bitwise_lu_gpu = get_or_create_bitwise_op_lookup(inventory)?;
149        let bitwise_lu = bitwise_lu_gpu.cpu_chip.clone().unwrap();
150
151        for (i, modulus) in extension.supported_moduli.iter().enumerate() {
152            // determine the number of bytes needed to represent a prime field element
153            let bytes = modulus.bits().div_ceil(8);
154            let start_offset =
155                Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
156
157            let modulus_limbs = big_uint_to_limbs(modulus, 8);
158
159            if bytes <= 32 {
160                let config = ExprBuilderConfig {
161                    modulus: modulus.clone(),
162                    num_limbs: 32,
163                    limb_bits: 8,
164                };
165
166                inventory.next_air::<ModularAir<1, 32>>()?;
167                let addsub = get_modular_addsub_chip::<F, 1, 32>(
168                    config.clone(),
169                    mem_helper.clone(),
170                    range_checker.clone(),
171                    bitwise_lu.clone(),
172                    pointer_max_bits,
173                );
174                inventory.add_executor_chip(HybridModularChip::new(addsub));
175
176                inventory.next_air::<ModularAir<1, 32>>()?;
177                let muldiv = get_modular_muldiv_chip::<F, 1, 32>(
178                    config,
179                    mem_helper.clone(),
180                    range_checker.clone(),
181                    bitwise_lu.clone(),
182                    pointer_max_bits,
183                );
184                inventory.add_executor_chip(HybridModularChip::new(muldiv));
185
186                let modulus_limbs = std::array::from_fn(|i| {
187                    if i < modulus_limbs.len() {
188                        modulus_limbs[i] as u8
189                    } else {
190                        0
191                    }
192                });
193                inventory.next_air::<ModularIsEqualAir<1, 32, 32>>()?;
194                let is_eq = ModularIsEqualChip::<F, 1, 32, 32>::new(
195                    ModularIsEqualFiller::new(
196                        Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()),
197                        start_offset,
198                        modulus_limbs,
199                        bitwise_lu.clone(),
200                    ),
201                    mem_helper.clone(),
202                );
203                inventory.add_executor_chip(HybridModularIsEqualChip::new(is_eq));
204            } else if bytes <= 48 {
205                let config = ExprBuilderConfig {
206                    modulus: modulus.clone(),
207                    num_limbs: 48,
208                    limb_bits: 8,
209                };
210
211                inventory.next_air::<ModularAir<3, 16>>()?;
212                let addsub = get_modular_addsub_chip::<F, 3, 16>(
213                    config.clone(),
214                    mem_helper.clone(),
215                    range_checker.clone(),
216                    bitwise_lu.clone(),
217                    pointer_max_bits,
218                );
219                inventory.add_executor_chip(HybridModularChip::new(addsub));
220
221                inventory.next_air::<ModularAir<3, 16>>()?;
222                let muldiv = get_modular_muldiv_chip::<F, 3, 16>(
223                    config,
224                    mem_helper.clone(),
225                    range_checker.clone(),
226                    bitwise_lu.clone(),
227                    pointer_max_bits,
228                );
229                inventory.add_executor_chip(HybridModularChip::new(muldiv));
230
231                let modulus_limbs = std::array::from_fn(|i| {
232                    if i < modulus_limbs.len() {
233                        modulus_limbs[i] as u8
234                    } else {
235                        0
236                    }
237                });
238                inventory.next_air::<ModularIsEqualAir<3, 16, 48>>()?;
239                let is_eq = ModularIsEqualChip::<F, 3, 16, 48>::new(
240                    ModularIsEqualFiller::new(
241                        Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()),
242                        start_offset,
243                        modulus_limbs,
244                        bitwise_lu.clone(),
245                    ),
246                    mem_helper.clone(),
247                );
248                inventory.add_executor_chip(HybridModularIsEqualChip::new(is_eq));
249            } else {
250                panic!("Modulus too large");
251            }
252        }
253
254        Ok(())
255    }
256}
257
258#[derive(derive_new::new)]
259pub struct HybridFp2Chip<F, const BLOCKS: usize, const BLOCK_SIZE: usize> {
260    cpu: Fp2Chip<F, BLOCKS, BLOCK_SIZE>,
261}
262
263impl<const BLOCKS: usize, const BLOCK_SIZE: usize> Chip<DenseRecordArena, GpuBackend>
264    for HybridFp2Chip<F, BLOCKS, BLOCK_SIZE>
265{
266    fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
267        let total_input_limbs =
268            self.cpu.inner.num_inputs() * self.cpu.inner.expr.canonical_num_limbs();
269        let layout = AdapterCoreLayout::with_metadata(FieldExpressionMetadata::<
270            F,
271            Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
272        >::new(total_input_limbs));
273
274        let record_size = RecordSeeker::<
275            DenseRecordArena,
276            AlgebraRecord<2, BLOCKS, BLOCK_SIZE>,
277            _,
278        >::get_aligned_record_size(&layout);
279
280        let records = arena.allocated();
281        if records.is_empty() {
282            return get_empty_air_proving_ctx::<GpuBackend>();
283        }
284        debug_assert_eq!(records.len() % record_size, 0);
285
286        let num_records = records.len() / record_size;
287        let height = num_records.next_power_of_two();
288        let mut seeker = arena
289            .get_record_seeker::<AlgebraRecord<2, BLOCKS, BLOCK_SIZE>, AdapterCoreLayout<
290                FieldExpressionMetadata<
291                    F,
292                    Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
293                >,
294            >>();
295        let adapter_width =
296            Rv32VecHeapAdapterCols::<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>::width();
297        let width = adapter_width + BaseAir::<F>::width(&self.cpu.inner.expr);
298        let mut matrix_arena = MatrixRecordArena::<F>::with_capacity(height, width);
299        seeker.transfer_to_matrix_arena(&mut matrix_arena, layout);
300        let ctx = self.cpu.generate_proving_ctx(matrix_arena);
301        cpu_proving_ctx_to_gpu(ctx)
302    }
303}
304
305impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Fp2Extension>
306    for AlgebraHybridProverExt
307{
308    fn extend_prover(
309        &self,
310        extension: &Fp2Extension,
311        inventory: &mut ChipInventory<SC, DenseRecordArena, GpuBackend>,
312    ) -> Result<(), ChipInventoryError> {
313        let range_checker_gpu = get_inventory_range_checker(inventory);
314        let timestamp_max_bits = inventory.timestamp_max_bits();
315        let pointer_max_bits = inventory.airs().pointer_max_bits();
316        let range_checker = range_checker_gpu.cpu_chip.clone().unwrap();
317        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
318        let bitwise_lu_gpu = get_or_create_bitwise_op_lookup(inventory)?;
319        let bitwise_lu = bitwise_lu_gpu.cpu_chip.clone().unwrap();
320
321        for (_, modulus) in extension.supported_moduli.iter() {
322            // determine the number of bytes needed to represent a prime field element
323            let bytes = modulus.bits().div_ceil(8);
324
325            if bytes <= 32 {
326                let config = ExprBuilderConfig {
327                    modulus: modulus.clone(),
328                    num_limbs: 32,
329                    limb_bits: 8,
330                };
331
332                inventory.next_air::<Fp2Air<2, 32>>()?;
333                let addsub = get_fp2_addsub_chip::<F, 2, 32>(
334                    config.clone(),
335                    mem_helper.clone(),
336                    range_checker.clone(),
337                    bitwise_lu.clone(),
338                    pointer_max_bits,
339                );
340                inventory.add_executor_chip(HybridFp2Chip::new(addsub));
341
342                inventory.next_air::<Fp2Air<2, 32>>()?;
343                let muldiv = get_fp2_muldiv_chip::<F, 2, 32>(
344                    config,
345                    mem_helper.clone(),
346                    range_checker.clone(),
347                    bitwise_lu.clone(),
348                    pointer_max_bits,
349                );
350                inventory.add_executor_chip(HybridFp2Chip::new(muldiv));
351            } else if bytes <= 48 {
352                let config = ExprBuilderConfig {
353                    modulus: modulus.clone(),
354                    num_limbs: 48,
355                    limb_bits: 8,
356                };
357
358                inventory.next_air::<Fp2Air<6, 16>>()?;
359                let addsub = get_fp2_addsub_chip::<F, 6, 16>(
360                    config.clone(),
361                    mem_helper.clone(),
362                    range_checker.clone(),
363                    bitwise_lu.clone(),
364                    pointer_max_bits,
365                );
366                inventory.add_executor_chip(HybridFp2Chip::new(addsub));
367
368                inventory.next_air::<Fp2Air<6, 16>>()?;
369                let muldiv = get_fp2_muldiv_chip::<F, 6, 16>(
370                    config,
371                    mem_helper.clone(),
372                    range_checker.clone(),
373                    bitwise_lu.clone(),
374                    pointer_max_bits,
375                );
376                inventory.add_executor_chip(HybridFp2Chip::new(muldiv));
377            } else {
378                panic!("Modulus too large");
379            }
380        }
381
382        Ok(())
383    }
384}
385
386/// This builder will do tracegen for the RV32IM extensions on GPU but the modular extensions on
387/// CPU.
388#[derive(Clone)]
389pub struct Rv32ModularHybridBuilder;
390
391type E = GpuBabyBearPoseidon2Engine;
392
393impl VmBuilder<E> for Rv32ModularHybridBuilder {
394    type VmConfig = Rv32ModularConfig;
395    type SystemChipInventory = SystemChipInventoryGPU;
396    type RecordArena = DenseRecordArena;
397
398    fn create_chip_complex(
399        &self,
400        config: &Rv32ModularConfig,
401        circuit: AirInventory<SC>,
402    ) -> Result<
403        VmChipComplex<SC, Self::RecordArena, GpuBackend, Self::SystemChipInventory>,
404        ChipInventoryError,
405    > {
406        let mut chip_complex =
407            VmBuilder::<E>::create_chip_complex(&SystemGpuBuilder, &config.system, circuit)?;
408        let inventory = &mut chip_complex.inventory;
409        VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.base, inventory)?;
410        VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.mul, inventory)?;
411        VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.io, inventory)?;
412        VmProverExtension::<E, _, _>::extend_prover(
413            &AlgebraHybridProverExt,
414            &config.modular,
415            inventory,
416        )?;
417        Ok(chip_complex)
418    }
419}
420
421/// This builder will do tracegen for the RV32IM extensions on GPU but the modular and complex
422/// extensions on CPU.
423#[derive(Clone)]
424pub struct Rv32ModularWithFp2HybridBuilder;
425
426impl VmBuilder<E> for Rv32ModularWithFp2HybridBuilder {
427    type VmConfig = Rv32ModularWithFp2Config;
428    type SystemChipInventory = SystemChipInventoryGPU;
429    type RecordArena = DenseRecordArena;
430
431    fn create_chip_complex(
432        &self,
433        config: &Rv32ModularWithFp2Config,
434        circuit: AirInventory<SC>,
435    ) -> Result<
436        VmChipComplex<SC, Self::RecordArena, GpuBackend, Self::SystemChipInventory>,
437        ChipInventoryError,
438    > {
439        let mut chip_complex = VmBuilder::<E>::create_chip_complex(
440            &Rv32ModularHybridBuilder,
441            &config.modular,
442            circuit,
443        )?;
444        let inventory = &mut chip_complex.inventory;
445        VmProverExtension::<E, _, _>::extend_prover(
446            &AlgebraHybridProverExt,
447            &config.fp2,
448            inventory,
449        )?;
450        Ok(chip_complex)
451    }
452}