openvm_cuda_backend/
committer.rs

1use openvm_stark_backend::prover::hal::MatrixDimensions;
2use p3_baby_bear::BabyBear;
3use p3_util::log2_strict_usize;
4
5use crate::{base::DeviceMatrix, gpu_device::GpuDevice, lde::GpuLde, merkle_tree::GpuMerkleTree};
6
7impl GpuDevice {
8    pub fn commit_trace<LDE: GpuLde>(
9        &self,
10        trace: DeviceMatrix<BabyBear>,
11    ) -> (Vec<u8>, GpuMerkleTree<LDE>) {
12        let log_height: u8 = log2_strict_usize(trace.height()).try_into().unwrap();
13        let lde = LDE::new(trace, 0, self.config.shift);
14        (
15            vec![log_height],
16            GpuMerkleTree::new(vec![lde], self).unwrap(),
17        )
18    }
19
20    /// Commit a trace to a GPU device.
21    pub fn commit_traces_with_lde<LDE: GpuLde>(
22        &self,
23        traces_with_shifts: Vec<(DeviceMatrix<BabyBear>, BabyBear)>,
24        log_blowup: usize,
25    ) -> (Vec<u8>, GpuMerkleTree<LDE>) {
26        let (log_trace_heights, ldes): (Vec<u8>, Vec<LDE>) = traces_with_shifts
27            .into_iter()
28            .map(|(trace, shift)| {
29                let height = trace.height();
30                let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
31                let lde = LDE::new(trace, log_blowup, shift);
32                (log_height, lde)
33            })
34            .collect();
35
36        (log_trace_heights, GpuMerkleTree::new(ldes, self).unwrap())
37    }
38}