openvm_ecc_circuit/extension/
hybrid.rs1use openvm_algebra_circuit::Rv32ModularHybridBuilder;
4use openvm_circuit::{
5 arch::*,
6 system::{
7 cuda::{
8 extensions::{get_inventory_range_checker, get_or_create_bitwise_op_lookup},
9 SystemChipInventoryGPU,
10 },
11 memory::SharedMemoryHelper,
12 },
13};
14use openvm_cuda_backend::{
15 chip::{cpu_proving_ctx_to_gpu, get_empty_air_proving_ctx},
16 engine::GpuBabyBearPoseidon2Engine,
17 prover_backend::GpuBackend,
18 types::{F, SC},
19};
20use openvm_mod_circuit_builder::{ExprBuilderConfig, FieldExpressionMetadata};
21use openvm_rv32_adapters::{Rv32VecHeapAdapterCols, Rv32VecHeapAdapterExecutor};
22use openvm_stark_backend::{p3_air::BaseAir, prover::types::AirProvingContext, Chip};
23
24use crate::{
25 get_ec_addne_chip, get_ec_double_chip, EccRecord, Rv32WeierstrassConfig, WeierstrassAir,
26 WeierstrassChip, WeierstrassExtension,
27};
28
29#[derive(derive_new::new)]
30pub struct HybridWeierstrassChip<
31 F,
32 const NUM_READS: usize,
33 const BLOCKS: usize,
34 const BLOCK_SIZE: usize,
35> {
36 cpu: WeierstrassChip<F, NUM_READS, BLOCKS, BLOCK_SIZE>,
37}
38
39impl<const NUM_READS: usize, const BLOCKS: usize, const BLOCK_SIZE: usize>
42 Chip<DenseRecordArena, GpuBackend> for HybridWeierstrassChip<F, NUM_READS, BLOCKS, BLOCK_SIZE>
43{
44 fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
45 let total_input_limbs =
46 self.cpu.inner.num_inputs() * self.cpu.inner.expr.canonical_num_limbs();
47 let layout = AdapterCoreLayout::with_metadata(FieldExpressionMetadata::<
48 F,
49 Rv32VecHeapAdapterExecutor<NUM_READS, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
50 >::new(total_input_limbs));
51
52 let record_size = RecordSeeker::<
53 DenseRecordArena,
54 EccRecord<NUM_READS, BLOCKS, BLOCK_SIZE>,
55 _,
56 >::get_aligned_record_size(&layout);
57
58 let records = arena.allocated();
59 if records.is_empty() {
60 return get_empty_air_proving_ctx::<GpuBackend>();
61 }
62 debug_assert_eq!(records.len() % record_size, 0);
63
64 let num_records = records.len() / record_size;
65 let height = num_records.next_power_of_two();
66 let mut seeker = arena
67 .get_record_seeker::<EccRecord<NUM_READS, BLOCKS, BLOCK_SIZE>, AdapterCoreLayout<
68 FieldExpressionMetadata<
69 F,
70 Rv32VecHeapAdapterExecutor<NUM_READS, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
71 >,
72 >>();
73 let adapter_width =
74 Rv32VecHeapAdapterCols::<F, NUM_READS, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>::width();
75 let width = adapter_width + BaseAir::<F>::width(&self.cpu.inner.expr);
76 let mut matrix_arena = MatrixRecordArena::<F>::with_capacity(height, width);
77 seeker.transfer_to_matrix_arena(&mut matrix_arena, layout);
78 let ctx = self.cpu.generate_proving_ctx(matrix_arena);
79 cpu_proving_ctx_to_gpu(ctx)
80 }
81}
82
83#[derive(Clone, Copy, Default)]
84pub struct EccHybridProverExt;
85
86impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, WeierstrassExtension>
87 for EccHybridProverExt
88{
89 fn extend_prover(
90 &self,
91 extension: &WeierstrassExtension,
92 inventory: &mut ChipInventory<SC, DenseRecordArena, GpuBackend>,
93 ) -> Result<(), ChipInventoryError> {
94 let range_checker_gpu = get_inventory_range_checker(inventory);
95 let timestamp_max_bits = inventory.timestamp_max_bits();
96 let pointer_max_bits = inventory.airs().pointer_max_bits();
97 let range_checker = range_checker_gpu.cpu_chip.clone().unwrap();
98 let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
99
100 let bitwise_lu_gpu = get_or_create_bitwise_op_lookup(inventory)?;
101 let bitwise_lu = bitwise_lu_gpu.cpu_chip.clone().unwrap();
102
103 for curve in extension.supported_curves.iter() {
104 let bytes = curve.modulus.bits().div_ceil(8);
105
106 if bytes <= 32 {
107 let config = ExprBuilderConfig {
108 modulus: curve.modulus.clone(),
109 num_limbs: 32,
110 limb_bits: 8,
111 };
112
113 inventory.next_air::<WeierstrassAir<2, 2, 32>>()?;
114 let addne = get_ec_addne_chip::<F, 2, 32>(
115 config.clone(),
116 mem_helper.clone(),
117 range_checker.clone(),
118 bitwise_lu.clone(),
119 pointer_max_bits,
120 );
121 inventory.add_executor_chip(HybridWeierstrassChip::new(addne));
122
123 inventory.next_air::<WeierstrassAir<1, 2, 32>>()?;
124 let double = get_ec_double_chip::<F, 2, 32>(
125 config,
126 mem_helper.clone(),
127 range_checker.clone(),
128 bitwise_lu.clone(),
129 pointer_max_bits,
130 curve.a.clone(),
131 );
132 inventory.add_executor_chip(HybridWeierstrassChip::new(double));
133 } else if bytes <= 48 {
134 let config = ExprBuilderConfig {
135 modulus: curve.modulus.clone(),
136 num_limbs: 48,
137 limb_bits: 8,
138 };
139
140 inventory.next_air::<WeierstrassAir<2, 6, 16>>()?;
141 let addne = get_ec_addne_chip::<F, 6, 16>(
142 config.clone(),
143 mem_helper.clone(),
144 range_checker.clone(),
145 bitwise_lu.clone(),
146 pointer_max_bits,
147 );
148 inventory.add_executor_chip(HybridWeierstrassChip::new(addne));
149
150 inventory.next_air::<WeierstrassAir<1, 6, 16>>()?;
151 let double = get_ec_double_chip::<F, 6, 16>(
152 config,
153 mem_helper.clone(),
154 range_checker.clone(),
155 bitwise_lu.clone(),
156 pointer_max_bits,
157 curve.a.clone(),
158 );
159 inventory.add_executor_chip(HybridWeierstrassChip::new(double));
160 } else {
161 panic!("Modulus too large");
162 }
163 }
164
165 Ok(())
166 }
167}
168
169#[derive(Clone)]
172pub struct Rv32WeierstrassHybridBuilder;
173
174type E = GpuBabyBearPoseidon2Engine;
175
176impl VmBuilder<E> for Rv32WeierstrassHybridBuilder {
177 type VmConfig = Rv32WeierstrassConfig;
178 type SystemChipInventory = SystemChipInventoryGPU;
179 type RecordArena = DenseRecordArena;
180
181 fn create_chip_complex(
182 &self,
183 config: &Rv32WeierstrassConfig,
184 circuit: AirInventory<SC>,
185 ) -> Result<
186 VmChipComplex<SC, Self::RecordArena, GpuBackend, Self::SystemChipInventory>,
187 ChipInventoryError,
188 > {
189 let mut chip_complex = VmBuilder::<E>::create_chip_complex(
190 &Rv32ModularHybridBuilder,
191 &config.modular,
192 circuit,
193 )?;
194 let inventory = &mut chip_complex.inventory;
195 VmProverExtension::<E, _, _>::extend_prover(
196 &EccHybridProverExt,
197 &config.weierstrass,
198 inventory,
199 )?;
200
201 Ok(chip_complex)
202 }
203}