1use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode};
2use openvm_circuit::{
3 arch::{
4 AirInventory, ChipInventory, ChipInventoryError, DenseRecordArena, VmBuilder,
5 VmChipComplex, VmProverExtension,
6 },
7 system::cuda::{
8 extensions::{
9 get_inventory_range_checker, get_or_create_bitwise_op_lookup, SystemGpuBuilder,
10 },
11 SystemChipInventoryGPU,
12 },
13};
14use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend};
15use openvm_instructions::LocalOpcode;
16use openvm_mod_circuit_builder::ExprBuilderConfig;
17use openvm_rv32im_circuit::Rv32ImGpuProverExt;
18use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config;
19use strum::EnumCount;
20
21use crate::{
22 fp2_chip::{Fp2AddSubChipGpu, Fp2Air, Fp2MulDivChipGpu},
23 modular_chip::{
24 ModularAddSubChipGpu, ModularAir, ModularIsEqualAir, ModularIsEqualChipGpu,
25 ModularMulDivChipGpu,
26 },
27 Fp2Extension, ModularExtension, Rv32ModularConfig, Rv32ModularWithFp2Config,
28};
29
30#[derive(Clone)]
31pub struct AlgebraGpuProverExt;
32
33impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Fp2Extension>
36 for AlgebraGpuProverExt
37{
38 fn extend_prover(
39 &self,
40 extension: &Fp2Extension,
41 inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
42 ) -> Result<(), ChipInventoryError> {
43 let pointer_max_bits = inventory.airs().pointer_max_bits();
44 let timestamp_max_bits = inventory.timestamp_max_bits();
45
46 let range_checker = get_inventory_range_checker(inventory);
48
49 let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
50
51 for (i, (_, modulus)) in extension.supported_moduli.iter().enumerate() {
52 let bytes = modulus.bits().div_ceil(8);
54 let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
55
56 if bytes <= 32 {
57 let config = ExprBuilderConfig {
58 modulus: modulus.clone(),
59 num_limbs: 32,
60 limb_bits: 8,
61 };
62
63 inventory.next_air::<Fp2Air<2, 32>>()?;
64 let addsub = Fp2AddSubChipGpu::<2, 32>::new(
65 range_checker.clone(),
66 bitwise_lu.clone(),
67 config.clone(),
68 start_offset,
69 pointer_max_bits as u32,
70 timestamp_max_bits as u32,
71 );
72 inventory.add_executor_chip(addsub);
73
74 inventory.next_air::<Fp2Air<2, 32>>()?;
75 let muldiv = Fp2MulDivChipGpu::<2, 32>::new(
76 range_checker.clone(),
77 bitwise_lu.clone(),
78 config,
79 start_offset,
80 pointer_max_bits as u32,
81 timestamp_max_bits as u32,
82 );
83 inventory.add_executor_chip(muldiv);
84 } else if bytes <= 48 {
85 let config = ExprBuilderConfig {
86 modulus: modulus.clone(),
87 num_limbs: 48,
88 limb_bits: 8,
89 };
90
91 inventory.next_air::<Fp2Air<6, 16>>()?;
92 let addsub = Fp2AddSubChipGpu::<6, 16>::new(
93 range_checker.clone(),
94 bitwise_lu.clone(),
95 config.clone(),
96 start_offset,
97 pointer_max_bits as u32,
98 timestamp_max_bits as u32,
99 );
100 inventory.add_executor_chip(addsub);
101
102 inventory.next_air::<Fp2Air<6, 16>>()?;
103 let muldiv = Fp2MulDivChipGpu::<6, 16>::new(
104 range_checker.clone(),
105 bitwise_lu.clone(),
106 config,
107 start_offset,
108 pointer_max_bits as u32,
109 timestamp_max_bits as u32,
110 );
111 inventory.add_executor_chip(muldiv);
112 } else {
113 panic!("Modulus too large");
114 }
115 }
116
117 Ok(())
118 }
119}
120
121impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, ModularExtension>
122 for AlgebraGpuProverExt
123{
124 fn extend_prover(
125 &self,
126 extension: &ModularExtension,
127 inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
128 ) -> Result<(), ChipInventoryError> {
129 let pointer_max_bits = inventory.airs().pointer_max_bits();
130 let timestamp_max_bits = inventory.timestamp_max_bits();
131
132 let range_checker = get_inventory_range_checker(inventory);
134
135 let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
136
137 for (i, modulus) in extension.supported_moduli.iter().enumerate() {
138 let bytes = modulus.bits().div_ceil(8);
139 let start_offset =
140 Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
141
142 if bytes <= 32 {
143 let config = ExprBuilderConfig {
144 modulus: modulus.clone(),
145 num_limbs: 32,
146 limb_bits: 8,
147 };
148
149 inventory.next_air::<ModularAir<1, 32>>()?;
150 let addsub = ModularAddSubChipGpu::<1, 32>::new(
151 range_checker.clone(),
152 bitwise_lu.clone(),
153 config.clone(),
154 start_offset,
155 pointer_max_bits as u32,
156 timestamp_max_bits as u32,
157 );
158 inventory.add_executor_chip(addsub);
159
160 inventory.next_air::<ModularAir<1, 32>>()?;
161 let muldiv = ModularMulDivChipGpu::<1, 32>::new(
162 range_checker.clone(),
163 bitwise_lu.clone(),
164 config,
165 start_offset,
166 pointer_max_bits as u32,
167 timestamp_max_bits as u32,
168 );
169 inventory.add_executor_chip(muldiv);
170
171 inventory.next_air::<ModularIsEqualAir<1, 32, 32>>()?;
172 let is_eq = ModularIsEqualChipGpu::<1, 32, 32>::new(
173 range_checker.clone(),
174 bitwise_lu.clone(),
175 modulus.clone(),
176 pointer_max_bits as u32,
177 timestamp_max_bits as u32,
178 );
179 inventory.add_executor_chip(is_eq);
180 } else if bytes <= 48 {
181 let config = ExprBuilderConfig {
182 modulus: modulus.clone(),
183 num_limbs: 48,
184 limb_bits: 8,
185 };
186
187 inventory.next_air::<ModularAir<3, 16>>()?;
188 let addsub = ModularAddSubChipGpu::<3, 16>::new(
189 range_checker.clone(),
190 bitwise_lu.clone(),
191 config.clone(),
192 start_offset,
193 pointer_max_bits as u32,
194 timestamp_max_bits as u32,
195 );
196 inventory.add_executor_chip(addsub);
197
198 inventory.next_air::<ModularAir<3, 16>>()?;
199 let muldiv = ModularMulDivChipGpu::<3, 16>::new(
200 range_checker.clone(),
201 bitwise_lu.clone(),
202 config,
203 start_offset,
204 pointer_max_bits as u32,
205 timestamp_max_bits as u32,
206 );
207 inventory.add_executor_chip(muldiv);
208
209 inventory.next_air::<ModularIsEqualAir<3, 16, 48>>()?;
210 let is_eq = ModularIsEqualChipGpu::<3, 16, 48>::new(
211 range_checker.clone(),
212 bitwise_lu.clone(),
213 modulus.clone(),
214 pointer_max_bits as u32,
215 timestamp_max_bits as u32,
216 );
217 inventory.add_executor_chip(is_eq);
218 } else {
219 panic!("Modulus too large");
220 }
221 }
222
223 Ok(())
224 }
225}
226
227#[derive(Clone)]
228pub struct Rv32ModularGpuBuilder;
229
230type E = GpuBabyBearPoseidon2Engine;
231
232impl VmBuilder<E> for Rv32ModularGpuBuilder {
233 type VmConfig = Rv32ModularConfig;
234 type SystemChipInventory = SystemChipInventoryGPU;
235 type RecordArena = DenseRecordArena;
236
237 fn create_chip_complex(
238 &self,
239 config: &Rv32ModularConfig,
240 circuit: AirInventory<BabyBearPoseidon2Config>,
241 ) -> Result<
242 VmChipComplex<
243 BabyBearPoseidon2Config,
244 Self::RecordArena,
245 GpuBackend,
246 Self::SystemChipInventory,
247 >,
248 ChipInventoryError,
249 > {
250 let mut chip_complex =
251 VmBuilder::<E>::create_chip_complex(&SystemGpuBuilder, &config.system, circuit)?;
252 let inventory = &mut chip_complex.inventory;
253 VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.base, inventory)?;
254 VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.mul, inventory)?;
255 VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.io, inventory)?;
256 VmProverExtension::<E, _, _>::extend_prover(
257 &AlgebraGpuProverExt,
258 &config.modular,
259 inventory,
260 )?;
261 Ok(chip_complex)
262 }
263}
264
265#[derive(Clone)]
266pub struct Rv32ModularWithFp2GpuBuilder;
267
268impl VmBuilder<E> for Rv32ModularWithFp2GpuBuilder {
269 type VmConfig = Rv32ModularWithFp2Config;
270 type SystemChipInventory = SystemChipInventoryGPU;
271 type RecordArena = DenseRecordArena;
272
273 fn create_chip_complex(
274 &self,
275 config: &Rv32ModularWithFp2Config,
276 circuit: AirInventory<BabyBearPoseidon2Config>,
277 ) -> Result<
278 VmChipComplex<
279 BabyBearPoseidon2Config,
280 Self::RecordArena,
281 GpuBackend,
282 Self::SystemChipInventory,
283 >,
284 ChipInventoryError,
285 > {
286 let mut chip_complex =
287 VmBuilder::<E>::create_chip_complex(&Rv32ModularGpuBuilder, &config.modular, circuit)?;
288 let inventory = &mut chip_complex.inventory;
289 VmProverExtension::<E, _, _>::extend_prover(&AlgebraGpuProverExt, &config.fp2, inventory)?;
290 Ok(chip_complex)
291 }
292}