openvm_bigint_circuit/extension/
cuda.rs1use openvm_circuit::{
2 arch::DenseRecordArena,
3 system::cuda::{
4 extensions::{
5 get_inventory_range_checker, get_or_create_bitwise_op_lookup, SystemGpuBuilder,
6 },
7 SystemChipInventoryGPU,
8 },
9};
10use openvm_circuit_primitives::range_tuple::RangeTupleCheckerChipGPU;
11use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend};
12use openvm_rv32im_circuit::Rv32ImGpuProverExt;
13use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config;
14
15use super::*;
16
17pub struct Int256GpuProverExt;
18
19impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Int256>
22 for Int256GpuProverExt
23{
24 fn extend_prover(
25 &self,
26 extension: &Int256,
27 inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
28 ) -> Result<(), ChipInventoryError> {
29 let pointer_max_bits = inventory.airs().pointer_max_bits();
30 let timestamp_max_bits = inventory.timestamp_max_bits();
31
32 let range_checker = get_inventory_range_checker(inventory);
33 let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
34
35 let range_tuple_checker = {
36 let existing_chip = inventory
37 .find_chip::<Arc<RangeTupleCheckerChipGPU<2>>>()
38 .find(|c| {
39 c.sizes[0] >= extension.range_tuple_checker_sizes[0]
40 && c.sizes[1] >= extension.range_tuple_checker_sizes[1]
41 });
42 if let Some(chip) = existing_chip {
43 chip.clone()
44 } else {
45 inventory.next_air::<RangeTupleCheckerAir<2>>()?;
46 let chip = Arc::new(RangeTupleCheckerChipGPU::new(
47 extension.range_tuple_checker_sizes,
48 ));
49 inventory.add_periphery_chip(chip.clone());
50 chip
51 }
52 };
53
54 inventory.next_air::<Rv32BaseAlu256Air>()?;
55 let base_alu = BaseAlu256ChipGpu::new(
56 range_checker.clone(),
57 bitwise_lu.clone(),
58 pointer_max_bits,
59 timestamp_max_bits,
60 );
61 inventory.add_executor_chip(base_alu);
62
63 inventory.next_air::<Rv32LessThan256Air>()?;
64 let lt = LessThan256ChipGpu::new(
65 range_checker.clone(),
66 bitwise_lu.clone(),
67 pointer_max_bits,
68 timestamp_max_bits,
69 );
70 inventory.add_executor_chip(lt);
71
72 inventory.next_air::<Rv32BranchEqual256Air>()?;
73 let beq = BranchEqual256ChipGpu::new(
74 range_checker.clone(),
75 bitwise_lu.clone(),
76 pointer_max_bits,
77 timestamp_max_bits,
78 );
79 inventory.add_executor_chip(beq);
80
81 inventory.next_air::<Rv32BranchLessThan256Air>()?;
82 let blt = BranchLessThan256ChipGpu::new(
83 range_checker.clone(),
84 bitwise_lu.clone(),
85 pointer_max_bits,
86 timestamp_max_bits,
87 );
88 inventory.add_executor_chip(blt);
89
90 inventory.next_air::<Rv32Multiplication256Air>()?;
91 let mult = Multiplication256ChipGpu::new(
92 range_checker.clone(),
93 bitwise_lu.clone(),
94 range_tuple_checker.clone(),
95 pointer_max_bits,
96 timestamp_max_bits,
97 );
98 inventory.add_executor_chip(mult);
99
100 inventory.next_air::<Rv32Shift256Air>()?;
101 let shift = Shift256ChipGpu::new(
102 range_checker.clone(),
103 bitwise_lu.clone(),
104 pointer_max_bits,
105 timestamp_max_bits,
106 );
107 inventory.add_executor_chip(shift);
108
109 Ok(())
110 }
111}
112
113#[derive(Clone)]
114pub struct Int256Rv32GpuBuilder;
115
116type E = GpuBabyBearPoseidon2Engine;
117
118impl VmBuilder<E> for Int256Rv32GpuBuilder {
119 type VmConfig = Int256Rv32Config;
120 type SystemChipInventory = SystemChipInventoryGPU;
121 type RecordArena = DenseRecordArena;
122
123 fn create_chip_complex(
124 &self,
125 config: &Int256Rv32Config,
126 circuit: AirInventory<<E as StarkEngine>::SC>,
127 ) -> Result<
128 VmChipComplex<
129 <E as StarkEngine>::SC,
130 Self::RecordArena,
131 <E as StarkEngine>::PB,
132 Self::SystemChipInventory,
133 >,
134 ChipInventoryError,
135 > {
136 let mut chip_complex =
137 VmBuilder::<E>::create_chip_complex(&SystemGpuBuilder, &config.system, circuit)?;
138 let inventory = &mut chip_complex.inventory;
139 VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.rv32i, inventory)?;
140 VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.rv32m, inventory)?;
141 VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.io, inventory)?;
142 VmProverExtension::<E, _, _>::extend_prover(
143 &Int256GpuProverExt,
144 &config.bigint,
145 inventory,
146 )?;
147 Ok(chip_complex)
148 }
149}