openvm_ecc_circuit/extension/
hybrid.rs

1//! Prover extension for the GPU backend which still does trace generation on CPU.
2
3use openvm_algebra_circuit::Rv32ModularHybridBuilder;
4use openvm_circuit::{
5    arch::*,
6    system::{
7        cuda::{
8            extensions::{get_inventory_range_checker, get_or_create_bitwise_op_lookup},
9            SystemChipInventoryGPU,
10        },
11        memory::SharedMemoryHelper,
12    },
13};
14use openvm_cuda_backend::{
15    chip::{cpu_proving_ctx_to_gpu, get_empty_air_proving_ctx},
16    engine::GpuBabyBearPoseidon2Engine,
17    prover_backend::GpuBackend,
18    types::{F, SC},
19};
20use openvm_mod_circuit_builder::{ExprBuilderConfig, FieldExpressionMetadata};
21use openvm_rv32_adapters::{Rv32VecHeapAdapterCols, Rv32VecHeapAdapterExecutor};
22use openvm_stark_backend::{p3_air::BaseAir, prover::types::AirProvingContext, Chip};
23
24use crate::{
25    get_ec_addne_chip, get_ec_double_chip, EccRecord, Rv32WeierstrassConfig, WeierstrassAir,
26    WeierstrassChip, WeierstrassExtension,
27};
28
29#[derive(derive_new::new)]
30pub struct HybridWeierstrassChip<
31    F,
32    const NUM_READS: usize,
33    const BLOCKS: usize,
34    const BLOCK_SIZE: usize,
35> {
36    cpu: WeierstrassChip<F, NUM_READS, BLOCKS, BLOCK_SIZE>,
37}
38
39// Auto-implementation of Chip for GpuBackend for a Cpu Chip by doing conversion
40// of Dense->Matrix Record Arena, cpu tracegen, and then H2D transfer of the trace matrix.
41impl<const NUM_READS: usize, const BLOCKS: usize, const BLOCK_SIZE: usize>
42    Chip<DenseRecordArena, GpuBackend> for HybridWeierstrassChip<F, NUM_READS, BLOCKS, BLOCK_SIZE>
43{
44    fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
45        let total_input_limbs =
46            self.cpu.inner.num_inputs() * self.cpu.inner.expr.canonical_num_limbs();
47        let layout = AdapterCoreLayout::with_metadata(FieldExpressionMetadata::<
48            F,
49            Rv32VecHeapAdapterExecutor<NUM_READS, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
50        >::new(total_input_limbs));
51
52        let record_size = RecordSeeker::<
53            DenseRecordArena,
54            EccRecord<NUM_READS, BLOCKS, BLOCK_SIZE>,
55            _,
56        >::get_aligned_record_size(&layout);
57
58        let records = arena.allocated();
59        if records.is_empty() {
60            return get_empty_air_proving_ctx::<GpuBackend>();
61        }
62        debug_assert_eq!(records.len() % record_size, 0);
63
64        let num_records = records.len() / record_size;
65        let height = num_records.next_power_of_two();
66        let mut seeker = arena
67            .get_record_seeker::<EccRecord<NUM_READS, BLOCKS, BLOCK_SIZE>, AdapterCoreLayout<
68                FieldExpressionMetadata<
69                    F,
70                    Rv32VecHeapAdapterExecutor<NUM_READS, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
71                >,
72            >>();
73        let adapter_width =
74            Rv32VecHeapAdapterCols::<F, NUM_READS, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>::width();
75        let width = adapter_width + BaseAir::<F>::width(&self.cpu.inner.expr);
76        let mut matrix_arena = MatrixRecordArena::<F>::with_capacity(height, width);
77        seeker.transfer_to_matrix_arena(&mut matrix_arena, layout);
78        let ctx = self.cpu.generate_proving_ctx(matrix_arena);
79        cpu_proving_ctx_to_gpu(ctx)
80    }
81}
82
83#[derive(Clone, Copy, Default)]
84pub struct EccHybridProverExt;
85
86impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, WeierstrassExtension>
87    for EccHybridProverExt
88{
89    fn extend_prover(
90        &self,
91        extension: &WeierstrassExtension,
92        inventory: &mut ChipInventory<SC, DenseRecordArena, GpuBackend>,
93    ) -> Result<(), ChipInventoryError> {
94        let range_checker_gpu = get_inventory_range_checker(inventory);
95        let timestamp_max_bits = inventory.timestamp_max_bits();
96        let pointer_max_bits = inventory.airs().pointer_max_bits();
97        let range_checker = range_checker_gpu.cpu_chip.clone().unwrap();
98        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
99
100        let bitwise_lu_gpu = get_or_create_bitwise_op_lookup(inventory)?;
101        let bitwise_lu = bitwise_lu_gpu.cpu_chip.clone().unwrap();
102
103        for curve in extension.supported_curves.iter() {
104            let bytes = curve.modulus.bits().div_ceil(8);
105
106            if bytes <= 32 {
107                let config = ExprBuilderConfig {
108                    modulus: curve.modulus.clone(),
109                    num_limbs: 32,
110                    limb_bits: 8,
111                };
112
113                inventory.next_air::<WeierstrassAir<2, 2, 32>>()?;
114                let addne = get_ec_addne_chip::<F, 2, 32>(
115                    config.clone(),
116                    mem_helper.clone(),
117                    range_checker.clone(),
118                    bitwise_lu.clone(),
119                    pointer_max_bits,
120                );
121                inventory.add_executor_chip(HybridWeierstrassChip::new(addne));
122
123                inventory.next_air::<WeierstrassAir<1, 2, 32>>()?;
124                let double = get_ec_double_chip::<F, 2, 32>(
125                    config,
126                    mem_helper.clone(),
127                    range_checker.clone(),
128                    bitwise_lu.clone(),
129                    pointer_max_bits,
130                    curve.a.clone(),
131                );
132                inventory.add_executor_chip(HybridWeierstrassChip::new(double));
133            } else if bytes <= 48 {
134                let config = ExprBuilderConfig {
135                    modulus: curve.modulus.clone(),
136                    num_limbs: 48,
137                    limb_bits: 8,
138                };
139
140                inventory.next_air::<WeierstrassAir<2, 6, 16>>()?;
141                let addne = get_ec_addne_chip::<F, 6, 16>(
142                    config.clone(),
143                    mem_helper.clone(),
144                    range_checker.clone(),
145                    bitwise_lu.clone(),
146                    pointer_max_bits,
147                );
148                inventory.add_executor_chip(HybridWeierstrassChip::new(addne));
149
150                inventory.next_air::<WeierstrassAir<1, 6, 16>>()?;
151                let double = get_ec_double_chip::<F, 6, 16>(
152                    config,
153                    mem_helper.clone(),
154                    range_checker.clone(),
155                    bitwise_lu.clone(),
156                    pointer_max_bits,
157                    curve.a.clone(),
158                );
159                inventory.add_executor_chip(HybridWeierstrassChip::new(double));
160            } else {
161                panic!("Modulus too large");
162            }
163        }
164
165        Ok(())
166    }
167}
168
169/// This builder will do tracegen for the RV32IM extensions on GPU but the modular and ecc
170/// extensions on CPU.
171#[derive(Clone)]
172pub struct Rv32WeierstrassHybridBuilder;
173
174type E = GpuBabyBearPoseidon2Engine;
175
176impl VmBuilder<E> for Rv32WeierstrassHybridBuilder {
177    type VmConfig = Rv32WeierstrassConfig;
178    type SystemChipInventory = SystemChipInventoryGPU;
179    type RecordArena = DenseRecordArena;
180
181    fn create_chip_complex(
182        &self,
183        config: &Rv32WeierstrassConfig,
184        circuit: AirInventory<SC>,
185    ) -> Result<
186        VmChipComplex<SC, Self::RecordArena, GpuBackend, Self::SystemChipInventory>,
187        ChipInventoryError,
188    > {
189        let mut chip_complex = VmBuilder::<E>::create_chip_complex(
190            &Rv32ModularHybridBuilder,
191            &config.modular,
192            circuit,
193        )?;
194        let inventory = &mut chip_complex.inventory;
195        VmProverExtension::<E, _, _>::extend_prover(
196            &EccHybridProverExt,
197            &config.weierstrass,
198            inventory,
199        )?;
200
201        Ok(chip_complex)
202    }
203}