1use std::{mem::size_of, sync::Arc};
2
3use derive_new::new;
4use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
5use openvm_circuit_primitives::{
6 bitwise_op_lookup::BitwiseOperationLookupChipGPU, range_tuple::RangeTupleCheckerChipGPU,
7 var_range::VariableRangeCheckerChipGPU,
8};
9use openvm_cuda_backend::{
10 base::DeviceMatrix,
11 chip::{get_empty_air_proving_ctx, UInt2},
12 prelude::F,
13 prover_backend::GpuBackend,
14};
15use openvm_cuda_common::copy::MemCopyH2D;
16use openvm_rv32_adapters::{
17 Rv32HeapBranchAdapterCols, Rv32HeapBranchAdapterRecord, Rv32VecHeapAdapterCols,
18 Rv32VecHeapAdapterRecord,
19};
20use openvm_rv32im_circuit::{
21 adapters::{INT256_NUM_LIMBS, RV32_CELL_BITS},
22 BaseAluCoreCols, BaseAluCoreRecord, BranchEqualCoreCols, BranchEqualCoreRecord,
23 BranchLessThanCoreCols, BranchLessThanCoreRecord, LessThanCoreCols, LessThanCoreRecord,
24 MultiplicationCoreCols, MultiplicationCoreRecord, ShiftCoreCols, ShiftCoreRecord,
25};
26use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
27
28mod cuda_abi;
29
30pub type BaseAlu256AdapterRecord =
34 Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
35pub type BaseAlu256CoreRecord = BaseAluCoreRecord<INT256_NUM_LIMBS>;
36
37#[derive(new)]
38pub struct BaseAlu256ChipGpu {
39 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
40 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
41 pub pointer_max_bits: usize,
42 pub timestamp_max_bits: usize,
43}
44
45impl Chip<DenseRecordArena, GpuBackend> for BaseAlu256ChipGpu {
46 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
47 const RECORD_SIZE: usize = size_of::<(BaseAlu256AdapterRecord, BaseAlu256CoreRecord)>();
48 let records = arena.allocated();
49 if records.is_empty() {
50 return get_empty_air_proving_ctx::<GpuBackend>();
51 }
52 debug_assert_eq!(records.len() % RECORD_SIZE, 0);
53
54 let trace_width = BaseAluCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
55 + Rv32VecHeapAdapterCols::<F, 2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>::width();
56 let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
57
58 let d_records = records.to_device().unwrap();
59 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
60
61 unsafe {
62 cuda_abi::alu256::tracegen(
63 d_trace.buffer(),
64 trace_height,
65 &d_records,
66 &self.range_checker.count,
67 &self.bitwise_lookup.count,
68 RV32_CELL_BITS,
69 self.pointer_max_bits as u32,
70 self.timestamp_max_bits as u32,
71 )
72 .unwrap();
73 }
74
75 AirProvingContext::simple_no_pis(d_trace)
76 }
77}
78
79pub type BranchEqual256AdapterRecord = Rv32HeapBranchAdapterRecord<2>;
83pub type BranchEqual256CoreRecord = BranchEqualCoreRecord<INT256_NUM_LIMBS>;
84
85#[derive(new)]
86pub struct BranchEqual256ChipGpu {
87 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
88 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
89 pub pointer_max_bits: usize,
90 pub timestamp_max_bits: usize,
91}
92
93impl Chip<DenseRecordArena, GpuBackend> for BranchEqual256ChipGpu {
94 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
95 const RECORD_SIZE: usize =
96 size_of::<(BranchEqual256AdapterRecord, BranchEqual256CoreRecord)>();
97 let records = arena.allocated();
98 if records.is_empty() {
99 return get_empty_air_proving_ctx::<GpuBackend>();
100 }
101 debug_assert_eq!(records.len() % RECORD_SIZE, 0);
102
103 let trace_width = BranchEqualCoreCols::<F, INT256_NUM_LIMBS>::width()
104 + Rv32HeapBranchAdapterCols::<F, 2, INT256_NUM_LIMBS>::width();
105 let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
106
107 let d_records = records.to_device().unwrap();
108 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
109
110 unsafe {
111 cuda_abi::beq256::tracegen(
112 d_trace.buffer(),
113 trace_height,
114 &d_records,
115 &self.range_checker.count,
116 &self.bitwise_lookup.count,
117 RV32_CELL_BITS,
118 self.pointer_max_bits as u32,
119 self.timestamp_max_bits as u32,
120 )
121 .unwrap();
122 }
123
124 AirProvingContext::simple_no_pis(d_trace)
125 }
126}
127
128pub type LessThan256AdapterRecord =
132 Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
133pub type LessThan256CoreRecord = LessThanCoreRecord<INT256_NUM_LIMBS, RV32_CELL_BITS>;
134
135#[derive(new)]
136pub struct LessThan256ChipGpu {
137 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
138 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
139 pub pointer_max_bits: usize,
140 pub timestamp_max_bits: usize,
141}
142
143impl Chip<DenseRecordArena, GpuBackend> for LessThan256ChipGpu {
144 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
145 const RECORD_SIZE: usize = size_of::<(LessThan256AdapterRecord, LessThan256CoreRecord)>();
146 let records = arena.allocated();
147 if records.is_empty() {
148 return get_empty_air_proving_ctx::<GpuBackend>();
149 }
150 debug_assert_eq!(records.len() % RECORD_SIZE, 0);
151
152 let trace_width = LessThanCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
153 + Rv32VecHeapAdapterCols::<F, 2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>::width();
154 let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
155
156 let d_records = records.to_device().unwrap();
157 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
158
159 unsafe {
160 cuda_abi::lt256::tracegen(
161 d_trace.buffer(),
162 trace_height,
163 &d_records,
164 &self.range_checker.count,
165 &self.bitwise_lookup.count,
166 RV32_CELL_BITS,
167 self.pointer_max_bits as u32,
168 self.timestamp_max_bits as u32,
169 )
170 .unwrap();
171 }
172
173 AirProvingContext::simple_no_pis(d_trace)
174 }
175}
176
177pub type BranchLessThan256AdapterRecord = Rv32HeapBranchAdapterRecord<2>;
181pub type BranchLessThan256CoreRecord = BranchLessThanCoreRecord<INT256_NUM_LIMBS, RV32_CELL_BITS>;
182
183#[derive(new)]
184pub struct BranchLessThan256ChipGpu {
185 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
186 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
187 pub pointer_max_bits: usize,
188 pub timestamp_max_bits: usize,
189}
190
191impl Chip<DenseRecordArena, GpuBackend> for BranchLessThan256ChipGpu {
192 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
193 const RECORD_SIZE: usize =
194 size_of::<(BranchLessThan256AdapterRecord, BranchLessThan256CoreRecord)>();
195 let records = arena.allocated();
196 if records.is_empty() {
197 return get_empty_air_proving_ctx::<GpuBackend>();
198 }
199 debug_assert_eq!(records.len() % RECORD_SIZE, 0);
200
201 let trace_width = BranchLessThanCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
202 + Rv32HeapBranchAdapterCols::<F, 2, INT256_NUM_LIMBS>::width();
203 let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
204
205 let d_records = records.to_device().unwrap();
206 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
207
208 unsafe {
209 cuda_abi::blt256::tracegen(
210 d_trace.buffer(),
211 trace_height,
212 &d_records,
213 &self.range_checker.count,
214 &self.bitwise_lookup.count,
215 RV32_CELL_BITS,
216 self.pointer_max_bits as u32,
217 self.timestamp_max_bits as u32,
218 )
219 .unwrap();
220 }
221
222 AirProvingContext::simple_no_pis(d_trace)
223 }
224}
225
226pub type Shift256AdapterRecord =
230 Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
231pub type Shift256CoreRecord = ShiftCoreRecord<INT256_NUM_LIMBS, RV32_CELL_BITS>;
232
233#[derive(new)]
234pub struct Shift256ChipGpu {
235 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
236 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
237 pub pointer_max_bits: usize,
238 pub timestamp_max_bits: usize,
239}
240
241impl Chip<DenseRecordArena, GpuBackend> for Shift256ChipGpu {
242 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
243 const RECORD_SIZE: usize = size_of::<(Shift256AdapterRecord, Shift256CoreRecord)>();
244 let records = arena.allocated();
245 if records.is_empty() {
246 return get_empty_air_proving_ctx::<GpuBackend>();
247 }
248 debug_assert_eq!(records.len() % RECORD_SIZE, 0);
249
250 let trace_width = ShiftCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
251 + Rv32VecHeapAdapterCols::<F, 2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>::width();
252 let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
253
254 let d_records = records.to_device().unwrap();
255 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
256
257 unsafe {
258 cuda_abi::shift256::tracegen(
259 d_trace.buffer(),
260 trace_height,
261 &d_records,
262 &self.range_checker.count,
263 &self.bitwise_lookup.count,
264 RV32_CELL_BITS,
265 self.pointer_max_bits as u32,
266 self.timestamp_max_bits as u32,
267 )
268 .unwrap();
269 }
270
271 AirProvingContext::simple_no_pis(d_trace)
272 }
273}
274
275pub type Multiplication256AdapterRecord =
279 Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
280pub type Multiplication256CoreRecord = MultiplicationCoreRecord<INT256_NUM_LIMBS, RV32_CELL_BITS>;
281
282#[derive(new)]
283pub struct Multiplication256ChipGpu {
284 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
285 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
286 pub range_tuple_checker: Arc<RangeTupleCheckerChipGPU<2>>,
287 pub pointer_max_bits: usize,
288 pub timestamp_max_bits: usize,
289}
290
291impl Chip<DenseRecordArena, GpuBackend> for Multiplication256ChipGpu {
292 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
293 const RECORD_SIZE: usize =
294 size_of::<(Multiplication256AdapterRecord, Multiplication256CoreRecord)>();
295 let records = arena.allocated();
296 if records.is_empty() {
297 return get_empty_air_proving_ctx::<GpuBackend>();
298 }
299 debug_assert_eq!(records.len() % RECORD_SIZE, 0);
300
301 let trace_width = MultiplicationCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
302 + Rv32VecHeapAdapterCols::<F, 2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>::width();
303 let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
304
305 let d_records = records.to_device().unwrap();
306 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
307
308 let sizes = self.range_tuple_checker.sizes;
309 let d_sizes = UInt2 {
310 x: sizes[0],
311 y: sizes[1],
312 };
313 unsafe {
314 cuda_abi::mul256::tracegen(
315 d_trace.buffer(),
316 trace_height,
317 &d_records,
318 &self.range_checker.count,
319 &self.bitwise_lookup.count,
320 RV32_CELL_BITS,
321 &self.range_tuple_checker.count,
322 d_sizes,
323 self.pointer_max_bits as u32,
324 self.timestamp_max_bits as u32,
325 )
326 .unwrap();
327 }
328
329 AirProvingContext::simple_no_pis(d_trace)
330 }
331}