openvm_algebra_circuit/extension/
cuda.rs

1use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode};
2use openvm_circuit::{
3    arch::{
4        AirInventory, ChipInventory, ChipInventoryError, DenseRecordArena, VmBuilder,
5        VmChipComplex, VmProverExtension,
6    },
7    system::cuda::{
8        extensions::{
9            get_inventory_range_checker, get_or_create_bitwise_op_lookup, SystemGpuBuilder,
10        },
11        SystemChipInventoryGPU,
12    },
13};
14use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend};
15use openvm_instructions::LocalOpcode;
16use openvm_mod_circuit_builder::ExprBuilderConfig;
17use openvm_rv32im_circuit::Rv32ImGpuProverExt;
18use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config;
19use strum::EnumCount;
20
21use crate::{
22    fp2_chip::{Fp2AddSubChipGpu, Fp2Air, Fp2MulDivChipGpu},
23    modular_chip::{
24        ModularAddSubChipGpu, ModularAir, ModularIsEqualAir, ModularIsEqualChipGpu,
25        ModularMulDivChipGpu,
26    },
27    Fp2Extension, ModularExtension, Rv32ModularConfig, Rv32ModularWithFp2Config,
28};
29
30#[derive(Clone)]
31pub struct AlgebraGpuProverExt;
32
33// This implementation is specific to GpuBackend because the lookup chips
34// (VariableRangeCheckerChipGPU, BitwiseOperationLookupChipGPU) are specific to GpuBackend.
35impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Fp2Extension>
36    for AlgebraGpuProverExt
37{
38    fn extend_prover(
39        &self,
40        extension: &Fp2Extension,
41        inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
42    ) -> Result<(), ChipInventoryError> {
43        let pointer_max_bits = inventory.airs().pointer_max_bits();
44        let timestamp_max_bits = inventory.timestamp_max_bits();
45
46        // Range checker should always exist in inventory
47        let range_checker = get_inventory_range_checker(inventory);
48
49        let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
50
51        for (i, (_, modulus)) in extension.supported_moduli.iter().enumerate() {
52            // Determine the number of bytes needed to represent a prime field element
53            let bytes = modulus.bits().div_ceil(8);
54            let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT;
55
56            if bytes <= 32 {
57                let config = ExprBuilderConfig {
58                    modulus: modulus.clone(),
59                    num_limbs: 32,
60                    limb_bits: 8,
61                };
62
63                inventory.next_air::<Fp2Air<2, 32>>()?;
64                let addsub = Fp2AddSubChipGpu::<2, 32>::new(
65                    range_checker.clone(),
66                    bitwise_lu.clone(),
67                    config.clone(),
68                    start_offset,
69                    pointer_max_bits as u32,
70                    timestamp_max_bits as u32,
71                );
72                inventory.add_executor_chip(addsub);
73
74                inventory.next_air::<Fp2Air<2, 32>>()?;
75                let muldiv = Fp2MulDivChipGpu::<2, 32>::new(
76                    range_checker.clone(),
77                    bitwise_lu.clone(),
78                    config,
79                    start_offset,
80                    pointer_max_bits as u32,
81                    timestamp_max_bits as u32,
82                );
83                inventory.add_executor_chip(muldiv);
84            } else if bytes <= 48 {
85                let config = ExprBuilderConfig {
86                    modulus: modulus.clone(),
87                    num_limbs: 48,
88                    limb_bits: 8,
89                };
90
91                inventory.next_air::<Fp2Air<6, 16>>()?;
92                let addsub = Fp2AddSubChipGpu::<6, 16>::new(
93                    range_checker.clone(),
94                    bitwise_lu.clone(),
95                    config.clone(),
96                    start_offset,
97                    pointer_max_bits as u32,
98                    timestamp_max_bits as u32,
99                );
100                inventory.add_executor_chip(addsub);
101
102                inventory.next_air::<Fp2Air<6, 16>>()?;
103                let muldiv = Fp2MulDivChipGpu::<6, 16>::new(
104                    range_checker.clone(),
105                    bitwise_lu.clone(),
106                    config,
107                    start_offset,
108                    pointer_max_bits as u32,
109                    timestamp_max_bits as u32,
110                );
111                inventory.add_executor_chip(muldiv);
112            } else {
113                panic!("Modulus too large");
114            }
115        }
116
117        Ok(())
118    }
119}
120
121impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, ModularExtension>
122    for AlgebraGpuProverExt
123{
124    fn extend_prover(
125        &self,
126        extension: &ModularExtension,
127        inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
128    ) -> Result<(), ChipInventoryError> {
129        let pointer_max_bits = inventory.airs().pointer_max_bits();
130        let timestamp_max_bits = inventory.timestamp_max_bits();
131
132        // Range checker should always exist in inventory
133        let range_checker = get_inventory_range_checker(inventory);
134
135        let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
136
137        for (i, modulus) in extension.supported_moduli.iter().enumerate() {
138            let bytes = modulus.bits().div_ceil(8);
139            let start_offset =
140                Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
141
142            if bytes <= 32 {
143                let config = ExprBuilderConfig {
144                    modulus: modulus.clone(),
145                    num_limbs: 32,
146                    limb_bits: 8,
147                };
148
149                inventory.next_air::<ModularAir<1, 32>>()?;
150                let addsub = ModularAddSubChipGpu::<1, 32>::new(
151                    range_checker.clone(),
152                    bitwise_lu.clone(),
153                    config.clone(),
154                    start_offset,
155                    pointer_max_bits as u32,
156                    timestamp_max_bits as u32,
157                );
158                inventory.add_executor_chip(addsub);
159
160                inventory.next_air::<ModularAir<1, 32>>()?;
161                let muldiv = ModularMulDivChipGpu::<1, 32>::new(
162                    range_checker.clone(),
163                    bitwise_lu.clone(),
164                    config,
165                    start_offset,
166                    pointer_max_bits as u32,
167                    timestamp_max_bits as u32,
168                );
169                inventory.add_executor_chip(muldiv);
170
171                inventory.next_air::<ModularIsEqualAir<1, 32, 32>>()?;
172                let is_eq = ModularIsEqualChipGpu::<1, 32, 32>::new(
173                    range_checker.clone(),
174                    bitwise_lu.clone(),
175                    modulus.clone(),
176                    pointer_max_bits as u32,
177                    timestamp_max_bits as u32,
178                );
179                inventory.add_executor_chip(is_eq);
180            } else if bytes <= 48 {
181                let config = ExprBuilderConfig {
182                    modulus: modulus.clone(),
183                    num_limbs: 48,
184                    limb_bits: 8,
185                };
186
187                inventory.next_air::<ModularAir<3, 16>>()?;
188                let addsub = ModularAddSubChipGpu::<3, 16>::new(
189                    range_checker.clone(),
190                    bitwise_lu.clone(),
191                    config.clone(),
192                    start_offset,
193                    pointer_max_bits as u32,
194                    timestamp_max_bits as u32,
195                );
196                inventory.add_executor_chip(addsub);
197
198                inventory.next_air::<ModularAir<3, 16>>()?;
199                let muldiv = ModularMulDivChipGpu::<3, 16>::new(
200                    range_checker.clone(),
201                    bitwise_lu.clone(),
202                    config,
203                    start_offset,
204                    pointer_max_bits as u32,
205                    timestamp_max_bits as u32,
206                );
207                inventory.add_executor_chip(muldiv);
208
209                inventory.next_air::<ModularIsEqualAir<3, 16, 48>>()?;
210                let is_eq = ModularIsEqualChipGpu::<3, 16, 48>::new(
211                    range_checker.clone(),
212                    bitwise_lu.clone(),
213                    modulus.clone(),
214                    pointer_max_bits as u32,
215                    timestamp_max_bits as u32,
216                );
217                inventory.add_executor_chip(is_eq);
218            } else {
219                panic!("Modulus too large");
220            }
221        }
222
223        Ok(())
224    }
225}
226
227#[derive(Clone)]
228pub struct Rv32ModularGpuBuilder;
229
230type E = GpuBabyBearPoseidon2Engine;
231
232impl VmBuilder<E> for Rv32ModularGpuBuilder {
233    type VmConfig = Rv32ModularConfig;
234    type SystemChipInventory = SystemChipInventoryGPU;
235    type RecordArena = DenseRecordArena;
236
237    fn create_chip_complex(
238        &self,
239        config: &Rv32ModularConfig,
240        circuit: AirInventory<BabyBearPoseidon2Config>,
241    ) -> Result<
242        VmChipComplex<
243            BabyBearPoseidon2Config,
244            Self::RecordArena,
245            GpuBackend,
246            Self::SystemChipInventory,
247        >,
248        ChipInventoryError,
249    > {
250        let mut chip_complex =
251            VmBuilder::<E>::create_chip_complex(&SystemGpuBuilder, &config.system, circuit)?;
252        let inventory = &mut chip_complex.inventory;
253        VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.base, inventory)?;
254        VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.mul, inventory)?;
255        VmProverExtension::<E, _, _>::extend_prover(&Rv32ImGpuProverExt, &config.io, inventory)?;
256        VmProverExtension::<E, _, _>::extend_prover(
257            &AlgebraGpuProverExt,
258            &config.modular,
259            inventory,
260        )?;
261        Ok(chip_complex)
262    }
263}
264
265#[derive(Clone)]
266pub struct Rv32ModularWithFp2GpuBuilder;
267
268impl VmBuilder<E> for Rv32ModularWithFp2GpuBuilder {
269    type VmConfig = Rv32ModularWithFp2Config;
270    type SystemChipInventory = SystemChipInventoryGPU;
271    type RecordArena = DenseRecordArena;
272
273    fn create_chip_complex(
274        &self,
275        config: &Rv32ModularWithFp2Config,
276        circuit: AirInventory<BabyBearPoseidon2Config>,
277    ) -> Result<
278        VmChipComplex<
279            BabyBearPoseidon2Config,
280            Self::RecordArena,
281            GpuBackend,
282            Self::SystemChipInventory,
283        >,
284        ChipInventoryError,
285    > {
286        let mut chip_complex =
287            VmBuilder::<E>::create_chip_complex(&Rv32ModularGpuBuilder, &config.modular, circuit)?;
288        let inventory = &mut chip_complex.inventory;
289        VmProverExtension::<E, _, _>::extend_prover(&AlgebraGpuProverExt, &config.fp2, inventory)?;
290        Ok(chip_complex)
291    }
292}