openvm_mod_circuit_builder/cuda/
chip.rs

1#![allow(clippy::too_many_arguments)]
2#![allow(clippy::type_complexity)]
3
4use std::{collections::HashMap, sync::Arc};
5
6use cuda_runtime_sys::{cudaDeviceSetLimit, cudaLimit};
7use num_bigint::BigUint;
8use num_traits::{FromBytes, One};
9use openvm_circuit::utils::next_power_of_two_or_zero;
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
12};
13use openvm_cuda_backend::{base::DeviceMatrix, types::F};
14use openvm_cuda_common::{
15    copy::{MemCopyD2H, MemCopyH2D},
16    d_buffer::DeviceBuffer,
17};
18use openvm_stark_backend::p3_air::BaseAir;
19
20use crate::{
21    cuda::{
22        constants::{ExprType, LIMB_BITS, MAX_LIMBS},
23        expr_op::ExprOp,
24    },
25    cuda_abi::field_expression::tracegen,
26    utils::biguint_to_limbs_vec,
27    ExprMeta, ExprNode, FieldExprMeta, FieldExpressionChipGPU, FieldExpressionCoreAir,
28    SymbolicExpr,
29};
30
31impl FieldExpressionChipGPU {
32    pub fn new(
33        air: FieldExpressionCoreAir,
34        records: DeviceBuffer<u8>,
35        num_records: usize,
36        record_stride: usize,
37        adapter_width: usize,
38        adapter_blocks: usize,
39        range_checker: Arc<VariableRangeCheckerChipGPU>,
40        bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<LIMB_BITS>>,
41        pointer_max_bits: u32,
42        timestamp_max_bits: u32,
43    ) -> Self {
44        let num_inputs = air.num_inputs() as u32;
45        let num_vars = air.num_vars() as u32;
46        let num_u32_flags = air.num_flags() as u32;
47        let core_width = BaseAir::<F>::width(&air) as u32;
48        let trace_width = adapter_width as u32 + core_width;
49
50        let num_limbs = air.expr.canonical_num_limbs() as u32;
51        let limb_bits = air.expr.canonical_limb_bits() as u32;
52
53        let prime_limbs = air
54            .expr
55            .builder
56            .prime_limbs
57            .iter()
58            .map(|&x| x as u8)
59            .collect::<Vec<_>>();
60
61        // Pad prime_limbs to next valid size (32 or 48)
62        let padded_limbs_len = if air.expr.builder.prime_limbs.len() <= 32 {
63            32
64        } else {
65            48
66        };
67        let mut padded_prime_limbs = air
68            .expr
69            .builder
70            .prime_limbs
71            .iter()
72            .map(|&x| x as u32)
73            .collect::<Vec<_>>();
74        padded_prime_limbs.resize(padded_limbs_len, 0u32);
75
76        let prime_limbs_buf = padded_prime_limbs.to_device().unwrap();
77
78        // Compute Barrett mu constant
79        let p_big: BigUint = BigUint::from_le_bytes(&prime_limbs);
80        let actual_limbs = air.expr.builder.prime_limbs.len();
81        let two_n_bits = 2 * actual_limbs * limb_bits as usize;
82        let b2n = BigUint::one() << two_n_bits;
83        let mu_big = &b2n / &p_big;
84        let mu_limbs = biguint_to_limbs_vec(&mu_big, 2 * MAX_LIMBS);
85        let barrett_mu_buf = mu_limbs.to_device().unwrap();
86
87        let (
88            expr_meta,
89            compute_expr_ops_buf,
90            compute_roots_buf,
91            constraint_expr_ops_buf,
92            constraint_roots_buf,
93            constants_buf,
94            const_limb_counts_buf,
95            q_limb_counts_buf,
96            carry_limb_counts_buf,
97            ast_depth,
98            max_q_count,
99        ) = Self::build_expr_meta(
100            &air,
101            num_vars,
102            num_limbs,
103            limb_bits,
104            &prime_limbs_buf,
105            &barrett_mu_buf,
106        );
107
108        let local_opcode_idx_buf = air
109            .local_opcode_idx
110            .iter()
111            .map(|&x| x as u32)
112            .collect::<Vec<_>>()
113            .to_device()
114            .unwrap();
115        let opcode_flag_idx_buf = if air.opcode_flag_idx.is_empty() {
116            DeviceBuffer::new()
117        } else {
118            air.opcode_flag_idx
119                .iter()
120                .map(|&x| x as u32)
121                .collect::<Vec<_>>()
122                .to_device()
123                .unwrap()
124        };
125        let output_indices_buf = if air.output_indices().is_empty() {
126            DeviceBuffer::new()
127        } else {
128            air.output_indices()
129                .iter()
130                .map(|&x| x as u32)
131                .collect::<Vec<_>>()
132                .to_device()
133                .unwrap()
134        };
135
136        let input_limbs_offset = std::mem::size_of::<u8>();
137
138        let meta_host = FieldExprMeta {
139            num_inputs,
140            num_u32_flags,
141            num_limbs,
142            limb_bits,
143            adapter_blocks: adapter_blocks as u32,
144            adapter_width: adapter_width as u32,
145            core_width,
146            trace_width,
147            local_opcode_idx: local_opcode_idx_buf.as_ptr(),
148            opcode_flag_idx: opcode_flag_idx_buf.as_ptr(),
149            output_indices: output_indices_buf.as_ptr(),
150            num_local_opcodes: air.local_opcode_idx.len() as u32,
151            num_output_indices: air.output_indices().len() as u32,
152            record_stride: record_stride as u32,
153            input_limbs_offset: input_limbs_offset as u32,
154            q_limb_counts: q_limb_counts_buf.as_ptr(),
155            carry_limb_counts: carry_limb_counts_buf.as_ptr(),
156            compute_expr_ops: compute_expr_ops_buf.as_ptr() as *const ExprOp,
157            compute_root_indices: compute_roots_buf.as_ptr(),
158            constraint_expr_ops: constraint_expr_ops_buf.as_ptr() as *const ExprOp,
159            constraint_root_indices: constraint_roots_buf.as_ptr(),
160            max_q_count,
161            expr_meta,
162            max_ast_depth: ast_depth,
163        };
164
165        let meta = vec![meta_host].to_device().unwrap();
166
167        Self {
168            air,
169            records: Arc::new(records),
170            num_records,
171            record_stride,
172            total_trace_width: trace_width as usize,
173            meta,
174            local_opcode_idx_buf,
175            opcode_flag_idx_buf,
176            output_indices_buf,
177            prime_limbs_buf,
178            compute_expr_ops_buf,
179            compute_roots_buf,
180            constraint_expr_ops_buf,
181            constraint_roots_buf,
182            constants_buf,
183            const_limb_counts_buf,
184            q_limb_counts_buf,
185            carry_limb_counts_buf,
186            barrett_mu_buf,
187            range_checker,
188            bitwise_lookup,
189            pointer_max_bits,
190            timestamp_max_bits,
191        }
192    }
193
194    fn build_expr_meta(
195        air: &FieldExpressionCoreAir,
196        num_vars: u32,
197        num_limbs: u32,
198        limb_bits: u32,
199        prime_limbs_buf: &DeviceBuffer<u32>,
200        barrett_mu_buf: &DeviceBuffer<u8>,
201    ) -> (
202        ExprMeta,
203        DeviceBuffer<u128>,
204        DeviceBuffer<u32>,
205        DeviceBuffer<u128>,
206        DeviceBuffer<u32>,
207        DeviceBuffer<u32>,
208        DeviceBuffer<u32>,
209        DeviceBuffer<u32>,
210        DeviceBuffer<u32>,
211        u32,
212        u32,
213    ) {
214        // Build compute expressions AST
215        let mut compute_expr_pool = Vec::new();
216        let mut compute_node_map = HashMap::new();
217        let mut compute_root_indices = Vec::with_capacity(air.expr.builder.computes.len());
218        let mut max_compute_depth = 0;
219        for compute_expr in &air.expr.builder.computes {
220            let depth = Self::calculate_ast_depth(compute_expr);
221            max_compute_depth = max_compute_depth.max(depth);
222            let root =
223                Self::add_expr_to_pool(compute_expr, &mut compute_expr_pool, &mut compute_node_map);
224            compute_root_indices.push(root);
225        }
226
227        // Build constraint expressions AST
228        let mut constraint_expr_pool = Vec::new();
229        let mut constraint_node_map = HashMap::new();
230        let mut constraint_root_indices = Vec::with_capacity(air.expr.builder.constraints.len());
231        let mut max_constraint_depth = 0;
232        for constraint_expr in &air.expr.builder.constraints {
233            let depth = Self::calculate_ast_depth(constraint_expr);
234            max_constraint_depth = max_constraint_depth.max(depth);
235            let root = Self::add_expr_to_pool(
236                constraint_expr,
237                &mut constraint_expr_pool,
238                &mut constraint_node_map,
239            );
240            constraint_root_indices.push(root);
241        }
242
243        let ast_depth = max_compute_depth.max(max_constraint_depth);
244
245        // Extract constants
246        let mut constants = Vec::new();
247        let mut const_limb_counts = Vec::new();
248        for (_const_val, const_limbs) in &air.expr.builder.constants {
249            const_limb_counts.push(const_limbs.len() as u32);
250            for &limb in const_limbs {
251                constants.push(limb as u32);
252            }
253        }
254        let q_counts: Vec<u32> = air.expr.builder.q_limbs.iter().map(|&x| x as u32).collect();
255        let carry_counts: Vec<u32> = air
256            .expr
257            .builder
258            .carry_limbs
259            .iter()
260            .map(|&x| x as u32)
261            .collect();
262
263        let max_q_count = if q_counts.is_empty() {
264            num_limbs + 1 // Default fallback
265        } else {
266            *q_counts.iter().max().unwrap()
267        };
268
269        let compute_expr_ops_u128: Vec<u128> = compute_expr_pool
270            .iter()
271            .map(|n| ExprOp::from_node(n).0)
272            .collect();
273        let constraint_expr_ops_u128: Vec<u128> = constraint_expr_pool
274            .iter()
275            .map(|n| ExprOp::from_node(n).0)
276            .collect();
277
278        let compute_expr_ops_buf = compute_expr_ops_u128.to_device().unwrap();
279        let constraint_expr_ops_buf = constraint_expr_ops_u128.to_device().unwrap();
280
281        let compute_roots_buf = compute_root_indices.to_device().unwrap();
282        let constraint_roots_buf = constraint_root_indices.to_device().unwrap();
283
284        let constants_buf = if constants.is_empty() {
285            DeviceBuffer::new()
286        } else {
287            constants.to_device().unwrap()
288        };
289        let const_limb_counts_buf = if const_limb_counts.is_empty() {
290            DeviceBuffer::new()
291        } else {
292            const_limb_counts.to_device().unwrap()
293        };
294        let q_limb_counts_buf = if q_counts.is_empty() {
295            DeviceBuffer::new()
296        } else {
297            q_counts.to_device().unwrap()
298        };
299        let carry_limb_counts_buf = if carry_counts.is_empty() {
300            DeviceBuffer::new()
301        } else {
302            carry_counts.to_device().unwrap()
303        };
304
305        let expr_meta = ExprMeta {
306            constants: constants_buf.as_ptr(),
307            const_limb_counts: const_limb_counts_buf.as_ptr(),
308            q_limb_counts: q_limb_counts_buf.as_ptr(),
309            carry_limb_counts: carry_limb_counts_buf.as_ptr(),
310            num_vars,
311            num_constants: air.expr.builder.constants.len() as u32,
312            expr_pool_size: compute_expr_pool.len() as u32 + constraint_expr_pool.len() as u32,
313            prime_limbs: prime_limbs_buf.as_ptr(),
314            prime_limb_count: num_limbs,
315            limb_bits,
316            barrett_mu: barrett_mu_buf.as_ptr(),
317        };
318
319        (
320            expr_meta,
321            compute_expr_ops_buf,
322            compute_roots_buf,
323            constraint_expr_ops_buf,
324            constraint_roots_buf,
325            constants_buf,
326            const_limb_counts_buf,
327            q_limb_counts_buf,
328            carry_limb_counts_buf,
329            ast_depth,
330            max_q_count,
331        )
332    }
333
334    fn calculate_ast_depth(expr: &SymbolicExpr) -> u32 {
335        match expr {
336            SymbolicExpr::Input(_) | SymbolicExpr::Var(_) | SymbolicExpr::Const(_, _, _) => 1,
337            SymbolicExpr::Add(left, right)
338            | SymbolicExpr::Sub(left, right)
339            | SymbolicExpr::Mul(left, right)
340            | SymbolicExpr::Div(left, right) => {
341                1 + Self::calculate_ast_depth(left).max(Self::calculate_ast_depth(right))
342            }
343            SymbolicExpr::IntAdd(child, _) | SymbolicExpr::IntMul(child, _) => {
344                1 + Self::calculate_ast_depth(child)
345            }
346            SymbolicExpr::Select(_, if_true, if_false) => {
347                1 + Self::calculate_ast_depth(if_true).max(Self::calculate_ast_depth(if_false))
348            }
349        }
350    }
351
352    fn convert_to_expr_node(
353        expr: &SymbolicExpr,
354        expr_pool: &mut Vec<ExprNode>,
355        node_map: &mut HashMap<String, u32>,
356    ) -> ExprNode {
357        match expr {
358            SymbolicExpr::Input(idx) => ExprNode {
359                r#type: ExprType::Input as u32,
360                data: [*idx as u32, 0, 0],
361            },
362            SymbolicExpr::Var(idx) => ExprNode {
363                r#type: ExprType::Var as u32,
364                data: [*idx as u32, 0, 0],
365            },
366            SymbolicExpr::Const(idx, _val, _limbs) => ExprNode {
367                r#type: ExprType::Const as u32,
368                data: [*idx as u32, 0, 0],
369            },
370            SymbolicExpr::Add(left, right) => {
371                let left_idx = Self::add_expr_to_pool(left, expr_pool, node_map);
372                let right_idx = Self::add_expr_to_pool(right, expr_pool, node_map);
373                ExprNode {
374                    r#type: ExprType::Add as u32,
375                    data: [left_idx, right_idx, 0],
376                }
377            }
378            SymbolicExpr::Sub(left, right) => {
379                let left_idx = Self::add_expr_to_pool(left, expr_pool, node_map);
380                let right_idx = Self::add_expr_to_pool(right, expr_pool, node_map);
381                ExprNode {
382                    r#type: ExprType::Sub as u32,
383                    data: [left_idx, right_idx, 0],
384                }
385            }
386            SymbolicExpr::Mul(left, right) => {
387                let left_idx = Self::add_expr_to_pool(left, expr_pool, node_map);
388                let right_idx = Self::add_expr_to_pool(right, expr_pool, node_map);
389                ExprNode {
390                    r#type: ExprType::Mul as u32,
391                    data: [left_idx, right_idx, 0],
392                }
393            }
394            SymbolicExpr::Div(left, right) => {
395                let left_idx = Self::add_expr_to_pool(left, expr_pool, node_map);
396                let right_idx = Self::add_expr_to_pool(right, expr_pool, node_map);
397                ExprNode {
398                    r#type: ExprType::Div as u32,
399                    data: [left_idx, right_idx, 0],
400                }
401            }
402            SymbolicExpr::IntAdd(child, scalar) => {
403                let child_idx = Self::add_expr_to_pool(child, expr_pool, node_map);
404                ExprNode {
405                    r#type: ExprType::IntAdd as u32,
406                    data: [child_idx, *scalar as u32, 0],
407                }
408            }
409            SymbolicExpr::IntMul(child, scalar) => {
410                let child_idx = Self::add_expr_to_pool(child, expr_pool, node_map);
411                ExprNode {
412                    r#type: ExprType::IntMul as u32,
413                    data: [child_idx, *scalar as u32, 0],
414                }
415            }
416            SymbolicExpr::Select(flag_idx, if_true, if_false) => {
417                let true_idx = Self::add_expr_to_pool(if_true, expr_pool, node_map);
418                let false_idx = Self::add_expr_to_pool(if_false, expr_pool, node_map);
419                ExprNode {
420                    r#type: ExprType::Select as u32,
421                    data: [*flag_idx as u32, true_idx, false_idx],
422                }
423            }
424        }
425    }
426
427    fn add_expr_to_pool(
428        expr: &SymbolicExpr,
429        expr_pool: &mut Vec<ExprNode>,
430        node_map: &mut HashMap<String, u32>,
431    ) -> u32 {
432        let expr_str = format!("{:?}", expr); // Simple deduplication key
433
434        if let Some(&existing_idx) = node_map.get(&expr_str) {
435            return existing_idx;
436        }
437
438        // Create the node based on expression type
439        let node = Self::convert_to_expr_node(expr, expr_pool, node_map);
440
441        let idx = expr_pool.len() as u32;
442        node_map.insert(expr_str, idx);
443
444        expr_pool.push(node);
445        idx
446    }
447
448    pub fn generate_field_trace(&self) -> DeviceMatrix<F> {
449        let padded_height = next_power_of_two_or_zero(self.num_records);
450        let mat = DeviceMatrix::with_capacity(padded_height, self.total_trace_width);
451
452        let meta_host = self.meta.to_host().unwrap()[0].clone();
453        let input_size = meta_host.num_inputs * meta_host.num_limbs;
454        let var_size = meta_host.expr_meta.num_vars * meta_host.num_limbs;
455        let carry_counts = self.carry_limb_counts_buf.to_host().unwrap();
456        let total_carry_count: u32 = carry_counts.iter().sum::<u32>();
457
458        // size in bytes
459        let workspace_per_thread = (input_size + var_size + total_carry_count)
460            * (size_of::<u32>() as u32)
461            + meta_host.num_u32_flags;
462
463        // Align workspace size to 16 bytes for CUDA alignment requirements
464        let workspace_per_thread = workspace_per_thread.next_multiple_of(16);
465
466        // Allocate workspace for all threads
467        let total_workspace_size = (workspace_per_thread as usize)
468            .checked_mul(padded_height)
469            .unwrap();
470        let workspace = DeviceBuffer::<u8>::with_capacity(total_workspace_size);
471
472        unsafe {
473            cudaDeviceSetLimit(cudaLimit::cudaLimitStackSize, 48 * 1024);
474            tracegen(
475                &self.records,
476                mat.buffer(),
477                &self.meta,
478                self.num_records,
479                self.record_stride,
480                self.total_trace_width,
481                padded_height,
482                &self.range_checker.count,
483                &self.bitwise_lookup.count,
484                LIMB_BITS as u32,
485                self.pointer_max_bits,
486                self.timestamp_max_bits,
487                workspace.as_ptr(),
488                workspace_per_thread,
489            )
490            .unwrap();
491        }
492        mat
493    }
494}