openvm_rv32im_circuit/extension/
cuda.rs1use std::sync::Arc;
2
3use openvm_circuit::{
4 arch::{ChipInventory, ChipInventoryError, DenseRecordArena, VmProverExtension},
5 system::cuda::extensions::{get_inventory_range_checker, get_or_create_bitwise_op_lookup},
6};
7use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerAir, RangeTupleCheckerChipGPU};
8use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend};
9use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config;
10
11use crate::{
12 Rv32AuipcAir, Rv32AuipcChipGpu, Rv32BaseAluAir, Rv32BaseAluChipGpu, Rv32BranchEqualAir,
13 Rv32BranchEqualChipGpu, Rv32BranchLessThanAir, Rv32BranchLessThanChipGpu, Rv32DivRemAir,
14 Rv32DivRemChipGpu, Rv32HintStoreAir, Rv32HintStoreChipGpu, Rv32I, Rv32Io, Rv32JalLuiAir,
15 Rv32JalLuiChipGpu, Rv32JalrAir, Rv32JalrChipGpu, Rv32LessThanAir, Rv32LessThanChipGpu,
16 Rv32LoadSignExtendAir, Rv32LoadSignExtendChipGpu, Rv32LoadStoreAir, Rv32LoadStoreChipGpu,
17 Rv32M, Rv32MulHAir, Rv32MulHChipGpu, Rv32MultiplicationAir, Rv32MultiplicationChipGpu,
18 Rv32ShiftAir, Rv32ShiftChipGpu,
19};
20
21pub struct Rv32ImGpuProverExt;
22
23impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Rv32I> for Rv32ImGpuProverExt {
26 fn extend_prover(
27 &self,
28 _: &Rv32I,
29 inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
30 ) -> Result<(), ChipInventoryError> {
31 let pointer_max_bits = inventory.airs().pointer_max_bits();
32 let timestamp_max_bits = inventory.timestamp_max_bits();
33
34 let range_checker = get_inventory_range_checker(inventory);
35 let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
36
37 inventory.next_air::<Rv32BaseAluAir>()?;
40 let base_alu = Rv32BaseAluChipGpu::new(
41 range_checker.clone(),
42 bitwise_lu.clone(),
43 timestamp_max_bits,
44 );
45 inventory.add_executor_chip(base_alu);
46
47 inventory.next_air::<Rv32LessThanAir>()?;
48 let lt = Rv32LessThanChipGpu::new(
49 range_checker.clone(),
50 bitwise_lu.clone(),
51 timestamp_max_bits,
52 );
53 inventory.add_executor_chip(lt);
54
55 inventory.next_air::<Rv32ShiftAir>()?;
56 let shift = Rv32ShiftChipGpu::new(
57 range_checker.clone(),
58 bitwise_lu.clone(),
59 timestamp_max_bits,
60 );
61 inventory.add_executor_chip(shift);
62
63 inventory.next_air::<Rv32LoadStoreAir>()?;
64 let load_store_chip =
65 Rv32LoadStoreChipGpu::new(range_checker.clone(), pointer_max_bits, timestamp_max_bits);
66 inventory.add_executor_chip(load_store_chip);
67
68 inventory.next_air::<Rv32LoadSignExtendAir>()?;
69 let load_sign_extend = Rv32LoadSignExtendChipGpu::new(
70 range_checker.clone(),
71 pointer_max_bits,
72 timestamp_max_bits,
73 );
74 inventory.add_executor_chip(load_sign_extend);
75
76 inventory.next_air::<Rv32BranchEqualAir>()?;
77 let beq = Rv32BranchEqualChipGpu::new(range_checker.clone(), timestamp_max_bits);
78 inventory.add_executor_chip(beq);
79
80 inventory.next_air::<Rv32BranchLessThanAir>()?;
81 let blt = Rv32BranchLessThanChipGpu::new(
82 range_checker.clone(),
83 bitwise_lu.clone(),
84 timestamp_max_bits,
85 );
86 inventory.add_executor_chip(blt);
87
88 inventory.next_air::<Rv32JalLuiAir>()?;
89 let jal_lui = Rv32JalLuiChipGpu::new(
90 range_checker.clone(),
91 bitwise_lu.clone(),
92 timestamp_max_bits,
93 );
94 inventory.add_executor_chip(jal_lui);
95
96 inventory.next_air::<Rv32JalrAir>()?;
97 let jalr = Rv32JalrChipGpu::new(
98 range_checker.clone(),
99 bitwise_lu.clone(),
100 timestamp_max_bits,
101 );
102 inventory.add_executor_chip(jalr);
103
104 inventory.next_air::<Rv32AuipcAir>()?;
105 let auipc = Rv32AuipcChipGpu::new(
106 range_checker.clone(),
107 bitwise_lu.clone(),
108 timestamp_max_bits,
109 );
110 inventory.add_executor_chip(auipc);
111
112 Ok(())
113 }
114}
115
116impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Rv32M> for Rv32ImGpuProverExt {
119 fn extend_prover(
120 &self,
121 extension: &Rv32M,
122 inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
123 ) -> Result<(), ChipInventoryError> {
124 let pointer_max_bits = inventory.airs().pointer_max_bits();
125 let timestamp_max_bits = inventory.timestamp_max_bits();
126
127 let range_checker = get_inventory_range_checker(inventory);
128 let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
129
130 let range_tuple_checker = {
131 let existing_chip = inventory
132 .find_chip::<Arc<RangeTupleCheckerChipGPU<2>>>()
133 .find(|c| {
134 c.sizes[0] >= extension.range_tuple_checker_sizes[0]
135 && c.sizes[1] >= extension.range_tuple_checker_sizes[1]
136 });
137 if let Some(chip) = existing_chip {
138 chip.clone()
139 } else {
140 inventory.next_air::<RangeTupleCheckerAir<2>>()?;
141 let chip = Arc::new(RangeTupleCheckerChipGPU::new(
142 extension.range_tuple_checker_sizes,
143 ));
144 inventory.add_periphery_chip(chip.clone());
145 chip
146 }
147 };
148
149 inventory.next_air::<Rv32MultiplicationAir>()?;
152 let mult = Rv32MultiplicationChipGpu::new(
153 range_checker.clone(),
154 range_tuple_checker.clone(),
155 timestamp_max_bits,
156 );
157 inventory.add_executor_chip(mult);
158
159 inventory.next_air::<Rv32MulHAir>()?;
160 let mul_h = Rv32MulHChipGpu::new(
161 range_checker.clone(),
162 bitwise_lu.clone(),
163 range_tuple_checker.clone(),
164 timestamp_max_bits,
165 );
166 inventory.add_executor_chip(mul_h);
167
168 inventory.next_air::<Rv32DivRemAir>()?;
169 let div_rem = Rv32DivRemChipGpu::new(
170 range_checker.clone(),
171 bitwise_lu.clone(),
172 range_tuple_checker.clone(),
173 pointer_max_bits,
174 timestamp_max_bits,
175 );
176 inventory.add_executor_chip(div_rem);
177
178 Ok(())
179 }
180}
181
182impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Rv32Io>
185 for Rv32ImGpuProverExt
186{
187 fn extend_prover(
188 &self,
189 _: &Rv32Io,
190 inventory: &mut ChipInventory<BabyBearPoseidon2Config, DenseRecordArena, GpuBackend>,
191 ) -> Result<(), ChipInventoryError> {
192 let pointer_max_bits = inventory.airs().pointer_max_bits();
193 let timestamp_max_bits = inventory.timestamp_max_bits();
194
195 let range_checker = get_inventory_range_checker(inventory);
196 let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?;
197
198 inventory.next_air::<Rv32HintStoreAir>()?;
199 let hint_store = Rv32HintStoreChipGpu::new(
200 range_checker.clone(),
201 bitwise_lu.clone(),
202 pointer_max_bits,
203 timestamp_max_bits,
204 );
205 inventory.add_executor_chip(hint_store);
206
207 Ok(())
208 }
209}