openvm_cuda_backend/
data_transporter.rs1use std::{fmt::Debug, sync::Arc};
2
3use openvm_cuda_common::{
4 copy::{MemCopyD2H, MemCopyH2D},
5 d_buffer::DeviceBuffer,
6};
7use openvm_stark_backend::{
8 config::{Com, PcsProverData, Val},
9 keygen::types::MultiStarkProvingKey,
10 prover::{
11 hal::{DeviceDataTransporter, MatrixDimensions, TraceCommitter},
12 types::{
13 CommittedTraceData, DeviceMultiStarkProvingKey, DeviceStarkProvingKey,
14 SingleCommitPreimage,
15 },
16 },
17};
18use p3_matrix::{dense::RowMajorMatrix, Matrix};
19
20use crate::{
21 base::DeviceMatrix,
22 cuda::kernels::matrix::matrix_transpose,
23 gpu_device::GpuDevice,
24 prelude::{F, SC},
25 prover_backend::GpuBackend,
26};
27
28impl DeviceDataTransporter<SC, GpuBackend> for GpuDevice {
29 fn transport_pk_to_device(
30 &self,
31 mpk: &MultiStarkProvingKey<SC>,
32 ) -> DeviceMultiStarkProvingKey<GpuBackend> {
33 let per_air = mpk
34 .per_air
35 .iter()
36 .map(|pk| {
37 let preprocessed_data = pk.preprocessed_data.as_ref().map(|pd| {
38 let trace = self.transport_matrix_to_device(&pd.trace);
39 let (_, data) = self.commit(&[trace.clone()]);
40 SingleCommitPreimage {
41 trace,
42 data,
43 matrix_idx: 0,
44 }
45 });
46
47 DeviceStarkProvingKey {
48 air_name: pk.air_name.clone(),
49 vk: pk.vk.clone(),
50 preprocessed_data,
51 rap_partial_pk: pk.rap_partial_pk.clone(),
52 }
53 })
54 .collect();
55
56 DeviceMultiStarkProvingKey::new(
57 per_air,
58 mpk.trace_height_constraints.clone(),
59 mpk.vk_pre_hash,
60 )
61 }
62
63 fn transport_matrix_to_device(&self, matrix: &Arc<RowMajorMatrix<F>>) -> DeviceMatrix<F> {
64 transport_matrix_to_device(matrix.clone())
65 }
66
67 fn transport_committed_trace_to_device(
70 &self,
71 commitment: Com<SC>,
72 trace: &Arc<RowMajorMatrix<Val<SC>>>,
73 _: &Arc<PcsProverData<SC>>,
74 ) -> CommittedTraceData<GpuBackend> {
75 let trace = self.transport_matrix_to_device(trace);
76 let (d_commitment, data) = self.commit(&[trace.clone()]);
77 assert_eq!(
78 d_commitment, commitment,
79 "GPU commitment does not match host"
80 );
81 CommittedTraceData {
82 commitment,
83 trace,
84 data,
85 }
86 }
87
88 fn transport_matrix_from_device_to_host(
89 &self,
90 matrix: &DeviceMatrix<F>,
91 ) -> Arc<RowMajorMatrix<F>> {
92 let matrix_host = transport_device_matrix_to_host(matrix);
93 Arc::new(matrix_host)
94 }
95}
96
97pub fn transport_matrix_to_device(matrix: Arc<RowMajorMatrix<F>>) -> DeviceMatrix<F> {
98 let data = matrix.values.as_slice();
99 let input_buffer = data.to_device().unwrap();
100 let output = DeviceMatrix::<F>::with_capacity(matrix.height(), matrix.width());
101 unsafe {
102 matrix_transpose::<F>(
103 output.buffer(),
104 &input_buffer,
105 matrix.width(),
106 matrix.height(),
107 )
108 .unwrap();
109 }
110 assert_eq!(output.strong_count(), 1);
111 output
112}
113
114pub fn transport_device_matrix_to_host<T: Clone + Send + Sync>(
115 matrix: &DeviceMatrix<T>,
116) -> RowMajorMatrix<T> {
117 let matrix_buffer = DeviceBuffer::<T>::with_capacity(matrix.height() * matrix.width());
118 unsafe {
119 matrix_transpose::<T>(
120 &matrix_buffer,
121 matrix.buffer(),
122 matrix.height(),
123 matrix.width(),
124 )
125 .unwrap();
126 }
127 RowMajorMatrix::<T>::new(matrix_buffer.to_host().unwrap(), matrix.width())
128}
129
130pub fn assert_eq_device_matrix<T: Clone + Send + Sync + PartialEq + Debug>(
131 a: &DeviceMatrix<T>,
132 b: &DeviceMatrix<T>,
133) {
134 assert_eq!(a.height(), b.height());
135 assert_eq!(a.width(), b.width());
136 assert_eq!(a.buffer().len(), b.buffer().len());
137 let a_host = a.to_host().unwrap();
138 let b_host = b.to_host().unwrap();
139 for r in 0..a.height() {
140 for c in 0..a.width() {
141 assert_eq!(
142 a_host[c * a.height() + r],
143 b_host[c * b.height() + r],
144 "Mismatch at row {} column {}",
145 r,
146 c
147 );
148 }
149 }
150}
151
152pub fn assert_eq_host_and_device_matrix<T: Clone + Send + Sync + PartialEq + Debug>(
153 cpu: Arc<RowMajorMatrix<T>>,
154 gpu: &DeviceMatrix<T>,
155) {
156 assert_eq!(gpu.width(), cpu.width());
157 assert_eq!(gpu.height(), cpu.height());
158 let gpu = gpu.to_host().unwrap();
159 for r in 0..cpu.height() {
160 for c in 0..cpu.width() {
161 assert_eq!(
162 gpu[c * cpu.height() + r],
163 cpu.get(r, c),
164 "Mismatch at row {} column {}",
165 r,
166 c
167 );
168 }
169 }
170}