openvm_cuda_backend/
committer.rs1use openvm_stark_backend::prover::hal::MatrixDimensions;
2use p3_baby_bear::BabyBear;
3use p3_util::log2_strict_usize;
4
5use crate::{base::DeviceMatrix, gpu_device::GpuDevice, lde::GpuLde, merkle_tree::GpuMerkleTree};
6
7impl GpuDevice {
8 pub fn commit_trace<LDE: GpuLde>(
9 &self,
10 trace: DeviceMatrix<BabyBear>,
11 ) -> (Vec<u8>, GpuMerkleTree<LDE>) {
12 let log_height: u8 = log2_strict_usize(trace.height()).try_into().unwrap();
13 let lde = LDE::new(trace, 0, self.config.shift);
14 (
15 vec![log_height],
16 GpuMerkleTree::new(vec![lde], self).unwrap(),
17 )
18 }
19
20 pub fn commit_traces_with_lde<LDE: GpuLde>(
22 &self,
23 traces_with_shifts: Vec<(DeviceMatrix<BabyBear>, BabyBear)>,
24 log_blowup: usize,
25 ) -> (Vec<u8>, GpuMerkleTree<LDE>) {
26 let (log_trace_heights, ldes): (Vec<u8>, Vec<LDE>) = traces_with_shifts
27 .into_iter()
28 .map(|(trace, shift)| {
29 let height = trace.height();
30 let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
31 let lde = LDE::new(trace, log_blowup, shift);
32 (log_height, lde)
33 })
34 .collect();
35
36 (log_trace_heights, GpuMerkleTree::new(ldes, self).unwrap())
37 }
38}