openvm_bigint_circuit/cuda/
mod.rs

1use std::{mem::size_of, sync::Arc};
2
3use derive_new::new;
4use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
5use openvm_circuit_primitives::{
6    bitwise_op_lookup::BitwiseOperationLookupChipGPU, range_tuple::RangeTupleCheckerChipGPU,
7    var_range::VariableRangeCheckerChipGPU,
8};
9use openvm_cuda_backend::{
10    base::DeviceMatrix,
11    chip::{get_empty_air_proving_ctx, UInt2},
12    prelude::F,
13    prover_backend::GpuBackend,
14};
15use openvm_cuda_common::copy::MemCopyH2D;
16use openvm_rv32_adapters::{
17    Rv32HeapBranchAdapterCols, Rv32HeapBranchAdapterRecord, Rv32VecHeapAdapterCols,
18    Rv32VecHeapAdapterRecord,
19};
20use openvm_rv32im_circuit::{
21    adapters::{INT256_NUM_LIMBS, RV32_CELL_BITS},
22    BaseAluCoreCols, BaseAluCoreRecord, BranchEqualCoreCols, BranchEqualCoreRecord,
23    BranchLessThanCoreCols, BranchLessThanCoreRecord, LessThanCoreCols, LessThanCoreRecord,
24    MultiplicationCoreCols, MultiplicationCoreRecord, ShiftCoreCols, ShiftCoreRecord,
25};
26use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
27
28mod cuda_abi;
29
30//////////////////////////////////////////////////////////////////////////////////////
31/// ALU
32//////////////////////////////////////////////////////////////////////////////////////
33pub type BaseAlu256AdapterRecord =
34    Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
35pub type BaseAlu256CoreRecord = BaseAluCoreRecord<INT256_NUM_LIMBS>;
36
37#[derive(new)]
38pub struct BaseAlu256ChipGpu {
39    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
40    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
41    pub pointer_max_bits: usize,
42    pub timestamp_max_bits: usize,
43}
44
45impl Chip<DenseRecordArena, GpuBackend> for BaseAlu256ChipGpu {
46    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
47        const RECORD_SIZE: usize = size_of::<(BaseAlu256AdapterRecord, BaseAlu256CoreRecord)>();
48        let records = arena.allocated();
49        if records.is_empty() {
50            return get_empty_air_proving_ctx::<GpuBackend>();
51        }
52        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
53
54        let trace_width = BaseAluCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
55            + Rv32VecHeapAdapterCols::<F, 2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>::width();
56        let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
57
58        let d_records = records.to_device().unwrap();
59        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
60
61        unsafe {
62            cuda_abi::alu256::tracegen(
63                d_trace.buffer(),
64                trace_height,
65                &d_records,
66                &self.range_checker.count,
67                &self.bitwise_lookup.count,
68                RV32_CELL_BITS,
69                self.pointer_max_bits as u32,
70                self.timestamp_max_bits as u32,
71            )
72            .unwrap();
73        }
74
75        AirProvingContext::simple_no_pis(d_trace)
76    }
77}
78
79//////////////////////////////////////////////////////////////////////////////////////
80/// Branch Equal
81//////////////////////////////////////////////////////////////////////////////////////
82pub type BranchEqual256AdapterRecord = Rv32HeapBranchAdapterRecord<2>;
83pub type BranchEqual256CoreRecord = BranchEqualCoreRecord<INT256_NUM_LIMBS>;
84
85#[derive(new)]
86pub struct BranchEqual256ChipGpu {
87    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
88    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
89    pub pointer_max_bits: usize,
90    pub timestamp_max_bits: usize,
91}
92
93impl Chip<DenseRecordArena, GpuBackend> for BranchEqual256ChipGpu {
94    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
95        const RECORD_SIZE: usize =
96            size_of::<(BranchEqual256AdapterRecord, BranchEqual256CoreRecord)>();
97        let records = arena.allocated();
98        if records.is_empty() {
99            return get_empty_air_proving_ctx::<GpuBackend>();
100        }
101        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
102
103        let trace_width = BranchEqualCoreCols::<F, INT256_NUM_LIMBS>::width()
104            + Rv32HeapBranchAdapterCols::<F, 2, INT256_NUM_LIMBS>::width();
105        let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
106
107        let d_records = records.to_device().unwrap();
108        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
109
110        unsafe {
111            cuda_abi::beq256::tracegen(
112                d_trace.buffer(),
113                trace_height,
114                &d_records,
115                &self.range_checker.count,
116                &self.bitwise_lookup.count,
117                RV32_CELL_BITS,
118                self.pointer_max_bits as u32,
119                self.timestamp_max_bits as u32,
120            )
121            .unwrap();
122        }
123
124        AirProvingContext::simple_no_pis(d_trace)
125    }
126}
127
128//////////////////////////////////////////////////////////////////////////////////////
129/// Less Than
130//////////////////////////////////////////////////////////////////////////////////////
131pub type LessThan256AdapterRecord =
132    Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
133pub type LessThan256CoreRecord = LessThanCoreRecord<INT256_NUM_LIMBS, RV32_CELL_BITS>;
134
135#[derive(new)]
136pub struct LessThan256ChipGpu {
137    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
138    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
139    pub pointer_max_bits: usize,
140    pub timestamp_max_bits: usize,
141}
142
143impl Chip<DenseRecordArena, GpuBackend> for LessThan256ChipGpu {
144    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
145        const RECORD_SIZE: usize = size_of::<(LessThan256AdapterRecord, LessThan256CoreRecord)>();
146        let records = arena.allocated();
147        if records.is_empty() {
148            return get_empty_air_proving_ctx::<GpuBackend>();
149        }
150        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
151
152        let trace_width = LessThanCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
153            + Rv32VecHeapAdapterCols::<F, 2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>::width();
154        let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
155
156        let d_records = records.to_device().unwrap();
157        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
158
159        unsafe {
160            cuda_abi::lt256::tracegen(
161                d_trace.buffer(),
162                trace_height,
163                &d_records,
164                &self.range_checker.count,
165                &self.bitwise_lookup.count,
166                RV32_CELL_BITS,
167                self.pointer_max_bits as u32,
168                self.timestamp_max_bits as u32,
169            )
170            .unwrap();
171        }
172
173        AirProvingContext::simple_no_pis(d_trace)
174    }
175}
176
177//////////////////////////////////////////////////////////////////////////////////////
178/// Branch Less Than
179//////////////////////////////////////////////////////////////////////////////////////
180pub type BranchLessThan256AdapterRecord = Rv32HeapBranchAdapterRecord<2>;
181pub type BranchLessThan256CoreRecord = BranchLessThanCoreRecord<INT256_NUM_LIMBS, RV32_CELL_BITS>;
182
183#[derive(new)]
184pub struct BranchLessThan256ChipGpu {
185    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
186    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
187    pub pointer_max_bits: usize,
188    pub timestamp_max_bits: usize,
189}
190
191impl Chip<DenseRecordArena, GpuBackend> for BranchLessThan256ChipGpu {
192    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
193        const RECORD_SIZE: usize =
194            size_of::<(BranchLessThan256AdapterRecord, BranchLessThan256CoreRecord)>();
195        let records = arena.allocated();
196        if records.is_empty() {
197            return get_empty_air_proving_ctx::<GpuBackend>();
198        }
199        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
200
201        let trace_width = BranchLessThanCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
202            + Rv32HeapBranchAdapterCols::<F, 2, INT256_NUM_LIMBS>::width();
203        let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
204
205        let d_records = records.to_device().unwrap();
206        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
207
208        unsafe {
209            cuda_abi::blt256::tracegen(
210                d_trace.buffer(),
211                trace_height,
212                &d_records,
213                &self.range_checker.count,
214                &self.bitwise_lookup.count,
215                RV32_CELL_BITS,
216                self.pointer_max_bits as u32,
217                self.timestamp_max_bits as u32,
218            )
219            .unwrap();
220        }
221
222        AirProvingContext::simple_no_pis(d_trace)
223    }
224}
225
226//////////////////////////////////////////////////////////////////////////////////////
227/// Shift
228//////////////////////////////////////////////////////////////////////////////////////
229pub type Shift256AdapterRecord =
230    Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
231pub type Shift256CoreRecord = ShiftCoreRecord<INT256_NUM_LIMBS, RV32_CELL_BITS>;
232
233#[derive(new)]
234pub struct Shift256ChipGpu {
235    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
236    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
237    pub pointer_max_bits: usize,
238    pub timestamp_max_bits: usize,
239}
240
241impl Chip<DenseRecordArena, GpuBackend> for Shift256ChipGpu {
242    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
243        const RECORD_SIZE: usize = size_of::<(Shift256AdapterRecord, Shift256CoreRecord)>();
244        let records = arena.allocated();
245        if records.is_empty() {
246            return get_empty_air_proving_ctx::<GpuBackend>();
247        }
248        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
249
250        let trace_width = ShiftCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
251            + Rv32VecHeapAdapterCols::<F, 2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>::width();
252        let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
253
254        let d_records = records.to_device().unwrap();
255        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
256
257        unsafe {
258            cuda_abi::shift256::tracegen(
259                d_trace.buffer(),
260                trace_height,
261                &d_records,
262                &self.range_checker.count,
263                &self.bitwise_lookup.count,
264                RV32_CELL_BITS,
265                self.pointer_max_bits as u32,
266                self.timestamp_max_bits as u32,
267            )
268            .unwrap();
269        }
270
271        AirProvingContext::simple_no_pis(d_trace)
272    }
273}
274
275//////////////////////////////////////////////////////////////////////////////////////
276/// Multiplication
277//////////////////////////////////////////////////////////////////////////////////////
278pub type Multiplication256AdapterRecord =
279    Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
280pub type Multiplication256CoreRecord = MultiplicationCoreRecord<INT256_NUM_LIMBS, RV32_CELL_BITS>;
281
282#[derive(new)]
283pub struct Multiplication256ChipGpu {
284    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
285    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
286    pub range_tuple_checker: Arc<RangeTupleCheckerChipGPU<2>>,
287    pub pointer_max_bits: usize,
288    pub timestamp_max_bits: usize,
289}
290
291impl Chip<DenseRecordArena, GpuBackend> for Multiplication256ChipGpu {
292    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
293        const RECORD_SIZE: usize =
294            size_of::<(Multiplication256AdapterRecord, Multiplication256CoreRecord)>();
295        let records = arena.allocated();
296        if records.is_empty() {
297            return get_empty_air_proving_ctx::<GpuBackend>();
298        }
299        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
300
301        let trace_width = MultiplicationCoreCols::<F, INT256_NUM_LIMBS, RV32_CELL_BITS>::width()
302            + Rv32VecHeapAdapterCols::<F, 2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>::width();
303        let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
304
305        let d_records = records.to_device().unwrap();
306        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
307
308        let sizes = self.range_tuple_checker.sizes;
309        let d_sizes = UInt2 {
310            x: sizes[0],
311            y: sizes[1],
312        };
313        unsafe {
314            cuda_abi::mul256::tracegen(
315                d_trace.buffer(),
316                trace_height,
317                &d_records,
318                &self.range_checker.count,
319                &self.bitwise_lookup.count,
320                RV32_CELL_BITS,
321                &self.range_tuple_checker.count,
322                d_sizes,
323                self.pointer_max_bits as u32,
324                self.timestamp_max_bits as u32,
325            )
326            .unwrap();
327        }
328
329        AirProvingContext::simple_no_pis(d_trace)
330    }
331}