1use std::array;
2
3use itertools::Itertools;
4use openvm_cuda_common::{
5 copy::{MemCopyD2H, MemCopyH2D},
6 d_buffer::DeviceBuffer,
7 error::CudaError,
8 stream::gpu_metrics_span,
9};
10use openvm_stark_backend::{
11 air_builders::symbolic::{
12 symbolic_expression::SymbolicExpression,
13 symbolic_variable::{Entry, SymbolicVariable},
14 SymbolicConstraints, SymbolicConstraintsDag,
15 },
16 interaction::{
17 fri_log_up::{FriLogUpPartialProof, FriLogUpProvingKey, STARK_LU_NUM_CHALLENGES},
18 LogUpSecurityParameters, SymbolicInteraction,
19 },
20 p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger},
21 prover::{hal::MatrixDimensions, types::PairView},
22};
23use p3_field::{BasedVectorSpace, PrimeCharacteristicRing};
24
25use crate::{
26 base::DeviceMatrix,
27 cuda::kernels::{permute::*, prefix::*},
28 prelude::*,
29 transpiler::{codec::Codec, SymbolicRulesOnGpu},
30};
31
32#[derive(Debug)]
34pub struct GpuRapPhaseResult {
35 pub challenges: Vec<EF>,
36 pub after_challenge_trace_per_air: Vec<Option<DeviceMatrix<F>>>,
37 pub exposed_values_per_air: Vec<Option<Vec<EF>>>,
38}
39
40#[derive(Clone, Debug)]
41pub struct FriLogUpPhaseGpu {
42 log_up_params: LogUpSecurityParameters,
43}
44
45impl FriLogUpPhaseGpu {
46 pub fn new(log_up_params: LogUpSecurityParameters) -> Self {
47 assert!(log_up_params.bits_of_security::<EF>() >= 100);
48 Self { log_up_params }
49 }
50
51 pub fn partially_prove_gpu(
52 &self,
53 challenger: &mut Challenger,
54 constraints_per_air: &[&SymbolicConstraints<F>],
55 params_per_air: &[&FriLogUpProvingKey],
56 trace_view_per_air: Vec<PairView<DeviceMatrix<F>, F>>,
57 ) -> Option<(FriLogUpPartialProof<F>, GpuRapPhaseResult)> {
58 let has_any_interactions = constraints_per_air
60 .iter()
61 .any(|constraints| !constraints.interactions.is_empty());
62
63 if !has_any_interactions {
64 return None;
65 }
66
67 let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits);
68 let challenges: [EF; STARK_LU_NUM_CHALLENGES] =
69 array::from_fn(|_| challenger.sample_algebra_element::<EF>());
70
71 let (after_challenge_trace_per_air, cumulative_sum_per_air) =
72 gpu_metrics_span("generate_perm_trace_time_ms", || {
73 self.generate_after_challenge_traces_per_air_gpu(
74 &challenges,
75 constraints_per_air,
76 params_per_air,
77 trace_view_per_air,
78 )
79 })
80 .unwrap();
81
82 for cumulative_sum in cumulative_sum_per_air.iter().flatten() {
84 let base_slice =
85 <EF as BasedVectorSpace<F>>::as_basis_coefficients_slice(cumulative_sum);
86 challenger.observe_slice(base_slice);
87 }
88
89 let exposed_values_per_air = cumulative_sum_per_air
90 .iter()
91 .map(|csum| csum.map(|csum| vec![csum]))
92 .collect_vec();
93
94 Some((
95 FriLogUpPartialProof { logup_pow_witness },
96 GpuRapPhaseResult {
97 challenges: challenges.to_vec(),
98 after_challenge_trace_per_air,
99 exposed_values_per_air,
100 },
101 ))
102 }
103}
104
105impl FriLogUpPhaseGpu {
106 fn generate_after_challenge_traces_per_air_gpu(
107 &self,
108 challenges: &[EF; STARK_LU_NUM_CHALLENGES],
109 constraints_per_air: &[&SymbolicConstraints<F>],
110 params_per_air: &[&FriLogUpProvingKey],
111 trace_view_per_air: Vec<PairView<DeviceMatrix<F>, F>>,
112 ) -> (Vec<Option<DeviceMatrix<F>>>, Vec<Option<EF>>) {
113 let interaction_partitions = params_per_air
114 .iter()
115 .map(|¶ms| params.clone().interaction_partitions())
116 .collect_vec();
117
118 constraints_per_air
119 .iter()
120 .zip(trace_view_per_air)
121 .zip(interaction_partitions.iter())
122 .map(|((constraints, trace_view), interaction_partitions)| {
123 self.generate_after_challenge_trace_row_wise_gpu(
124 &constraints.interactions,
125 trace_view,
126 challenges,
127 interaction_partitions,
128 )
129 })
130 .unzip()
131 }
132
133 fn generate_after_challenge_trace_row_wise_gpu(
134 &self,
135 all_interactions: &[SymbolicInteraction<F>],
136 trace_view: PairView<DeviceMatrix<F>, F>,
137 permutation_randomness: &[EF; STARK_LU_NUM_CHALLENGES],
138 interaction_partitions: &[Vec<usize>],
139 ) -> (Option<DeviceMatrix<F>>, Option<EF>) {
140 if all_interactions.is_empty() {
141 return (None, None);
142 }
143
144 let height = trace_view.partitioned_main[0].height();
145 debug_assert!(
146 trace_view
147 .partitioned_main
148 .iter()
149 .all(|m| m.height() == height),
150 "All main trace parts must have same height"
151 );
152
153 let alphas_len = 1;
154 let &[alpha, beta] = permutation_randomness;
155 let max_fields_len = all_interactions
157 .iter()
158 .map(|interaction| interaction.message.len())
159 .max()
160 .unwrap_or(0);
161 let betas = beta.powers().take(max_fields_len + 1).collect_vec();
162
163 let challenges = std::iter::once(&alpha)
165 .chain(betas.iter())
166 .cloned()
167 .collect_vec();
168 let symbolic_challenges: Vec<SymbolicExpression<F>> = (0..challenges.len())
169 .map(|index| SymbolicVariable::<F>::new(Entry::Challenge, index).into())
170 .collect_vec();
171
172 let mut full_interactions: Vec<SymbolicInteraction<F>> = Vec::new();
177 for interaction_indices in interaction_partitions {
178 full_interactions.extend(
179 interaction_indices
180 .iter()
181 .map(|&interaction_idx| {
182 let mut interaction: SymbolicInteraction<F> =
183 all_interactions[interaction_idx].clone();
184 let b = SymbolicExpression::from_u32(interaction.bus_index as u32 + 1);
185 let betas = symbolic_challenges[alphas_len..].to_vec();
186 debug_assert!(interaction.message.len() <= betas.len());
187 let mut fields = interaction.message.iter();
188 let alpha = symbolic_challenges[0].clone();
189 let mut denom = alpha + fields.next().unwrap().clone();
190 for (expr, beta) in fields.zip(betas.iter().skip(1)) {
191 denom += beta.clone() * expr.clone();
192 }
193 denom += betas[interaction.message.len()].clone() * b;
194 interaction.message = vec![denom];
195 interaction
196 })
197 .collect_vec(),
198 );
199 }
200
201 let constraints = SymbolicConstraints {
205 constraints: vec![],
206 interactions: full_interactions,
207 };
208 let constraints_dag: SymbolicConstraintsDag<F> = constraints.into();
209 let rules = SymbolicRulesOnGpu::new(constraints_dag.clone(), true);
210 let encoded_rules = rules.constraints.iter().map(|c| c.encode()).collect_vec();
211
212 let partition_lens = interaction_partitions
214 .iter()
215 .map(|p| p.len() as u32)
216 .collect_vec();
217 let perm_width = interaction_partitions.len() + 1;
218 let perm_height = height;
219 let (device_matrix, sum) = self.permute_trace_gen_gpu(
220 perm_width * 4, perm_height,
222 trace_view.preprocessed,
223 trace_view.partitioned_main,
224 &challenges,
225 &encoded_rules,
226 rules.buffer_size,
227 &partition_lens,
228 &rules.used_nodes,
229 );
230
231 (Some(device_matrix), Some(sum))
232 }
233
234 #[allow(clippy::too_many_arguments)]
236 fn permute_trace_gen_gpu(
237 &self,
238 permutation_width: usize,
239 permutation_height: usize,
240 preprocessed: Option<DeviceMatrix<F>>,
241 partitioned_main: Vec<DeviceMatrix<F>>,
242 challenges: &[EF],
243 rules: &[u128],
244 num_intermediates: usize,
245 partition_lens: &[u32],
246 used_nodes: &[usize],
247 ) -> (DeviceMatrix<F>, EF) {
248 assert!(!rules.is_empty(), "No rules provided to permute");
249
250 tracing::debug!(
251 "permute gen rules.len() = {}, num_intermediates = {}",
252 rules.len(),
253 num_intermediates,
254 );
255
256 let null_buffer = DeviceBuffer::<F>::new();
258 let partitioned_main_ptrs = partitioned_main
259 .iter()
260 .map(|m| m.buffer().as_raw_ptr() as u64)
261 .collect_vec();
262 let d_partitioned_main = partitioned_main_ptrs.to_device().unwrap();
263 let d_preprocessed = preprocessed
264 .as_ref()
265 .map(|m| m.buffer())
266 .unwrap_or(&null_buffer);
267
268 let d_sum = DeviceBuffer::<EF>::with_capacity(1);
270 let d_permutation = DeviceMatrix::<F>::with_capacity(permutation_height, permutation_width);
271 let d_challenges = challenges.to_device().unwrap();
272 let d_rules = rules.to_device().unwrap();
273 let d_partition_lens = partition_lens.to_device().unwrap();
274 let d_used_nodes = used_nodes.to_device().unwrap();
275
276 let _ = self.hal_permute_trace_gen(
278 &d_sum,
279 d_permutation.buffer(),
280 d_preprocessed,
281 &d_partitioned_main,
282 &d_challenges,
283 &d_rules,
284 rules.len(),
285 num_intermediates,
286 permutation_height,
287 permutation_width / 4,
288 &d_partition_lens,
289 &d_used_nodes,
290 );
291 drop(preprocessed);
294 drop(partitioned_main);
295
296 let h_sum = d_sum.to_host().unwrap()[0];
298 (d_permutation, h_sum)
299 }
300
301 #[allow(clippy::too_many_arguments)]
303 fn hal_permute_trace_gen(
304 &self,
305 sum: &DeviceBuffer<EF>,
306 permutation: &DeviceBuffer<F>,
307 preprocessed: &DeviceBuffer<F>,
308 main_partitioned: &DeviceBuffer<u64>,
309 challenges: &DeviceBuffer<EF>,
310 rules: &DeviceBuffer<u128>,
311 num_rules: usize,
312 num_intermediates: usize,
313 permutation_height: usize,
314 permutation_width_ext: usize,
315 partition_lens: &DeviceBuffer<u32>,
316 used_nodes: &DeviceBuffer<usize>,
317 ) -> Result<(), CudaError> {
318 let task_size = 65536;
319 let tile_per_thread = (permutation_height as u32).div_ceil(task_size as u32);
320
321 tracing::debug!("permutation_height = {permutation_height}, task_size = {}, tile_per_thread = {} num_rules = {num_rules}", task_size, tile_per_thread);
322
323 let is_global = num_intermediates > 10;
324 let d_intermediates = if is_global {
325 DeviceBuffer::<EF>::with_capacity(task_size * num_intermediates)
326 } else {
327 DeviceBuffer::<EF>::with_capacity(1) };
329
330 let d_cumulative_sums = DeviceBuffer::<EF>::with_capacity(permutation_height);
331 unsafe {
332 calculate_cumulative_sums(
333 is_global,
334 permutation,
335 &d_cumulative_sums,
336 preprocessed,
337 main_partitioned,
338 challenges,
339 &d_intermediates,
340 rules,
341 used_nodes,
342 partition_lens,
343 partition_lens.len(),
344 permutation_height as u32,
345 permutation_width_ext as u32,
346 tile_per_thread,
347 )
348 .unwrap();
349 }
350
351 self.poly_prefix_sum_ext(&d_cumulative_sums, permutation_height as u64);
352
353 unsafe {
354 permute_update(
355 sum,
356 permutation,
357 &d_cumulative_sums,
358 permutation_height as u32,
359 permutation_width_ext as u32,
360 )
361 }
362 }
363
364 fn poly_prefix_sum_ext(&self, inout: &DeviceBuffer<EF>, count: u64) {
365 let acc_per_thread: u64 = 16;
367 let tiles_per_block: u64 = 256;
368 let element_per_block: u64 = tiles_per_block * acc_per_thread;
369 let mut block_num = (count as u32).div_ceil(tiles_per_block as u32) as u64;
370
371 let mut round_stride = 1_u64;
373 unsafe {
374 prefix_scan_block_ext(inout, count, round_stride, block_num).unwrap();
375 }
376
377 while block_num > 1 {
379 block_num = (block_num as u32).div_ceil(element_per_block as u32) as u64;
380 round_stride *= element_per_block;
381 unsafe {
382 prefix_scan_block_ext(inout, count, round_stride, block_num).unwrap();
383 }
384 }
385
386 while round_stride > element_per_block {
388 let low_level_round_stride = round_stride / element_per_block;
389 unsafe {
390 prefix_scan_block_downsweep_ext(inout, count, round_stride).unwrap();
391 }
392 round_stride = low_level_round_stride;
393 }
394
395 unsafe {
397 prefix_scan_epilogue_ext(inout, count).unwrap();
398 }
399 }
400}