openvm_cuda_backend/
data_transporter.rs

1use std::{fmt::Debug, sync::Arc};
2
3use openvm_cuda_common::{
4    copy::{MemCopyD2H, MemCopyH2D},
5    d_buffer::DeviceBuffer,
6};
7use openvm_stark_backend::{
8    config::{Com, PcsProverData, Val},
9    keygen::types::MultiStarkProvingKey,
10    prover::{
11        hal::{DeviceDataTransporter, MatrixDimensions, TraceCommitter},
12        types::{
13            CommittedTraceData, DeviceMultiStarkProvingKey, DeviceStarkProvingKey,
14            SingleCommitPreimage,
15        },
16    },
17};
18use p3_matrix::{dense::RowMajorMatrix, Matrix};
19
20use crate::{
21    base::DeviceMatrix,
22    cuda::kernels::matrix::matrix_transpose,
23    gpu_device::GpuDevice,
24    prelude::{F, SC},
25    prover_backend::GpuBackend,
26};
27
28impl DeviceDataTransporter<SC, GpuBackend> for GpuDevice {
29    fn transport_pk_to_device(
30        &self,
31        mpk: &MultiStarkProvingKey<SC>,
32    ) -> DeviceMultiStarkProvingKey<GpuBackend> {
33        let per_air = mpk
34            .per_air
35            .iter()
36            .map(|pk| {
37                let preprocessed_data = pk.preprocessed_data.as_ref().map(|pd| {
38                    let trace = self.transport_matrix_to_device(&pd.trace);
39                    let (_, data) = self.commit(&[trace.clone()]);
40                    SingleCommitPreimage {
41                        trace,
42                        data,
43                        matrix_idx: 0,
44                    }
45                });
46
47                DeviceStarkProvingKey {
48                    air_name: pk.air_name.clone(),
49                    vk: pk.vk.clone(),
50                    preprocessed_data,
51                    rap_partial_pk: pk.rap_partial_pk.clone(),
52                }
53            })
54            .collect();
55
56        DeviceMultiStarkProvingKey::new(
57            per_air,
58            mpk.trace_height_constraints.clone(),
59            mpk.vk_pre_hash,
60        )
61    }
62
63    fn transport_matrix_to_device(&self, matrix: &Arc<RowMajorMatrix<F>>) -> DeviceMatrix<F> {
64        transport_matrix_to_device(matrix.clone())
65    }
66
67    /// We ignore the host prover data because it's faster to just re-commit on GPU instead of doing
68    /// H2D transfer.
69    fn transport_committed_trace_to_device(
70        &self,
71        commitment: Com<SC>,
72        trace: &Arc<RowMajorMatrix<Val<SC>>>,
73        _: &Arc<PcsProverData<SC>>,
74    ) -> CommittedTraceData<GpuBackend> {
75        let trace = self.transport_matrix_to_device(trace);
76        let (d_commitment, data) = self.commit(&[trace.clone()]);
77        assert_eq!(
78            d_commitment, commitment,
79            "GPU commitment does not match host"
80        );
81        CommittedTraceData {
82            commitment,
83            trace,
84            data,
85        }
86    }
87
88    fn transport_matrix_from_device_to_host(
89        &self,
90        matrix: &DeviceMatrix<F>,
91    ) -> Arc<RowMajorMatrix<F>> {
92        let matrix_host = transport_device_matrix_to_host(matrix);
93        Arc::new(matrix_host)
94    }
95}
96
97pub fn transport_matrix_to_device(matrix: Arc<RowMajorMatrix<F>>) -> DeviceMatrix<F> {
98    let data = matrix.values.as_slice();
99    let input_buffer = data.to_device().unwrap();
100    let output = DeviceMatrix::<F>::with_capacity(matrix.height(), matrix.width());
101    unsafe {
102        matrix_transpose::<F>(
103            output.buffer(),
104            &input_buffer,
105            matrix.width(),
106            matrix.height(),
107        )
108        .unwrap();
109    }
110    assert_eq!(output.strong_count(), 1);
111    output
112}
113
114pub fn transport_device_matrix_to_host<T: Clone + Send + Sync>(
115    matrix: &DeviceMatrix<T>,
116) -> RowMajorMatrix<T> {
117    let matrix_buffer = DeviceBuffer::<T>::with_capacity(matrix.height() * matrix.width());
118    unsafe {
119        matrix_transpose::<T>(
120            &matrix_buffer,
121            matrix.buffer(),
122            matrix.height(),
123            matrix.width(),
124        )
125        .unwrap();
126    }
127    RowMajorMatrix::<T>::new(matrix_buffer.to_host().unwrap(), matrix.width())
128}
129
130pub fn assert_eq_device_matrix<T: Clone + Send + Sync + PartialEq + Debug>(
131    a: &DeviceMatrix<T>,
132    b: &DeviceMatrix<T>,
133) {
134    assert_eq!(a.height(), b.height());
135    assert_eq!(a.width(), b.width());
136    assert_eq!(a.buffer().len(), b.buffer().len());
137    let a_host = a.to_host().unwrap();
138    let b_host = b.to_host().unwrap();
139    for r in 0..a.height() {
140        for c in 0..a.width() {
141            assert_eq!(
142                a_host[c * a.height() + r],
143                b_host[c * b.height() + r],
144                "Mismatch at row {} column {}",
145                r,
146                c
147            );
148        }
149    }
150}
151
152pub fn assert_eq_host_and_device_matrix<T: Clone + Send + Sync + PartialEq + Debug>(
153    cpu: Arc<RowMajorMatrix<T>>,
154    gpu: &DeviceMatrix<T>,
155) {
156    assert_eq!(gpu.width(), cpu.width());
157    assert_eq!(gpu.height(), cpu.height());
158    let gpu = gpu.to_host().unwrap();
159    for r in 0..cpu.height() {
160        for c in 0..cpu.width() {
161            assert_eq!(
162                gpu[c * cpu.height() + r],
163                cpu.get(r, c),
164                "Mismatch at row {} column {}",
165                r,
166                c
167            );
168        }
169    }
170}