1use std::array;
2
3use itertools::Itertools;
4use openvm_cuda_common::{
5 copy::{MemCopyD2H, MemCopyH2D},
6 d_buffer::DeviceBuffer,
7 error::CudaError,
8 stream::gpu_metrics_span,
9};
10use openvm_stark_backend::{
11 air_builders::symbolic::{
12 symbolic_expression::SymbolicExpression,
13 symbolic_variable::{Entry, SymbolicVariable},
14 SymbolicConstraints, SymbolicConstraintsDag,
15 },
16 interaction::{
17 fri_log_up::{FriLogUpPartialProof, FriLogUpProvingKey, STARK_LU_NUM_CHALLENGES},
18 LogUpSecurityParameters, SymbolicInteraction,
19 },
20 p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger},
21 prover::{hal::MatrixDimensions, types::PairView},
22};
23use p3_field::{FieldAlgebra, FieldExtensionAlgebra};
24
25use crate::{
26 base::DeviceMatrix,
27 cuda::kernels::{permute::*, prefix::*},
28 prelude::*,
29 transpiler::{codec::Codec, SymbolicRulesOnGpu},
30};
31
32#[derive(Debug)]
34pub struct GpuRapPhaseResult {
35 pub challenges: Vec<EF>,
36 pub after_challenge_trace_per_air: Vec<Option<DeviceMatrix<F>>>,
37 pub exposed_values_per_air: Vec<Option<Vec<EF>>>,
38}
39
40#[derive(Clone, Debug)]
41pub struct FriLogUpPhaseGpu {
42 log_up_params: LogUpSecurityParameters,
43}
44
45impl FriLogUpPhaseGpu {
46 pub fn new(log_up_params: LogUpSecurityParameters) -> Self {
47 assert!(log_up_params.conjectured_bits_of_security::<EF>() >= 100);
48 Self { log_up_params }
49 }
50
51 pub fn partially_prove_gpu(
52 &self,
53 challenger: &mut Challenger,
54 constraints_per_air: &[&SymbolicConstraints<F>],
55 params_per_air: &[&FriLogUpProvingKey],
56 trace_view_per_air: Vec<PairView<DeviceMatrix<F>, F>>,
57 ) -> Option<(FriLogUpPartialProof<F>, GpuRapPhaseResult)> {
58 let has_any_interactions = constraints_per_air
60 .iter()
61 .any(|constraints| !constraints.interactions.is_empty());
62
63 if !has_any_interactions {
64 return None;
65 }
66
67 let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits);
68 let challenges: [EF; STARK_LU_NUM_CHALLENGES] =
69 array::from_fn(|_| challenger.sample_ext_element::<EF>());
70
71 let (after_challenge_trace_per_air, cumulative_sum_per_air) =
72 gpu_metrics_span("generate_perm_trace_time_ms", || {
73 self.generate_after_challenge_traces_per_air_gpu(
74 &challenges,
75 constraints_per_air,
76 params_per_air,
77 trace_view_per_air,
78 )
79 })
80 .unwrap();
81
82 for cumulative_sum in cumulative_sum_per_air.iter().flatten() {
84 let base_slice = <EF as FieldExtensionAlgebra<F>>::as_base_slice(cumulative_sum);
85 challenger.observe_slice(base_slice);
86 }
87
88 let exposed_values_per_air = cumulative_sum_per_air
89 .iter()
90 .map(|csum| csum.map(|csum| vec![csum]))
91 .collect_vec();
92
93 Some((
94 FriLogUpPartialProof { logup_pow_witness },
95 GpuRapPhaseResult {
96 challenges: challenges.to_vec(),
97 after_challenge_trace_per_air,
98 exposed_values_per_air,
99 },
100 ))
101 }
102}
103
104impl FriLogUpPhaseGpu {
105 fn generate_after_challenge_traces_per_air_gpu(
106 &self,
107 challenges: &[EF; STARK_LU_NUM_CHALLENGES],
108 constraints_per_air: &[&SymbolicConstraints<F>],
109 params_per_air: &[&FriLogUpProvingKey],
110 trace_view_per_air: Vec<PairView<DeviceMatrix<F>, F>>,
111 ) -> (Vec<Option<DeviceMatrix<F>>>, Vec<Option<EF>>) {
112 let interaction_partitions = params_per_air
113 .iter()
114 .map(|¶ms| params.clone().interaction_partitions())
115 .collect_vec();
116
117 constraints_per_air
118 .iter()
119 .zip(trace_view_per_air)
120 .zip(interaction_partitions.iter())
121 .map(|((constraints, trace_view), interaction_partitions)| {
122 self.generate_after_challenge_trace_row_wise_gpu(
123 &constraints.interactions,
124 trace_view,
125 challenges,
126 interaction_partitions,
127 )
128 })
129 .unzip()
130 }
131
132 fn generate_after_challenge_trace_row_wise_gpu(
133 &self,
134 all_interactions: &[SymbolicInteraction<F>],
135 trace_view: PairView<DeviceMatrix<F>, F>,
136 permutation_randomness: &[EF; STARK_LU_NUM_CHALLENGES],
137 interaction_partitions: &[Vec<usize>],
138 ) -> (Option<DeviceMatrix<F>>, Option<EF>) {
139 if all_interactions.is_empty() {
140 return (None, None);
141 }
142
143 let height = trace_view.partitioned_main[0].height();
144 debug_assert!(
145 trace_view
146 .partitioned_main
147 .iter()
148 .all(|m| m.height() == height),
149 "All main trace parts must have same height"
150 );
151
152 let alphas_len = 1;
153 let &[alpha, beta] = permutation_randomness;
154 let max_fields_len = all_interactions
156 .iter()
157 .map(|interaction| interaction.message.len())
158 .max()
159 .unwrap_or(0);
160 let betas = beta.powers().take(max_fields_len + 1).collect_vec();
161
162 let challenges = std::iter::once(&alpha)
164 .chain(betas.iter())
165 .cloned()
166 .collect_vec();
167 let symbolic_challenges: Vec<SymbolicExpression<F>> = (0..challenges.len())
168 .map(|index| SymbolicVariable::<F>::new(Entry::Challenge, index).into())
169 .collect_vec();
170
171 let mut full_interactions: Vec<SymbolicInteraction<F>> = Vec::new();
176 for interaction_indices in interaction_partitions {
177 full_interactions.extend(
178 interaction_indices
179 .iter()
180 .map(|&interaction_idx| {
181 let mut interaction: SymbolicInteraction<F> =
182 all_interactions[interaction_idx].clone();
183 let b = SymbolicExpression::from_canonical_u32(
184 interaction.bus_index as u32 + 1,
185 );
186 let betas = symbolic_challenges[alphas_len..].to_vec();
187 debug_assert!(interaction.message.len() <= betas.len());
188 let mut fields = interaction.message.iter();
189 let alpha = symbolic_challenges[0].clone();
190 let mut denom = alpha + fields.next().unwrap().clone();
191 for (expr, beta) in fields.zip(betas.iter().skip(1)) {
192 denom += beta.clone() * expr.clone();
193 }
194 denom += betas[interaction.message.len()].clone() * b;
195 interaction.message = vec![denom];
196 interaction
197 })
198 .collect_vec(),
199 );
200 }
201
202 let constraints = SymbolicConstraints {
206 constraints: vec![],
207 interactions: full_interactions,
208 };
209 let constraints_dag: SymbolicConstraintsDag<F> = constraints.into();
210 let rules = SymbolicRulesOnGpu::new(constraints_dag.clone());
211 let encoded_rules = rules.constraints.iter().map(|c| c.encode()).collect_vec();
212
213 let partition_lens = interaction_partitions
215 .iter()
216 .map(|p| p.len() as u32)
217 .collect_vec();
218 let perm_width = interaction_partitions.len() + 1;
219 let perm_height = height;
220 let (device_matrix, sum) = self.permute_trace_gen_gpu(
221 perm_width * 4, perm_height,
223 trace_view.preprocessed,
224 trace_view.partitioned_main,
225 &challenges,
226 &encoded_rules,
227 rules.num_intermediates,
228 &partition_lens,
229 &rules.used_nodes,
230 );
231
232 (Some(device_matrix), Some(sum))
233 }
234
235 #[allow(clippy::too_many_arguments)]
237 fn permute_trace_gen_gpu(
238 &self,
239 permutation_width: usize,
240 permutation_height: usize,
241 preprocessed: Option<DeviceMatrix<F>>,
242 partitioned_main: Vec<DeviceMatrix<F>>,
243 challenges: &[EF],
244 rules: &[u128],
245 num_intermediates: usize,
246 partition_lens: &[u32],
247 used_nodes: &[usize],
248 ) -> (DeviceMatrix<F>, EF) {
249 assert!(!rules.is_empty(), "No rules provided to permute");
250
251 tracing::debug!(
252 "permute gen rules.len() = {}, num_intermediates = {}",
253 rules.len(),
254 num_intermediates,
255 );
256
257 let null_buffer = DeviceBuffer::<F>::new();
259 let partitioned_main_ptrs = partitioned_main
260 .iter()
261 .map(|m| m.buffer().as_raw_ptr() as u64)
262 .collect_vec();
263 let d_partitioned_main = partitioned_main_ptrs.to_device().unwrap();
264 let d_preprocessed = preprocessed
265 .as_ref()
266 .map(|m| m.buffer())
267 .unwrap_or(&null_buffer);
268
269 let d_sum = DeviceBuffer::<EF>::with_capacity(1);
271 let d_permutation = DeviceMatrix::<F>::with_capacity(permutation_height, permutation_width);
272 let d_challenges = challenges.to_device().unwrap();
273 let d_rules = rules.to_device().unwrap();
274 let d_partition_lens = partition_lens.to_device().unwrap();
275 let d_used_nodes = used_nodes.to_device().unwrap();
276
277 let _ = self.hal_permute_trace_gen(
279 &d_sum,
280 d_permutation.buffer(),
281 d_preprocessed,
282 &d_partitioned_main,
283 &d_challenges,
284 &d_rules,
285 rules.len(),
286 num_intermediates,
287 permutation_height,
288 permutation_width / 4,
289 &d_partition_lens,
290 &d_used_nodes,
291 );
292 drop(preprocessed);
295 drop(partitioned_main);
296
297 let h_sum = d_sum.to_host().unwrap()[0];
299 (d_permutation, h_sum)
300 }
301
302 #[allow(clippy::too_many_arguments)]
304 fn hal_permute_trace_gen(
305 &self,
306 sum: &DeviceBuffer<EF>,
307 permutation: &DeviceBuffer<F>,
308 preprocessed: &DeviceBuffer<F>,
309 main_partitioned: &DeviceBuffer<u64>,
310 challenges: &DeviceBuffer<EF>,
311 rules: &DeviceBuffer<u128>,
312 num_rules: usize,
313 num_intermediates: usize,
314 permutation_height: usize,
315 permutation_width_ext: usize,
316 partition_lens: &DeviceBuffer<u32>,
317 used_nodes: &DeviceBuffer<usize>,
318 ) -> Result<(), CudaError> {
319 let task_size = 65536;
320 let tile_per_thread = (permutation_height as u32).div_ceil(task_size as u32);
321
322 tracing::debug!("permutation_height = {permutation_height}, task_size = {}, tile_per_thread = {} num_rules = {num_rules}", task_size, tile_per_thread);
323
324 let is_global = num_intermediates > 10;
325 let d_intermediates = if is_global {
326 DeviceBuffer::<EF>::with_capacity(task_size * num_intermediates)
327 } else {
328 DeviceBuffer::<EF>::with_capacity(1) };
330
331 let d_cumulative_sums = DeviceBuffer::<EF>::with_capacity(permutation_height);
332 unsafe {
333 calculate_cumulative_sums(
334 is_global,
335 permutation,
336 &d_cumulative_sums,
337 preprocessed,
338 main_partitioned,
339 challenges,
340 &d_intermediates,
341 rules,
342 used_nodes,
343 partition_lens,
344 partition_lens.len(),
345 permutation_height as u32,
346 permutation_width_ext as u32,
347 tile_per_thread,
348 )
349 .unwrap();
350 }
351
352 self.poly_prefix_sum_ext(&d_cumulative_sums, permutation_height as u64);
353
354 unsafe {
355 permute_update(
356 sum,
357 permutation,
358 &d_cumulative_sums,
359 permutation_height as u32,
360 permutation_width_ext as u32,
361 )
362 }
363 }
364
365 fn poly_prefix_sum_ext(&self, inout: &DeviceBuffer<EF>, count: u64) {
366 let acc_per_thread: u64 = 16;
368 let tiles_per_block: u64 = 256;
369 let element_per_block: u64 = tiles_per_block * acc_per_thread;
370 let mut block_num = (count as u32).div_ceil(tiles_per_block as u32) as u64;
371
372 let mut round_stride = 1_u64;
374 unsafe {
375 prefix_scan_block_ext(inout, count, round_stride, block_num).unwrap();
376 }
377
378 while block_num > 1 {
380 block_num = (block_num as u32).div_ceil(element_per_block as u32) as u64;
381 round_stride *= element_per_block;
382 unsafe {
383 prefix_scan_block_ext(inout, count, round_stride, block_num).unwrap();
384 }
385 }
386
387 while round_stride > element_per_block {
389 let low_level_round_stride = round_stride / element_per_block;
390 unsafe {
391 prefix_scan_block_downsweep_ext(inout, count, round_stride).unwrap();
392 }
393 round_stride = low_level_round_stride;
394 }
395
396 unsafe {
398 prefix_scan_epilogue_ext(inout, count).unwrap();
399 }
400 }
401}