openvm_cuda_backend/
fri_log_up.rs

1use std::array;
2
3use itertools::Itertools;
4use openvm_cuda_common::{
5    copy::{MemCopyD2H, MemCopyH2D},
6    d_buffer::DeviceBuffer,
7    error::CudaError,
8    stream::gpu_metrics_span,
9};
10use openvm_stark_backend::{
11    air_builders::symbolic::{
12        symbolic_expression::SymbolicExpression,
13        symbolic_variable::{Entry, SymbolicVariable},
14        SymbolicConstraints, SymbolicConstraintsDag,
15    },
16    interaction::{
17        fri_log_up::{FriLogUpPartialProof, FriLogUpProvingKey, STARK_LU_NUM_CHALLENGES},
18        LogUpSecurityParameters, SymbolicInteraction,
19    },
20    p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger},
21    prover::{hal::MatrixDimensions, types::PairView},
22};
23use p3_field::{FieldAlgebra, FieldExtensionAlgebra};
24
25use crate::{
26    base::DeviceMatrix,
27    cuda::kernels::{permute::*, prefix::*},
28    prelude::*,
29    transpiler::{codec::Codec, SymbolicRulesOnGpu},
30};
31
32// Output format that keeps GPU data as GPU data
33#[derive(Debug)]
34pub struct GpuRapPhaseResult {
35    pub challenges: Vec<EF>,
36    pub after_challenge_trace_per_air: Vec<Option<DeviceMatrix<F>>>,
37    pub exposed_values_per_air: Vec<Option<Vec<EF>>>,
38}
39
40#[derive(Clone, Debug)]
41pub struct FriLogUpPhaseGpu {
42    log_up_params: LogUpSecurityParameters,
43}
44
45impl FriLogUpPhaseGpu {
46    pub fn new(log_up_params: LogUpSecurityParameters) -> Self {
47        assert!(log_up_params.conjectured_bits_of_security::<EF>() >= 100);
48        Self { log_up_params }
49    }
50
51    pub fn partially_prove_gpu(
52        &self,
53        challenger: &mut Challenger,
54        constraints_per_air: &[&SymbolicConstraints<F>],
55        params_per_air: &[&FriLogUpProvingKey],
56        trace_view_per_air: Vec<PairView<DeviceMatrix<F>, F>>,
57    ) -> Option<(FriLogUpPartialProof<F>, GpuRapPhaseResult)> {
58        // 1. Check if there are any interactions - if not, we're done
59        let has_any_interactions = constraints_per_air
60            .iter()
61            .any(|constraints| !constraints.interactions.is_empty());
62
63        if !has_any_interactions {
64            return None;
65        }
66
67        let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits);
68        let challenges: [EF; STARK_LU_NUM_CHALLENGES] =
69            array::from_fn(|_| challenger.sample_ext_element::<EF>());
70
71        let (after_challenge_trace_per_air, cumulative_sum_per_air) =
72            gpu_metrics_span("generate_perm_trace_time_ms", || {
73                self.generate_after_challenge_traces_per_air_gpu(
74                    &challenges,
75                    constraints_per_air,
76                    params_per_air,
77                    trace_view_per_air,
78                )
79            })
80            .unwrap();
81
82        // Challenger needs to observe what is exposed (cumulative_sums)
83        for cumulative_sum in cumulative_sum_per_air.iter().flatten() {
84            let base_slice = <EF as FieldExtensionAlgebra<F>>::as_base_slice(cumulative_sum);
85            challenger.observe_slice(base_slice);
86        }
87
88        let exposed_values_per_air = cumulative_sum_per_air
89            .iter()
90            .map(|csum| csum.map(|csum| vec![csum]))
91            .collect_vec();
92
93        Some((
94            FriLogUpPartialProof { logup_pow_witness },
95            GpuRapPhaseResult {
96                challenges: challenges.to_vec(),
97                after_challenge_trace_per_air,
98                exposed_values_per_air,
99            },
100        ))
101    }
102}
103
104impl FriLogUpPhaseGpu {
105    fn generate_after_challenge_traces_per_air_gpu(
106        &self,
107        challenges: &[EF; STARK_LU_NUM_CHALLENGES],
108        constraints_per_air: &[&SymbolicConstraints<F>],
109        params_per_air: &[&FriLogUpProvingKey],
110        trace_view_per_air: Vec<PairView<DeviceMatrix<F>, F>>,
111    ) -> (Vec<Option<DeviceMatrix<F>>>, Vec<Option<EF>>) {
112        let interaction_partitions = params_per_air
113            .iter()
114            .map(|&params| params.clone().interaction_partitions())
115            .collect_vec();
116
117        constraints_per_air
118            .iter()
119            .zip(trace_view_per_air)
120            .zip(interaction_partitions.iter())
121            .map(|((constraints, trace_view), interaction_partitions)| {
122                self.generate_after_challenge_trace_row_wise_gpu(
123                    &constraints.interactions,
124                    trace_view,
125                    challenges,
126                    interaction_partitions,
127                )
128            })
129            .unzip()
130    }
131
132    fn generate_after_challenge_trace_row_wise_gpu(
133        &self,
134        all_interactions: &[SymbolicInteraction<F>],
135        trace_view: PairView<DeviceMatrix<F>, F>,
136        permutation_randomness: &[EF; STARK_LU_NUM_CHALLENGES],
137        interaction_partitions: &[Vec<usize>],
138    ) -> (Option<DeviceMatrix<F>>, Option<EF>) {
139        if all_interactions.is_empty() {
140            return (None, None);
141        }
142
143        let height = trace_view.partitioned_main[0].height();
144        debug_assert!(
145            trace_view
146                .partitioned_main
147                .iter()
148                .all(|m| m.height() == height),
149            "All main trace parts must have same height"
150        );
151
152        let alphas_len = 1;
153        let &[alpha, beta] = permutation_randomness;
154        // Generate betas
155        let max_fields_len = all_interactions
156            .iter()
157            .map(|interaction| interaction.message.len())
158            .max()
159            .unwrap_or(0);
160        let betas = beta.powers().take(max_fields_len + 1).collect_vec();
161
162        // 0. Prepare challenges
163        let challenges = std::iter::once(&alpha)
164            .chain(betas.iter())
165            .cloned()
166            .collect_vec();
167        let symbolic_challenges: Vec<SymbolicExpression<F>> = (0..challenges.len())
168            .map(|index| SymbolicVariable::<F>::new(Entry::Challenge, index).into())
169            .collect_vec();
170
171        // 1. Generate interactions message as denom = alpha + sum(beta_i * message_i) + beta_{m} *
172        //    b
173        // We use SymbolicInteraction to store (message = [denom], multiplicity = numerator) pair
174        // symbolically.
175        let mut full_interactions: Vec<SymbolicInteraction<F>> = Vec::new();
176        for interaction_indices in interaction_partitions {
177            full_interactions.extend(
178                interaction_indices
179                    .iter()
180                    .map(|&interaction_idx| {
181                        let mut interaction: SymbolicInteraction<F> =
182                            all_interactions[interaction_idx].clone();
183                        let b = SymbolicExpression::from_canonical_u32(
184                            interaction.bus_index as u32 + 1,
185                        );
186                        let betas = symbolic_challenges[alphas_len..].to_vec();
187                        debug_assert!(interaction.message.len() <= betas.len());
188                        let mut fields = interaction.message.iter();
189                        let alpha = symbolic_challenges[0].clone();
190                        let mut denom = alpha + fields.next().unwrap().clone();
191                        for (expr, beta) in fields.zip(betas.iter().skip(1)) {
192                            denom += beta.clone() * expr.clone();
193                        }
194                        denom += betas[interaction.message.len()].clone() * b;
195                        interaction.message = vec![denom];
196                        interaction
197                    })
198                    .collect_vec(),
199            );
200        }
201
202        // 2. Transpile to GPU Rules
203        // We use SymbolicConstraints as a way to encode the symbolic interactions as (denom,
204        // numerator) pairs to transport to GPU.
205        let constraints = SymbolicConstraints {
206            constraints: vec![],
207            interactions: full_interactions,
208        };
209        let constraints_dag: SymbolicConstraintsDag<F> = constraints.into();
210        let rules = SymbolicRulesOnGpu::new(constraints_dag.clone());
211        let encoded_rules = rules.constraints.iter().map(|c| c.encode()).collect_vec();
212
213        // 3. Call GPU module
214        let partition_lens = interaction_partitions
215            .iter()
216            .map(|p| p.len() as u32)
217            .collect_vec();
218        let perm_width = interaction_partitions.len() + 1;
219        let perm_height = height;
220        let (device_matrix, sum) = self.permute_trace_gen_gpu(
221            perm_width * 4, // the dim of base field matrix
222            perm_height,
223            trace_view.preprocessed,
224            trace_view.partitioned_main,
225            &challenges,
226            &encoded_rules,
227            rules.num_intermediates,
228            &partition_lens,
229            &rules.used_nodes,
230        );
231
232        (Some(device_matrix), Some(sum))
233    }
234
235    // gpu-module/src/permute.rs
236    #[allow(clippy::too_many_arguments)]
237    fn permute_trace_gen_gpu(
238        &self,
239        permutation_width: usize,
240        permutation_height: usize,
241        preprocessed: Option<DeviceMatrix<F>>,
242        partitioned_main: Vec<DeviceMatrix<F>>,
243        challenges: &[EF],
244        rules: &[u128],
245        num_intermediates: usize,
246        partition_lens: &[u32],
247        used_nodes: &[usize],
248    ) -> (DeviceMatrix<F>, EF) {
249        assert!(!rules.is_empty(), "No rules provided to permute");
250
251        tracing::debug!(
252            "permute gen rules.len() = {}, num_intermediates = {}",
253            rules.len(),
254            num_intermediates,
255        );
256
257        // 1. input data
258        let null_buffer = DeviceBuffer::<F>::new();
259        let partitioned_main_ptrs = partitioned_main
260            .iter()
261            .map(|m| m.buffer().as_raw_ptr() as u64)
262            .collect_vec();
263        let d_partitioned_main = partitioned_main_ptrs.to_device().unwrap();
264        let d_preprocessed = preprocessed
265            .as_ref()
266            .map(|m| m.buffer())
267            .unwrap_or(&null_buffer);
268
269        // 2. gpu buffers
270        let d_sum = DeviceBuffer::<EF>::with_capacity(1);
271        let d_permutation = DeviceMatrix::<F>::with_capacity(permutation_height, permutation_width);
272        let d_challenges = challenges.to_device().unwrap();
273        let d_rules = rules.to_device().unwrap();
274        let d_partition_lens = partition_lens.to_device().unwrap();
275        let d_used_nodes = used_nodes.to_device().unwrap();
276
277        // 3. hal function
278        let _ = self.hal_permute_trace_gen(
279            &d_sum,
280            d_permutation.buffer(),
281            d_preprocessed,
282            &d_partitioned_main,
283            &d_challenges,
284            &d_rules,
285            rules.len(),
286            num_intermediates,
287            permutation_height,
288            permutation_width / 4,
289            &d_partition_lens,
290            &d_used_nodes,
291        );
292        // We can drop preprocessed and main traces now that permutation trace is generated.
293        // Note these matrices may be smart pointers so they may not be fully deallocated.
294        drop(preprocessed);
295        drop(partitioned_main);
296
297        // 4. output data
298        let h_sum = d_sum.to_host().unwrap()[0];
299        (d_permutation, h_sum)
300    }
301
302    // gpu-backend/src/cuda.rs
303    #[allow(clippy::too_many_arguments)]
304    fn hal_permute_trace_gen(
305        &self,
306        sum: &DeviceBuffer<EF>,
307        permutation: &DeviceBuffer<F>,
308        preprocessed: &DeviceBuffer<F>,
309        main_partitioned: &DeviceBuffer<u64>,
310        challenges: &DeviceBuffer<EF>,
311        rules: &DeviceBuffer<u128>,
312        num_rules: usize,
313        num_intermediates: usize,
314        permutation_height: usize,
315        permutation_width_ext: usize,
316        partition_lens: &DeviceBuffer<u32>,
317        used_nodes: &DeviceBuffer<usize>,
318    ) -> Result<(), CudaError> {
319        let task_size = 65536;
320        let tile_per_thread = (permutation_height as u32).div_ceil(task_size as u32);
321
322        tracing::debug!("permutation_height = {permutation_height}, task_size = {}, tile_per_thread = {} num_rules = {num_rules}", task_size, tile_per_thread);
323
324        let is_global = num_intermediates > 10;
325        let d_intermediates = if is_global {
326            DeviceBuffer::<EF>::with_capacity(task_size * num_intermediates)
327        } else {
328            DeviceBuffer::<EF>::with_capacity(1) // Dummy buffer for register-based version
329        };
330
331        let d_cumulative_sums = DeviceBuffer::<EF>::with_capacity(permutation_height);
332        unsafe {
333            calculate_cumulative_sums(
334                is_global,
335                permutation,
336                &d_cumulative_sums,
337                preprocessed,
338                main_partitioned,
339                challenges,
340                &d_intermediates,
341                rules,
342                used_nodes,
343                partition_lens,
344                partition_lens.len(),
345                permutation_height as u32,
346                permutation_width_ext as u32,
347                tile_per_thread,
348            )
349            .unwrap();
350        }
351
352        self.poly_prefix_sum_ext(&d_cumulative_sums, permutation_height as u64);
353
354        unsafe {
355            permute_update(
356                sum,
357                permutation,
358                &d_cumulative_sums,
359                permutation_height as u32,
360                permutation_width_ext as u32,
361            )
362        }
363    }
364
365    fn poly_prefix_sum_ext(&self, inout: &DeviceBuffer<EF>, count: u64) {
366        // Parameters for the scan
367        let acc_per_thread: u64 = 16;
368        let tiles_per_block: u64 = 256;
369        let element_per_block: u64 = tiles_per_block * acc_per_thread;
370        let mut block_num = (count as u32).div_ceil(tiles_per_block as u32) as u64;
371
372        // First round
373        let mut round_stride = 1_u64;
374        unsafe {
375            prefix_scan_block_ext(inout, count, round_stride, block_num).unwrap();
376        }
377
378        // Subsequent rounds
379        while block_num > 1 {
380            block_num = (block_num as u32).div_ceil(element_per_block as u32) as u64;
381            round_stride *= element_per_block;
382            unsafe {
383                prefix_scan_block_ext(inout, count, round_stride, block_num).unwrap();
384            }
385        }
386
387        // Block downsweep
388        while round_stride > element_per_block {
389            let low_level_round_stride = round_stride / element_per_block;
390            unsafe {
391                prefix_scan_block_downsweep_ext(inout, count, round_stride).unwrap();
392            }
393            round_stride = low_level_round_stride;
394        }
395
396        // Epilogue
397        unsafe {
398            prefix_scan_epilogue_ext(inout, count).unwrap();
399        }
400    }
401}