openvm_cuda_backend/
fri_log_up.rs

1use std::array;
2
3use itertools::Itertools;
4use openvm_cuda_common::{
5    copy::{MemCopyD2H, MemCopyH2D},
6    d_buffer::DeviceBuffer,
7    error::CudaError,
8    stream::gpu_metrics_span,
9};
10use openvm_stark_backend::{
11    air_builders::symbolic::{
12        symbolic_expression::SymbolicExpression,
13        symbolic_variable::{Entry, SymbolicVariable},
14        SymbolicConstraints, SymbolicConstraintsDag,
15    },
16    interaction::{
17        fri_log_up::{FriLogUpPartialProof, FriLogUpProvingKey, STARK_LU_NUM_CHALLENGES},
18        LogUpSecurityParameters, SymbolicInteraction,
19    },
20    p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger},
21    prover::{hal::MatrixDimensions, types::PairView},
22};
23use p3_field::{BasedVectorSpace, PrimeCharacteristicRing};
24
25use crate::{
26    base::DeviceMatrix,
27    cuda::kernels::{permute::*, prefix::*},
28    prelude::*,
29    transpiler::{codec::Codec, SymbolicRulesOnGpu},
30};
31
32// Output format that keeps GPU data as GPU data
33#[derive(Debug)]
34pub struct GpuRapPhaseResult {
35    pub challenges: Vec<EF>,
36    pub after_challenge_trace_per_air: Vec<Option<DeviceMatrix<F>>>,
37    pub exposed_values_per_air: Vec<Option<Vec<EF>>>,
38}
39
40#[derive(Clone, Debug)]
41pub struct FriLogUpPhaseGpu {
42    log_up_params: LogUpSecurityParameters,
43}
44
45impl FriLogUpPhaseGpu {
46    pub fn new(log_up_params: LogUpSecurityParameters) -> Self {
47        assert!(log_up_params.bits_of_security::<EF>() >= 100);
48        Self { log_up_params }
49    }
50
51    pub fn partially_prove_gpu(
52        &self,
53        challenger: &mut Challenger,
54        constraints_per_air: &[&SymbolicConstraints<F>],
55        params_per_air: &[&FriLogUpProvingKey],
56        trace_view_per_air: Vec<PairView<DeviceMatrix<F>, F>>,
57    ) -> Option<(FriLogUpPartialProof<F>, GpuRapPhaseResult)> {
58        // 1. Check if there are any interactions - if not, we're done
59        let has_any_interactions = constraints_per_air
60            .iter()
61            .any(|constraints| !constraints.interactions.is_empty());
62
63        if !has_any_interactions {
64            return None;
65        }
66
67        let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits);
68        let challenges: [EF; STARK_LU_NUM_CHALLENGES] =
69            array::from_fn(|_| challenger.sample_algebra_element::<EF>());
70
71        let (after_challenge_trace_per_air, cumulative_sum_per_air) =
72            gpu_metrics_span("generate_perm_trace_time_ms", || {
73                self.generate_after_challenge_traces_per_air_gpu(
74                    &challenges,
75                    constraints_per_air,
76                    params_per_air,
77                    trace_view_per_air,
78                )
79            })
80            .unwrap();
81
82        // Challenger needs to observe what is exposed (cumulative_sums)
83        for cumulative_sum in cumulative_sum_per_air.iter().flatten() {
84            let base_slice =
85                <EF as BasedVectorSpace<F>>::as_basis_coefficients_slice(cumulative_sum);
86            challenger.observe_slice(base_slice);
87        }
88
89        let exposed_values_per_air = cumulative_sum_per_air
90            .iter()
91            .map(|csum| csum.map(|csum| vec![csum]))
92            .collect_vec();
93
94        Some((
95            FriLogUpPartialProof { logup_pow_witness },
96            GpuRapPhaseResult {
97                challenges: challenges.to_vec(),
98                after_challenge_trace_per_air,
99                exposed_values_per_air,
100            },
101        ))
102    }
103}
104
105impl FriLogUpPhaseGpu {
106    fn generate_after_challenge_traces_per_air_gpu(
107        &self,
108        challenges: &[EF; STARK_LU_NUM_CHALLENGES],
109        constraints_per_air: &[&SymbolicConstraints<F>],
110        params_per_air: &[&FriLogUpProvingKey],
111        trace_view_per_air: Vec<PairView<DeviceMatrix<F>, F>>,
112    ) -> (Vec<Option<DeviceMatrix<F>>>, Vec<Option<EF>>) {
113        let interaction_partitions = params_per_air
114            .iter()
115            .map(|&params| params.clone().interaction_partitions())
116            .collect_vec();
117
118        constraints_per_air
119            .iter()
120            .zip(trace_view_per_air)
121            .zip(interaction_partitions.iter())
122            .map(|((constraints, trace_view), interaction_partitions)| {
123                self.generate_after_challenge_trace_row_wise_gpu(
124                    &constraints.interactions,
125                    trace_view,
126                    challenges,
127                    interaction_partitions,
128                )
129            })
130            .unzip()
131    }
132
133    fn generate_after_challenge_trace_row_wise_gpu(
134        &self,
135        all_interactions: &[SymbolicInteraction<F>],
136        trace_view: PairView<DeviceMatrix<F>, F>,
137        permutation_randomness: &[EF; STARK_LU_NUM_CHALLENGES],
138        interaction_partitions: &[Vec<usize>],
139    ) -> (Option<DeviceMatrix<F>>, Option<EF>) {
140        if all_interactions.is_empty() {
141            return (None, None);
142        }
143
144        let height = trace_view.partitioned_main[0].height();
145        debug_assert!(
146            trace_view
147                .partitioned_main
148                .iter()
149                .all(|m| m.height() == height),
150            "All main trace parts must have same height"
151        );
152
153        let alphas_len = 1;
154        let &[alpha, beta] = permutation_randomness;
155        // Generate betas
156        let max_fields_len = all_interactions
157            .iter()
158            .map(|interaction| interaction.message.len())
159            .max()
160            .unwrap_or(0);
161        let betas = beta.powers().take(max_fields_len + 1).collect_vec();
162
163        // 0. Prepare challenges
164        let challenges = std::iter::once(&alpha)
165            .chain(betas.iter())
166            .cloned()
167            .collect_vec();
168        let symbolic_challenges: Vec<SymbolicExpression<F>> = (0..challenges.len())
169            .map(|index| SymbolicVariable::<F>::new(Entry::Challenge, index).into())
170            .collect_vec();
171
172        // 1. Generate interactions message as denom = alpha + sum(beta_i * message_i) + beta_{m} *
173        //    b
174        // We use SymbolicInteraction to store (message = [denom], multiplicity = numerator) pair
175        // symbolically.
176        let mut full_interactions: Vec<SymbolicInteraction<F>> = Vec::new();
177        for interaction_indices in interaction_partitions {
178            full_interactions.extend(
179                interaction_indices
180                    .iter()
181                    .map(|&interaction_idx| {
182                        let mut interaction: SymbolicInteraction<F> =
183                            all_interactions[interaction_idx].clone();
184                        let b = SymbolicExpression::from_u32(interaction.bus_index as u32 + 1);
185                        let betas = symbolic_challenges[alphas_len..].to_vec();
186                        debug_assert!(interaction.message.len() <= betas.len());
187                        let mut fields = interaction.message.iter();
188                        let alpha = symbolic_challenges[0].clone();
189                        let mut denom = alpha + fields.next().unwrap().clone();
190                        for (expr, beta) in fields.zip(betas.iter().skip(1)) {
191                            denom += beta.clone() * expr.clone();
192                        }
193                        denom += betas[interaction.message.len()].clone() * b;
194                        interaction.message = vec![denom];
195                        interaction
196                    })
197                    .collect_vec(),
198            );
199        }
200
201        // 2. Transpile to GPU Rules
202        // We use SymbolicConstraints as a way to encode the symbolic interactions as (denom,
203        // numerator) pairs to transport to GPU.
204        let constraints = SymbolicConstraints {
205            constraints: vec![],
206            interactions: full_interactions,
207        };
208        let constraints_dag: SymbolicConstraintsDag<F> = constraints.into();
209        let rules = SymbolicRulesOnGpu::new(constraints_dag.clone(), true);
210        let encoded_rules = rules.constraints.iter().map(|c| c.encode()).collect_vec();
211
212        // 3. Call GPU module
213        let partition_lens = interaction_partitions
214            .iter()
215            .map(|p| p.len() as u32)
216            .collect_vec();
217        let perm_width = interaction_partitions.len() + 1;
218        let perm_height = height;
219        let (device_matrix, sum) = self.permute_trace_gen_gpu(
220            perm_width * 4, // the dim of base field matrix
221            perm_height,
222            trace_view.preprocessed,
223            trace_view.partitioned_main,
224            &challenges,
225            &encoded_rules,
226            rules.buffer_size,
227            &partition_lens,
228            &rules.used_nodes,
229        );
230
231        (Some(device_matrix), Some(sum))
232    }
233
234    // gpu-module/src/permute.rs
235    #[allow(clippy::too_many_arguments)]
236    fn permute_trace_gen_gpu(
237        &self,
238        permutation_width: usize,
239        permutation_height: usize,
240        preprocessed: Option<DeviceMatrix<F>>,
241        partitioned_main: Vec<DeviceMatrix<F>>,
242        challenges: &[EF],
243        rules: &[u128],
244        num_intermediates: usize,
245        partition_lens: &[u32],
246        used_nodes: &[usize],
247    ) -> (DeviceMatrix<F>, EF) {
248        assert!(!rules.is_empty(), "No rules provided to permute");
249
250        tracing::debug!(
251            "permute gen rules.len() = {}, num_intermediates = {}",
252            rules.len(),
253            num_intermediates,
254        );
255
256        // 1. input data
257        let null_buffer = DeviceBuffer::<F>::new();
258        let partitioned_main_ptrs = partitioned_main
259            .iter()
260            .map(|m| m.buffer().as_raw_ptr() as u64)
261            .collect_vec();
262        let d_partitioned_main = partitioned_main_ptrs.to_device().unwrap();
263        let d_preprocessed = preprocessed
264            .as_ref()
265            .map(|m| m.buffer())
266            .unwrap_or(&null_buffer);
267
268        // 2. gpu buffers
269        let d_sum = DeviceBuffer::<EF>::with_capacity(1);
270        let d_permutation = DeviceMatrix::<F>::with_capacity(permutation_height, permutation_width);
271        let d_challenges = challenges.to_device().unwrap();
272        let d_rules = rules.to_device().unwrap();
273        let d_partition_lens = partition_lens.to_device().unwrap();
274        let d_used_nodes = used_nodes.to_device().unwrap();
275
276        // 3. hal function
277        let _ = self.hal_permute_trace_gen(
278            &d_sum,
279            d_permutation.buffer(),
280            d_preprocessed,
281            &d_partitioned_main,
282            &d_challenges,
283            &d_rules,
284            rules.len(),
285            num_intermediates,
286            permutation_height,
287            permutation_width / 4,
288            &d_partition_lens,
289            &d_used_nodes,
290        );
291        // We can drop preprocessed and main traces now that permutation trace is generated.
292        // Note these matrices may be smart pointers so they may not be fully deallocated.
293        drop(preprocessed);
294        drop(partitioned_main);
295
296        // 4. output data
297        let h_sum = d_sum.to_host().unwrap()[0];
298        (d_permutation, h_sum)
299    }
300
301    // gpu-backend/src/cuda.rs
302    #[allow(clippy::too_many_arguments)]
303    fn hal_permute_trace_gen(
304        &self,
305        sum: &DeviceBuffer<EF>,
306        permutation: &DeviceBuffer<F>,
307        preprocessed: &DeviceBuffer<F>,
308        main_partitioned: &DeviceBuffer<u64>,
309        challenges: &DeviceBuffer<EF>,
310        rules: &DeviceBuffer<u128>,
311        num_rules: usize,
312        num_intermediates: usize,
313        permutation_height: usize,
314        permutation_width_ext: usize,
315        partition_lens: &DeviceBuffer<u32>,
316        used_nodes: &DeviceBuffer<usize>,
317    ) -> Result<(), CudaError> {
318        let task_size = 65536;
319        let tile_per_thread = (permutation_height as u32).div_ceil(task_size as u32);
320
321        tracing::debug!("permutation_height = {permutation_height}, task_size = {}, tile_per_thread = {} num_rules = {num_rules}", task_size, tile_per_thread);
322
323        let is_global = num_intermediates > 10;
324        let d_intermediates = if is_global {
325            DeviceBuffer::<EF>::with_capacity(task_size * num_intermediates)
326        } else {
327            DeviceBuffer::<EF>::with_capacity(1) // Dummy buffer for register-based version
328        };
329
330        let d_cumulative_sums = DeviceBuffer::<EF>::with_capacity(permutation_height);
331        unsafe {
332            calculate_cumulative_sums(
333                is_global,
334                permutation,
335                &d_cumulative_sums,
336                preprocessed,
337                main_partitioned,
338                challenges,
339                &d_intermediates,
340                rules,
341                used_nodes,
342                partition_lens,
343                partition_lens.len(),
344                permutation_height as u32,
345                permutation_width_ext as u32,
346                tile_per_thread,
347            )
348            .unwrap();
349        }
350
351        self.poly_prefix_sum_ext(&d_cumulative_sums, permutation_height as u64);
352
353        unsafe {
354            permute_update(
355                sum,
356                permutation,
357                &d_cumulative_sums,
358                permutation_height as u32,
359                permutation_width_ext as u32,
360            )
361        }
362    }
363
364    fn poly_prefix_sum_ext(&self, inout: &DeviceBuffer<EF>, count: u64) {
365        // Parameters for the scan
366        let acc_per_thread: u64 = 16;
367        let tiles_per_block: u64 = 256;
368        let element_per_block: u64 = tiles_per_block * acc_per_thread;
369        let mut block_num = (count as u32).div_ceil(tiles_per_block as u32) as u64;
370
371        // First round
372        let mut round_stride = 1_u64;
373        unsafe {
374            prefix_scan_block_ext(inout, count, round_stride, block_num).unwrap();
375        }
376
377        // Subsequent rounds
378        while block_num > 1 {
379            block_num = (block_num as u32).div_ceil(element_per_block as u32) as u64;
380            round_stride *= element_per_block;
381            unsafe {
382                prefix_scan_block_ext(inout, count, round_stride, block_num).unwrap();
383            }
384        }
385
386        // Block downsweep
387        while round_stride > element_per_block {
388            let low_level_round_stride = round_stride / element_per_block;
389            unsafe {
390                prefix_scan_block_downsweep_ext(inout, count, round_stride).unwrap();
391            }
392            round_stride = low_level_round_stride;
393        }
394
395        // Epilogue
396        unsafe {
397            prefix_scan_epilogue_ext(inout, count).unwrap();
398        }
399    }
400}