1use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
4use openvm_circuit::{
5 arch::*,
6 system::{
7 cuda::{
8 extensions::{
9 get_inventory_range_checker, get_or_create_bitwise_op_lookup, SystemGpuBuilder,
10 },
11 SystemChipInventoryGPU,
12 },
13 memory::SharedMemoryHelper,
14 },
15};
16use openvm_circuit_primitives::bigint::utils::big_uint_to_limbs;
17use openvm_cuda_backend::{
18 chip::{cpu_proving_ctx_to_gpu, get_empty_air_proving_ctx},
19 engine::GpuBabyBearPoseidon2Engine,
20 prover_backend::GpuBackend,
21 types::{F, SC},
22};
23use openvm_instructions::LocalOpcode;
24use openvm_mod_circuit_builder::{ExprBuilderConfig, FieldExpressionMetadata};
25use openvm_rv32_adapters::{
26 Rv32IsEqualModAdapterCols, Rv32IsEqualModAdapterExecutor, Rv32IsEqualModAdapterFiller,
27 Rv32IsEqualModAdapterRecord, Rv32VecHeapAdapterCols, Rv32VecHeapAdapterExecutor,
28};
29use openvm_rv32im_circuit::Rv32ImGpuProverExt;
30use openvm_stark_backend::{p3_air::BaseAir, prover::types::AirProvingContext, Chip};
31use strum::EnumCount;
32
33use crate::{
34 fp2_chip::{get_fp2_addsub_chip, get_fp2_muldiv_chip, Fp2Air, Fp2Chip},
35 modular_chip::*,
36 AlgebraRecord, Fp2Extension, ModularExtension, Rv32ModularConfig, Rv32ModularWithFp2Config,
37};
38
39#[derive(derive_new::new)]
40pub struct HybridModularChip<F, const BLOCKS: usize, const BLOCK_SIZE: usize> {
41 cpu: ModularChip<F, BLOCKS, BLOCK_SIZE>,
42}
43
44impl<const BLOCKS: usize, const BLOCK_SIZE: usize> Chip<DenseRecordArena, GpuBackend>
47 for HybridModularChip<F, BLOCKS, BLOCK_SIZE>
48{
49 fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
50 let total_input_limbs =
51 self.cpu.inner.num_inputs() * self.cpu.inner.expr.canonical_num_limbs();
52 let layout = AdapterCoreLayout::with_metadata(FieldExpressionMetadata::<
53 F,
54 Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
55 >::new(total_input_limbs));
56
57 let record_size = RecordSeeker::<
58 DenseRecordArena,
59 AlgebraRecord<2, BLOCKS, BLOCK_SIZE>,
60 _,
61 >::get_aligned_record_size(&layout);
62
63 let records = arena.allocated();
64 if records.is_empty() {
65 return get_empty_air_proving_ctx::<GpuBackend>();
66 }
67 debug_assert_eq!(records.len() % record_size, 0);
68
69 let num_records = records.len() / record_size;
70
71 let height = num_records.next_power_of_two();
72 let mut seeker = arena
73 .get_record_seeker::<AlgebraRecord<2, BLOCKS, BLOCK_SIZE>, AdapterCoreLayout<
74 FieldExpressionMetadata<
75 F,
76 Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
77 >,
78 >>();
79 let adapter_width =
80 Rv32VecHeapAdapterCols::<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>::width();
81 let width = adapter_width + BaseAir::<F>::width(&self.cpu.inner.expr);
82 let mut matrix_arena = MatrixRecordArena::<F>::with_capacity(height, width);
83 seeker.transfer_to_matrix_arena(&mut matrix_arena, layout);
84 let ctx = self.cpu.generate_proving_ctx(matrix_arena);
85 cpu_proving_ctx_to_gpu(ctx)
86 }
87}
88
89#[derive(derive_new::new)]
90pub struct HybridModularIsEqualChip<
91 F,
92 const NUM_LANES: usize,
93 const LANE_SIZE: usize,
94 const TOTAL_LIMBS: usize,
95> {
96 cpu: ModularIsEqualChip<F, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
97}
98
99impl<const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_LIMBS: usize>
100 Chip<DenseRecordArena, GpuBackend>
101 for HybridModularIsEqualChip<F, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>
102{
103 fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
104 let record_size = size_of::<(
105 Rv32IsEqualModAdapterRecord<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
106 ModularIsEqualRecord<TOTAL_LIMBS>,
107 )>();
108 let trace_width = Rv32IsEqualModAdapterCols::<F, 2, NUM_LANES, LANE_SIZE>::width()
109 + ModularIsEqualCoreCols::<F, TOTAL_LIMBS>::width();
110 let records = arena.allocated();
111 if records.is_empty() {
112 return get_empty_air_proving_ctx::<GpuBackend>();
113 }
114 debug_assert_eq!(records.len() % record_size, 0);
115
116 let num_records = records.len() / record_size;
117 let height = num_records.next_power_of_two();
118 let mut seeker = arena.get_record_seeker::<(
119 &mut Rv32IsEqualModAdapterRecord<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
120 &mut ModularIsEqualRecord<TOTAL_LIMBS>,
121 ), EmptyAdapterCoreLayout<
122 F,
123 Rv32IsEqualModAdapterExecutor<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
124 >>();
125 let mut matrix_arena = MatrixRecordArena::<F>::with_capacity(height, trace_width);
126 seeker.transfer_to_matrix_arena(&mut matrix_arena, EmptyAdapterCoreLayout::new());
127 let ctx = self.cpu.generate_proving_ctx(matrix_arena);
128 cpu_proving_ctx_to_gpu(ctx)
129 }
130}
131
132#[derive(Clone, Copy, Default)]
133pub struct AlgebraHybridProverExt;
134
135impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, ModularExtension>
136 for AlgebraHybridProverExt
137{
138 fn extend_prover(
139 &self,
140 extension: &ModularExtension,
141 inventory: &mut ChipInventory<SC, DenseRecordArena, GpuBackend>,
142 ) -> Result<(), ChipInventoryError> {
143 let range_checker_gpu = get_inventory_range_checker(inventory);
144 let timestamp_max_bits = inventory.timestamp_max_bits();
145 let pointer_max_bits = inventory.airs().pointer_max_bits();
146 let range_checker = range_checker_gpu.cpu_chip.clone().unwrap();
147 let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
148 let bitwise_lu_gpu = get_or_create_bitwise_op_lookup(inventory)?;
149 let bitwise_lu = bitwise_lu_gpu.cpu_chip.clone().unwrap();
150
151 for (i, modulus) in extension.supported_moduli.iter().enumerate() {
152 let bytes = modulus.bits().div_ceil(8);
154 let start_offset =
155 Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
156
157 let modulus_limbs = big_uint_to_limbs(modulus, 8);
158
159 if bytes <= 32 {
160 let config = ExprBuilderConfig {
161 modulus: modulus.clone(),
162 num_limbs: 32,
163 limb_bits: 8,
164 };
165
166 inventory.next_air::<ModularAir<1, 32>>()?;
167 let addsub = get_modular_addsub_chip::<F, 1, 32>(
168 config.clone(),
169 mem_helper.clone(),
170 range_checker.clone(),
171 bitwise_lu.clone(),
172 pointer_max_bits,
173 );
174 inventory.add_executor_chip(HybridModularChip::new(addsub));
175
176 inventory.next_air::<ModularAir<1, 32>>()?;
177 let muldiv = get_modular_muldiv_chip::<F, 1, 32>(
178 config,
179 mem_helper.clone(),
180 range_checker.clone(),
181 bitwise_lu.clone(),
182 pointer_max_bits,
183 );
184 inventory.add_executor_chip(HybridModularChip::new(muldiv));
185
186 let modulus_limbs = std::array::from_fn(|i| {
187 if i < modulus_limbs.len() {
188 modulus_limbs[i] as u8
189 } else {
190 0
191 }
192 });
193 inventory.next_air::<ModularIsEqualAir<1, 32, 32>>()?;
194 let is_eq = ModularIsEqualChip::<F, 1, 32, 32>::new(
195 ModularIsEqualFiller::new(
196 Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()),
197 start_offset,
198 modulus_limbs,
199 bitwise_lu.clone(),
200 ),
201 mem_helper.clone(),
202 );
203 inventory.add_executor_chip(HybridModularIsEqualChip::new(is_eq));
204 } else if bytes <= 48 {
205 let config = ExprBuilderConfig {
206 modulus: modulus.clone(),
207 num_limbs: 48,
208 limb_bits: 8,
209 };
210
211 inventory.next_air::<ModularAir<3, 16>>()?;
212 let addsub = get_modular_addsub_chip::<F, 3, 16>(
213 config.clone(),
214 mem_helper.clone(),
215 range_checker.clone(),
216 bitwise_lu.clone(),
217 pointer_max_bits,
218 );
219 inventory.add_executor_chip(HybridModularChip::new(addsub));
220
221 inventory.next_air::<ModularAir<3, 16>>()?;
222 let muldiv = get_modular_muldiv_chip::<F, 3, 16>(
223 config,
224 mem_helper.clone(),
225 range_checker.clone(),
226 bitwise_lu.clone(),
227 pointer_max_bits,
228 );
229 inventory.add_executor_chip(HybridModularChip::new(muldiv));
230
231 let modulus_limbs = std::array::from_fn(|i| {
232 if i < modulus_limbs.len() {
233 modulus_limbs[i] as u8
234 } else {
235 0
236 }
237 });
238 inventory.next_air::<ModularIsEqualAir<3, 16, 48>>()?;
239 let is_eq = ModularIsEqualChip::<F, 3, 16, 48>::new(
240 ModularIsEqualFiller::new(
241 Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()),
242 start_offset,
243 modulus_limbs,
244 bitwise_lu.clone(),
245 ),
246 mem_helper.clone(),
247 );
248 inventory.add_executor_chip(HybridModularIsEqualChip::new(is_eq));
249 } else {
250 panic!("Modulus too large");
251 }
252 }
253
254 Ok(())
255 }
256}
257
258#[derive(derive_new::new)]
259pub struct HybridFp2Chip<F, const BLOCKS: usize, const BLOCK_SIZE: usize> {
260 cpu: Fp2Chip<F, BLOCKS, BLOCK_SIZE>,
261}
262
263impl<const BLOCKS: usize, const BLOCK_SIZE: usize> Chip<DenseRecordArena, GpuBackend>
264 for HybridFp2Chip<F, BLOCKS, BLOCK_SIZE>
265{
266 fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
267 let total_input_limbs =
268 self.cpu.inner.num_inputs() * self.cpu.inner.expr.canonical_num_limbs();
269 let layout = AdapterCoreLayout::with_metadata(FieldExpressionMetadata::<
270 F,
271 Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
272 >::new(total_input_limbs));
273
274 let record_size = RecordSeeker::<
275 DenseRecordArena,
276 AlgebraRecord<2, BLOCKS, BLOCK_SIZE>,
277 _,
278 >::get_aligned_record_size(&layout);
279
280 let records = arena.allocated();
281 if records.is_empty() {
282 return get_empty_air_proving_ctx::<GpuBackend>();
283 }
284 debug_assert_eq!(records.len() % record_size, 0);
285
286 let num_records = records.len() / record_size;
287 let height = num_records.next_power_of_two();
288 let mut seeker = arena
289 .get_record_seeker::<AlgebraRecord<2, BLOCKS, BLOCK_SIZE>, AdapterCoreLayout<
290 FieldExpressionMetadata<
291 F,
292 Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
293 >,
294 >>();
295 let adapter_width =
296 Rv32VecHeapAdapterCols::<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>::width();
297 let width = adapter_width + BaseAir::<F>::width(&self.cpu.inner.expr);
298 let mut matrix_arena = MatrixRecordArena::<F>::with_capacity(height, width);
299 seeker.transfer_to_matrix_arena(&mut matrix_arena, layout);
300 let ctx = self.cpu.generate_proving_ctx(matrix_arena);
301 cpu_proving_ctx_to_gpu(ctx)
302 }
303}
304
305impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Fp2Extension>
306 for AlgebraHybridProverExt
307{
308 fn extend_prover(
309 &self,
310 extension: &Fp2Extension,
311 inventory: &mut ChipInventory<SC, DenseRecordArena, GpuBackend>,
312 ) -> Result<(), ChipInventoryError> {
313 let range_checker_gpu = get_inventory_range_checker(inventory);
314 let timestamp_max_bits = inventory.timestamp_max_bits();
315 let pointer_max_bits = inventory.airs().pointer_max_bits();
316 let range_checker = range_checker_gpu.cpu_chip.clone().unwrap();
317 let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
318 let bitwise_lu_gpu = get_or_create_bitwise_op_lookup(inventory)?;
319 let bitwise_lu = bitwise_lu_gpu.cpu_chip.clone().unwrap();
320
321 for (_, modulus) in extension.supported_moduli.iter() {
322 let bytes = modulus.bits().div_ceil(8);
324
325 if bytes <= 32 {
326 let config = ExprBuilderConfig {
327 modulus: modulus.clone(),
328 num_limbs: 32,
329 limb_bits: 8,
330 };
331
332 inventory.next_air::<Fp2Air<2, 32>>()?;
333 let addsub = get_fp2_addsub_chip::<F, 2, 32>(
334 config.clone(),
335 mem_helper.clone(),
336 range_checker.clone(),
337 bitwise_lu.clone(),
338 pointer_max_bits,
339 );
340 inventory.add_executor_chip(HybridFp2Chip::new(addsub));
341
342 inventory.next_air::<Fp2Air<2, 32>>()?;
343 let muldiv = get_fp2_muldiv_chip::<F, 2, 32>(
344 config,
345 mem_helper.clone(),
346 range_checker.clone(),
347 bitwise_lu.clone(),
348 pointer_max_bits,
349 );
350 inventory.add_executor_chip(HybridFp2Chip::new(muldiv));
351 } else if bytes <= 48 {
352 let config = ExprBuilderConfig {
353 modulus: modulus.clone(),
354 num_limbs: 48,
355 limb_bits: 8,
356 };
357
358 inventory.next_air::<Fp2Air<6, 16>>()?;
359 let addsub = get_fp2_addsub_chip::<F, 6, 16>(
360 config.clone(),
361 mem_helper.clone(),
362 range_checker.clone(),
363 bitwise_lu.clone(),
364 pointer_max_bits,
365 );
366 inventory.add_executor_chip(HybridFp2Chip::new(addsub));
367
368 inventory.next_air::<Fp2Air<6, 16>>()?;
369 let muldiv = get_fp2_muldiv_chip::<F, 6, 16>(
370 config,
371 mem_helper.clone(),
372 range_checker.clone(),
373 bitwise_lu.clone(),
374 pointer_max_bits,
375 );
376 inventory.add_executor_chip(HybridFp2Chip::new(muldiv));
377 } else {
378 panic!("Modulus too large");
379 }
380 }
381
382 Ok(())
383 }
384}
385
386#[derive(Clone)]
389pub struct Rv32ModularHybridBuilder;
390
391type E = GpuBabyBearPoseidon2Engine;
392
393impl VmBuilder<E> for Rv32ModularHybridBuilder {
394 type VmConfig = Rv32ModularConfig;
395 type SystemChipInventory = SystemChipInventoryGPU;
396 type RecordArena = DenseRecordArena;
397
398 fn create_chip_complex(
399 &self,
400 config: &Rv32ModularConfig,
401 circuit: AirInventory<SC>,
402 ) -> Result<
403 VmChipComplex<SC, Self::RecordArena, GpuBackend, Self::SystemChipInventory>,
404 ChipInventoryError,
405 > {
406 let mut chip_complex =
407 VmBuilder::<E>::create_chip_complex(&SystemGpuBuilder, &config.system, circuit)?;
408 let inventory = &mut chip_complex.inventory;
409 VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.base, inventory)?;
410 VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.mul, inventory)?;
411 VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.io, inventory)?;
412 VmProverExtension::<E, _, _>::extend_prover(
413 &AlgebraHybridProverExt,
414 &config.modular,
415 inventory,
416 )?;
417 Ok(chip_complex)
418 }
419}
420
421#[derive(Clone)]
424pub struct Rv32ModularWithFp2HybridBuilder;
425
426impl VmBuilder<E> for Rv32ModularWithFp2HybridBuilder {
427 type VmConfig = Rv32ModularWithFp2Config;
428 type SystemChipInventory = SystemChipInventoryGPU;
429 type RecordArena = DenseRecordArena;
430
431 fn create_chip_complex(
432 &self,
433 config: &Rv32ModularWithFp2Config,
434 circuit: AirInventory<SC>,
435 ) -> Result<
436 VmChipComplex<SC, Self::RecordArena, GpuBackend, Self::SystemChipInventory>,
437 ChipInventoryError,
438 > {
439 let mut chip_complex = VmBuilder::<E>::create_chip_complex(
440 &Rv32ModularHybridBuilder,
441 &config.modular,
442 circuit,
443 )?;
444 let inventory = &mut chip_complex.inventory;
445 VmProverExtension::<E, _, _>::extend_prover(
446 &AlgebraHybridProverExt,
447 &config.fp2,
448 inventory,
449 )?;
450 Ok(chip_complex)
451 }
452}