1#![allow(clippy::too_many_arguments)]
2#![allow(clippy::type_complexity)]
3
4use std::{collections::HashMap, sync::Arc};
5
6use cuda_runtime_sys::{cudaDeviceSetLimit, cudaLimit};
7use num_bigint::BigUint;
8use num_traits::{FromBytes, One};
9use openvm_circuit::utils::next_power_of_two_or_zero;
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
12};
13use openvm_cuda_backend::{base::DeviceMatrix, types::F};
14use openvm_cuda_common::{
15 copy::{MemCopyD2H, MemCopyH2D},
16 d_buffer::DeviceBuffer,
17};
18use openvm_stark_backend::p3_air::BaseAir;
19
20use crate::{
21 cuda::{
22 constants::{ExprType, LIMB_BITS, MAX_LIMBS},
23 expr_op::ExprOp,
24 },
25 cuda_abi::field_expression::tracegen,
26 utils::biguint_to_limbs_vec,
27 ExprMeta, ExprNode, FieldExprMeta, FieldExpressionChipGPU, FieldExpressionCoreAir,
28 SymbolicExpr,
29};
30
31impl FieldExpressionChipGPU {
32 pub fn new(
33 air: FieldExpressionCoreAir,
34 records: DeviceBuffer<u8>,
35 num_records: usize,
36 record_stride: usize,
37 adapter_width: usize,
38 adapter_blocks: usize,
39 range_checker: Arc<VariableRangeCheckerChipGPU>,
40 bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<LIMB_BITS>>,
41 pointer_max_bits: u32,
42 timestamp_max_bits: u32,
43 ) -> Self {
44 let num_inputs = air.num_inputs() as u32;
45 let num_vars = air.num_vars() as u32;
46 let num_u32_flags = air.num_flags() as u32;
47 let core_width = BaseAir::<F>::width(&air) as u32;
48 let trace_width = adapter_width as u32 + core_width;
49
50 let num_limbs = air.expr.canonical_num_limbs() as u32;
51 let limb_bits = air.expr.canonical_limb_bits() as u32;
52
53 let prime_limbs = air
54 .expr
55 .builder
56 .prime_limbs
57 .iter()
58 .map(|&x| x as u8)
59 .collect::<Vec<_>>();
60
61 let padded_limbs_len = if air.expr.builder.prime_limbs.len() <= 32 {
63 32
64 } else {
65 48
66 };
67 let mut padded_prime_limbs = air
68 .expr
69 .builder
70 .prime_limbs
71 .iter()
72 .map(|&x| x as u32)
73 .collect::<Vec<_>>();
74 padded_prime_limbs.resize(padded_limbs_len, 0u32);
75
76 let prime_limbs_buf = padded_prime_limbs.to_device().unwrap();
77
78 let p_big: BigUint = BigUint::from_le_bytes(&prime_limbs);
80 let actual_limbs = air.expr.builder.prime_limbs.len();
81 let two_n_bits = 2 * actual_limbs * limb_bits as usize;
82 let b2n = BigUint::one() << two_n_bits;
83 let mu_big = &b2n / &p_big;
84 let mu_limbs = biguint_to_limbs_vec(&mu_big, 2 * MAX_LIMBS);
85 let barrett_mu_buf = mu_limbs.to_device().unwrap();
86
87 let (
88 expr_meta,
89 compute_expr_ops_buf,
90 compute_roots_buf,
91 constraint_expr_ops_buf,
92 constraint_roots_buf,
93 constants_buf,
94 const_limb_counts_buf,
95 q_limb_counts_buf,
96 carry_limb_counts_buf,
97 ast_depth,
98 max_q_count,
99 ) = Self::build_expr_meta(
100 &air,
101 num_vars,
102 num_limbs,
103 limb_bits,
104 &prime_limbs_buf,
105 &barrett_mu_buf,
106 );
107
108 let local_opcode_idx_buf = air
109 .local_opcode_idx
110 .iter()
111 .map(|&x| x as u32)
112 .collect::<Vec<_>>()
113 .to_device()
114 .unwrap();
115 let opcode_flag_idx_buf = if air.opcode_flag_idx.is_empty() {
116 DeviceBuffer::new()
117 } else {
118 air.opcode_flag_idx
119 .iter()
120 .map(|&x| x as u32)
121 .collect::<Vec<_>>()
122 .to_device()
123 .unwrap()
124 };
125 let output_indices_buf = if air.output_indices().is_empty() {
126 DeviceBuffer::new()
127 } else {
128 air.output_indices()
129 .iter()
130 .map(|&x| x as u32)
131 .collect::<Vec<_>>()
132 .to_device()
133 .unwrap()
134 };
135
136 let input_limbs_offset = std::mem::size_of::<u8>();
137
138 let meta_host = FieldExprMeta {
139 num_inputs,
140 num_u32_flags,
141 num_limbs,
142 limb_bits,
143 adapter_blocks: adapter_blocks as u32,
144 adapter_width: adapter_width as u32,
145 core_width,
146 trace_width,
147 local_opcode_idx: local_opcode_idx_buf.as_ptr(),
148 opcode_flag_idx: opcode_flag_idx_buf.as_ptr(),
149 output_indices: output_indices_buf.as_ptr(),
150 num_local_opcodes: air.local_opcode_idx.len() as u32,
151 num_output_indices: air.output_indices().len() as u32,
152 record_stride: record_stride as u32,
153 input_limbs_offset: input_limbs_offset as u32,
154 q_limb_counts: q_limb_counts_buf.as_ptr(),
155 carry_limb_counts: carry_limb_counts_buf.as_ptr(),
156 compute_expr_ops: compute_expr_ops_buf.as_ptr() as *const ExprOp,
157 compute_root_indices: compute_roots_buf.as_ptr(),
158 constraint_expr_ops: constraint_expr_ops_buf.as_ptr() as *const ExprOp,
159 constraint_root_indices: constraint_roots_buf.as_ptr(),
160 max_q_count,
161 expr_meta,
162 max_ast_depth: ast_depth,
163 };
164
165 let meta = vec![meta_host].to_device().unwrap();
166
167 Self {
168 air,
169 records: Arc::new(records),
170 num_records,
171 record_stride,
172 total_trace_width: trace_width as usize,
173 meta,
174 local_opcode_idx_buf,
175 opcode_flag_idx_buf,
176 output_indices_buf,
177 prime_limbs_buf,
178 compute_expr_ops_buf,
179 compute_roots_buf,
180 constraint_expr_ops_buf,
181 constraint_roots_buf,
182 constants_buf,
183 const_limb_counts_buf,
184 q_limb_counts_buf,
185 carry_limb_counts_buf,
186 barrett_mu_buf,
187 range_checker,
188 bitwise_lookup,
189 pointer_max_bits,
190 timestamp_max_bits,
191 }
192 }
193
194 fn build_expr_meta(
195 air: &FieldExpressionCoreAir,
196 num_vars: u32,
197 num_limbs: u32,
198 limb_bits: u32,
199 prime_limbs_buf: &DeviceBuffer<u32>,
200 barrett_mu_buf: &DeviceBuffer<u8>,
201 ) -> (
202 ExprMeta,
203 DeviceBuffer<u128>,
204 DeviceBuffer<u32>,
205 DeviceBuffer<u128>,
206 DeviceBuffer<u32>,
207 DeviceBuffer<u32>,
208 DeviceBuffer<u32>,
209 DeviceBuffer<u32>,
210 DeviceBuffer<u32>,
211 u32,
212 u32,
213 ) {
214 let mut compute_expr_pool = Vec::new();
216 let mut compute_node_map = HashMap::new();
217 let mut compute_root_indices = Vec::with_capacity(air.expr.builder.computes.len());
218 let mut max_compute_depth = 0;
219 for compute_expr in &air.expr.builder.computes {
220 let depth = Self::calculate_ast_depth(compute_expr);
221 max_compute_depth = max_compute_depth.max(depth);
222 let root =
223 Self::add_expr_to_pool(compute_expr, &mut compute_expr_pool, &mut compute_node_map);
224 compute_root_indices.push(root);
225 }
226
227 let mut constraint_expr_pool = Vec::new();
229 let mut constraint_node_map = HashMap::new();
230 let mut constraint_root_indices = Vec::with_capacity(air.expr.builder.constraints.len());
231 let mut max_constraint_depth = 0;
232 for constraint_expr in &air.expr.builder.constraints {
233 let depth = Self::calculate_ast_depth(constraint_expr);
234 max_constraint_depth = max_constraint_depth.max(depth);
235 let root = Self::add_expr_to_pool(
236 constraint_expr,
237 &mut constraint_expr_pool,
238 &mut constraint_node_map,
239 );
240 constraint_root_indices.push(root);
241 }
242
243 let ast_depth = max_compute_depth.max(max_constraint_depth);
244
245 let mut constants = Vec::new();
247 let mut const_limb_counts = Vec::new();
248 for (_const_val, const_limbs) in &air.expr.builder.constants {
249 const_limb_counts.push(const_limbs.len() as u32);
250 for &limb in const_limbs {
251 constants.push(limb as u32);
252 }
253 }
254 let q_counts: Vec<u32> = air.expr.builder.q_limbs.iter().map(|&x| x as u32).collect();
255 let carry_counts: Vec<u32> = air
256 .expr
257 .builder
258 .carry_limbs
259 .iter()
260 .map(|&x| x as u32)
261 .collect();
262
263 let max_q_count = if q_counts.is_empty() {
264 num_limbs + 1 } else {
266 *q_counts.iter().max().unwrap()
267 };
268
269 let compute_expr_ops_u128: Vec<u128> = compute_expr_pool
270 .iter()
271 .map(|n| ExprOp::from_node(n).0)
272 .collect();
273 let constraint_expr_ops_u128: Vec<u128> = constraint_expr_pool
274 .iter()
275 .map(|n| ExprOp::from_node(n).0)
276 .collect();
277
278 let compute_expr_ops_buf = compute_expr_ops_u128.to_device().unwrap();
279 let constraint_expr_ops_buf = constraint_expr_ops_u128.to_device().unwrap();
280
281 let compute_roots_buf = compute_root_indices.to_device().unwrap();
282 let constraint_roots_buf = constraint_root_indices.to_device().unwrap();
283
284 let constants_buf = if constants.is_empty() {
285 DeviceBuffer::new()
286 } else {
287 constants.to_device().unwrap()
288 };
289 let const_limb_counts_buf = if const_limb_counts.is_empty() {
290 DeviceBuffer::new()
291 } else {
292 const_limb_counts.to_device().unwrap()
293 };
294 let q_limb_counts_buf = if q_counts.is_empty() {
295 DeviceBuffer::new()
296 } else {
297 q_counts.to_device().unwrap()
298 };
299 let carry_limb_counts_buf = if carry_counts.is_empty() {
300 DeviceBuffer::new()
301 } else {
302 carry_counts.to_device().unwrap()
303 };
304
305 let expr_meta = ExprMeta {
306 constants: constants_buf.as_ptr(),
307 const_limb_counts: const_limb_counts_buf.as_ptr(),
308 q_limb_counts: q_limb_counts_buf.as_ptr(),
309 carry_limb_counts: carry_limb_counts_buf.as_ptr(),
310 num_vars,
311 num_constants: air.expr.builder.constants.len() as u32,
312 expr_pool_size: compute_expr_pool.len() as u32 + constraint_expr_pool.len() as u32,
313 prime_limbs: prime_limbs_buf.as_ptr(),
314 prime_limb_count: num_limbs,
315 limb_bits,
316 barrett_mu: barrett_mu_buf.as_ptr(),
317 };
318
319 (
320 expr_meta,
321 compute_expr_ops_buf,
322 compute_roots_buf,
323 constraint_expr_ops_buf,
324 constraint_roots_buf,
325 constants_buf,
326 const_limb_counts_buf,
327 q_limb_counts_buf,
328 carry_limb_counts_buf,
329 ast_depth,
330 max_q_count,
331 )
332 }
333
334 fn calculate_ast_depth(expr: &SymbolicExpr) -> u32 {
335 match expr {
336 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) | SymbolicExpr::Const(_, _, _) => 1,
337 SymbolicExpr::Add(left, right)
338 | SymbolicExpr::Sub(left, right)
339 | SymbolicExpr::Mul(left, right)
340 | SymbolicExpr::Div(left, right) => {
341 1 + Self::calculate_ast_depth(left).max(Self::calculate_ast_depth(right))
342 }
343 SymbolicExpr::IntAdd(child, _) | SymbolicExpr::IntMul(child, _) => {
344 1 + Self::calculate_ast_depth(child)
345 }
346 SymbolicExpr::Select(_, if_true, if_false) => {
347 1 + Self::calculate_ast_depth(if_true).max(Self::calculate_ast_depth(if_false))
348 }
349 }
350 }
351
352 fn convert_to_expr_node(
353 expr: &SymbolicExpr,
354 expr_pool: &mut Vec<ExprNode>,
355 node_map: &mut HashMap<String, u32>,
356 ) -> ExprNode {
357 match expr {
358 SymbolicExpr::Input(idx) => ExprNode {
359 r#type: ExprType::Input as u32,
360 data: [*idx as u32, 0, 0],
361 },
362 SymbolicExpr::Var(idx) => ExprNode {
363 r#type: ExprType::Var as u32,
364 data: [*idx as u32, 0, 0],
365 },
366 SymbolicExpr::Const(idx, _val, _limbs) => ExprNode {
367 r#type: ExprType::Const as u32,
368 data: [*idx as u32, 0, 0],
369 },
370 SymbolicExpr::Add(left, right) => {
371 let left_idx = Self::add_expr_to_pool(left, expr_pool, node_map);
372 let right_idx = Self::add_expr_to_pool(right, expr_pool, node_map);
373 ExprNode {
374 r#type: ExprType::Add as u32,
375 data: [left_idx, right_idx, 0],
376 }
377 }
378 SymbolicExpr::Sub(left, right) => {
379 let left_idx = Self::add_expr_to_pool(left, expr_pool, node_map);
380 let right_idx = Self::add_expr_to_pool(right, expr_pool, node_map);
381 ExprNode {
382 r#type: ExprType::Sub as u32,
383 data: [left_idx, right_idx, 0],
384 }
385 }
386 SymbolicExpr::Mul(left, right) => {
387 let left_idx = Self::add_expr_to_pool(left, expr_pool, node_map);
388 let right_idx = Self::add_expr_to_pool(right, expr_pool, node_map);
389 ExprNode {
390 r#type: ExprType::Mul as u32,
391 data: [left_idx, right_idx, 0],
392 }
393 }
394 SymbolicExpr::Div(left, right) => {
395 let left_idx = Self::add_expr_to_pool(left, expr_pool, node_map);
396 let right_idx = Self::add_expr_to_pool(right, expr_pool, node_map);
397 ExprNode {
398 r#type: ExprType::Div as u32,
399 data: [left_idx, right_idx, 0],
400 }
401 }
402 SymbolicExpr::IntAdd(child, scalar) => {
403 let child_idx = Self::add_expr_to_pool(child, expr_pool, node_map);
404 ExprNode {
405 r#type: ExprType::IntAdd as u32,
406 data: [child_idx, *scalar as u32, 0],
407 }
408 }
409 SymbolicExpr::IntMul(child, scalar) => {
410 let child_idx = Self::add_expr_to_pool(child, expr_pool, node_map);
411 ExprNode {
412 r#type: ExprType::IntMul as u32,
413 data: [child_idx, *scalar as u32, 0],
414 }
415 }
416 SymbolicExpr::Select(flag_idx, if_true, if_false) => {
417 let true_idx = Self::add_expr_to_pool(if_true, expr_pool, node_map);
418 let false_idx = Self::add_expr_to_pool(if_false, expr_pool, node_map);
419 ExprNode {
420 r#type: ExprType::Select as u32,
421 data: [*flag_idx as u32, true_idx, false_idx],
422 }
423 }
424 }
425 }
426
427 fn add_expr_to_pool(
428 expr: &SymbolicExpr,
429 expr_pool: &mut Vec<ExprNode>,
430 node_map: &mut HashMap<String, u32>,
431 ) -> u32 {
432 let expr_str = format!("{:?}", expr); if let Some(&existing_idx) = node_map.get(&expr_str) {
435 return existing_idx;
436 }
437
438 let node = Self::convert_to_expr_node(expr, expr_pool, node_map);
440
441 let idx = expr_pool.len() as u32;
442 node_map.insert(expr_str, idx);
443
444 expr_pool.push(node);
445 idx
446 }
447
448 pub fn generate_field_trace(&self) -> DeviceMatrix<F> {
449 let padded_height = next_power_of_two_or_zero(self.num_records);
450 let mat = DeviceMatrix::with_capacity(padded_height, self.total_trace_width);
451
452 let meta_host = self.meta.to_host().unwrap()[0].clone();
453 let input_size = meta_host.num_inputs * meta_host.num_limbs;
454 let var_size = meta_host.expr_meta.num_vars * meta_host.num_limbs;
455 let carry_counts = self.carry_limb_counts_buf.to_host().unwrap();
456 let total_carry_count: u32 = carry_counts.iter().sum::<u32>();
457
458 let workspace_per_thread = (input_size + var_size + total_carry_count)
460 * (size_of::<u32>() as u32)
461 + meta_host.num_u32_flags;
462
463 let workspace_per_thread = workspace_per_thread.next_multiple_of(16);
465
466 let total_workspace_size = (workspace_per_thread as usize)
468 .checked_mul(padded_height)
469 .unwrap();
470 let workspace = DeviceBuffer::<u8>::with_capacity(total_workspace_size);
471
472 unsafe {
473 cudaDeviceSetLimit(cudaLimit::cudaLimitStackSize, 48 * 1024);
474 tracegen(
475 &self.records,
476 mat.buffer(),
477 &self.meta,
478 self.num_records,
479 self.record_stride,
480 self.total_trace_width,
481 padded_height,
482 &self.range_checker.count,
483 &self.bitwise_lookup.count,
484 LIMB_BITS as u32,
485 self.pointer_max_bits,
486 self.timestamp_max_bits,
487 workspace.as_ptr(),
488 workspace_per_thread,
489 )
490 .unwrap();
491 }
492 mat
493 }
494}