openvm_rv32im_circuit/extension/
cuda.rs

1use std::sync::Arc;
2
3use openvm_circuit::{
4    arch::{ChipInventory, ChipInventoryError, DenseRecordArena, VmProverExtension},
5    system::cuda::extensions::{get_inventory_range_checker, get_or_create_bitwise_op_lookup},
6};
7use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerAir, RangeTupleCheckerChipGPU};
8use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend};
9use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config;
10
11use crate::{
12    Rv32AuipcAir, Rv32AuipcChipGpu, Rv32BaseAluAir, Rv32BaseAluChipGpu, Rv32BranchEqualAir,
13    Rv32BranchEqualChipGpu, Rv32BranchLessThanAir, Rv32BranchLessThanChipGpu, Rv32DivRemAir,
14    Rv32DivRemChipGpu, Rv32HintStoreAir, Rv32HintStoreChipGpu, Rv32I, Rv32Io, Rv32JalLuiAir,
15    Rv32JalLuiChipGpu, Rv32JalrAir, Rv32JalrChipGpu, Rv32LessThanAir, Rv32LessThanChipGpu,
16    Rv32LoadSignExtendAir, Rv32LoadSignExtendChipGpu, Rv32LoadStoreAir, Rv32LoadStoreChipGpu,
17    Rv32M, Rv32MulHAir, Rv32MulHChipGpu, Rv32MultiplicationAir, Rv32MultiplicationChipGpu,
18    Rv32ShiftAir, Rv32ShiftChipGpu,
19};
20
21pub struct Rv32ImGpuProverExt;
22
23// This implementation is specific to GpuBackend because the lookup chips
24// (VariableRangeCheckerChipGPU, BitwiseOperationLookupChipGPU) are specific to GpuBackend.
25impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Rv32I> for Rv32ImGpuProverExt {
26    fn extend_prover(
27        &self,
28        _: &Rv32I,
29        inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
30    ) -> Result<(), ChipInventoryError> {
31        let pointer_max_bits = inventory.airs().pointer_max_bits();
32        let timestamp_max_bits = inventory.timestamp_max_bits();
33
34        let range_checker = get_inventory_range_checker(inventory);
35        let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
36
37        // These calls to next_air are not strictly necessary to construct the chips, but provide a
38        // safeguard to ensure that chip construction matches the circuit definition
39        inventory.next_air::<Rv32BaseAluAir>()?;
40        let base_alu = Rv32BaseAluChipGpu::new(
41            range_checker.clone(),
42            bitwise_lu.clone(),
43            timestamp_max_bits,
44        );
45        inventory.add_executor_chip(base_alu);
46
47        inventory.next_air::<Rv32LessThanAir>()?;
48        let lt = Rv32LessThanChipGpu::new(
49            range_checker.clone(),
50            bitwise_lu.clone(),
51            timestamp_max_bits,
52        );
53        inventory.add_executor_chip(lt);
54
55        inventory.next_air::<Rv32ShiftAir>()?;
56        let shift = Rv32ShiftChipGpu::new(
57            range_checker.clone(),
58            bitwise_lu.clone(),
59            timestamp_max_bits,
60        );
61        inventory.add_executor_chip(shift);
62
63        inventory.next_air::<Rv32LoadStoreAir>()?;
64        let load_store_chip =
65            Rv32LoadStoreChipGpu::new(range_checker.clone(), pointer_max_bits, timestamp_max_bits);
66        inventory.add_executor_chip(load_store_chip);
67
68        inventory.next_air::<Rv32LoadSignExtendAir>()?;
69        let load_sign_extend = Rv32LoadSignExtendChipGpu::new(
70            range_checker.clone(),
71            pointer_max_bits,
72            timestamp_max_bits,
73        );
74        inventory.add_executor_chip(load_sign_extend);
75
76        inventory.next_air::<Rv32BranchEqualAir>()?;
77        let beq = Rv32BranchEqualChipGpu::new(range_checker.clone(), timestamp_max_bits);
78        inventory.add_executor_chip(beq);
79
80        inventory.next_air::<Rv32BranchLessThanAir>()?;
81        let blt = Rv32BranchLessThanChipGpu::new(
82            range_checker.clone(),
83            bitwise_lu.clone(),
84            timestamp_max_bits,
85        );
86        inventory.add_executor_chip(blt);
87
88        inventory.next_air::<Rv32JalLuiAir>()?;
89        let jal_lui = Rv32JalLuiChipGpu::new(
90            range_checker.clone(),
91            bitwise_lu.clone(),
92            timestamp_max_bits,
93        );
94        inventory.add_executor_chip(jal_lui);
95
96        inventory.next_air::<Rv32JalrAir>()?;
97        let jalr = Rv32JalrChipGpu::new(
98            range_checker.clone(),
99            bitwise_lu.clone(),
100            timestamp_max_bits,
101        );
102        inventory.add_executor_chip(jalr);
103
104        inventory.next_air::<Rv32AuipcAir>()?;
105        let auipc = Rv32AuipcChipGpu::new(
106            range_checker.clone(),
107            bitwise_lu.clone(),
108            timestamp_max_bits,
109        );
110        inventory.add_executor_chip(auipc);
111
112        Ok(())
113    }
114}
115
116// This implementation is specific to GpuBackend because the lookup chips
117// (VariableRangeCheckerChipGPU, BitwiseOperationLookupChipGPU) are specific to GpuBackend.
118impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Rv32M> for Rv32ImGpuProverExt {
119    fn extend_prover(
120        &self,
121        extension: &Rv32M,
122        inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
123    ) -> Result<(), ChipInventoryError> {
124        let pointer_max_bits = inventory.airs().pointer_max_bits();
125        let timestamp_max_bits = inventory.timestamp_max_bits();
126
127        let range_checker = get_inventory_range_checker(inventory);
128        let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
129
130        let range_tuple_checker = {
131            let existing_chip = inventory
132                .find_chip::<Arc<RangeTupleCheckerChipGPU<2>>>()
133                .find(|c| {
134                    c.sizes[0] >= extension.range_tuple_checker_sizes[0]
135                        && c.sizes[1] >= extension.range_tuple_checker_sizes[1]
136                });
137            if let Some(chip) = existing_chip {
138                chip.clone()
139            } else {
140                inventory.next_air::<RangeTupleCheckerAir<2>>()?;
141                let chip = Arc::new(RangeTupleCheckerChipGPU::new(
142                    extension.range_tuple_checker_sizes,
143                ));
144                inventory.add_periphery_chip(chip.clone());
145                chip
146            }
147        };
148
149        // These calls to next_air are not strictly necessary to construct the chips, but provide a
150        // safeguard to ensure that chip construction matches the circuit definition
151        inventory.next_air::<Rv32MultiplicationAir>()?;
152        let mult = Rv32MultiplicationChipGpu::new(
153            range_checker.clone(),
154            range_tuple_checker.clone(),
155            timestamp_max_bits,
156        );
157        inventory.add_executor_chip(mult);
158
159        inventory.next_air::<Rv32MulHAir>()?;
160        let mul_h = Rv32MulHChipGpu::new(
161            range_checker.clone(),
162            bitwise_lu.clone(),
163            range_tuple_checker.clone(),
164            timestamp_max_bits,
165        );
166        inventory.add_executor_chip(mul_h);
167
168        inventory.next_air::<Rv32DivRemAir>()?;
169        let div_rem = Rv32DivRemChipGpu::new(
170            range_checker.clone(),
171            bitwise_lu.clone(),
172            range_tuple_checker.clone(),
173            pointer_max_bits,
174            timestamp_max_bits,
175        );
176        inventory.add_executor_chip(div_rem);
177
178        Ok(())
179    }
180}
181
182// This implementation is specific to GpuBackend because the lookup chips
183// (VariableRangeCheckerChipGPU, BitwiseOperationLookupChipGPU) are specific to GpuBackend.
184impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Rv32Io>
185    for Rv32ImGpuProverExt
186{
187    fn extend_prover(
188        &self,
189        _: &Rv32Io,
190        inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
191    ) -> Result<(), ChipInventoryError> {
192        let pointer_max_bits = inventory.airs().pointer_max_bits();
193        let timestamp_max_bits = inventory.timestamp_max_bits();
194
195        let range_checker = get_inventory_range_checker(inventory);
196        let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
197
198        inventory.next_air::<Rv32HintStoreAir>()?;
199        let hint_store = Rv32HintStoreChipGpu::new(
200            range_checker.clone(),
201            bitwise_lu.clone(),
202            pointer_max_bits,
203            timestamp_max_bits,
204        );
205        inventory.add_executor_chip(hint_store);
206
207        Ok(())
208    }
209}