openvm_circuit/system/cuda/
program.rs

1use std::{mem::size_of, sync::Arc};
2
3use openvm_circuit::{system::program::ProgramExecutionCols, utils::next_power_of_two_or_zero};
4use openvm_cuda_backend::{
5    base::DeviceMatrix, gpu_device::GpuDevice, prover_backend::GpuBackend, types::F,
6};
7use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer};
8use openvm_instructions::{
9    program::{Program, DEFAULT_PC_STEP},
10    LocalOpcode, SystemOpcode,
11};
12use openvm_stark_backend::{
13    prover::{
14        hal::{MatrixDimensions, TraceCommitter},
15        types::{AirProvingContext, CommittedTraceData},
16    },
17    Chip,
18};
19use p3_field::FieldAlgebra;
20
21use crate::cuda_abi::program;
22
23pub struct ProgramChipGPU {
24    pub cached: Option<CommittedTraceData<GpuBackend>>,
25}
26
27impl ProgramChipGPU {
28    pub fn new() -> Self {
29        Self { cached: None }
30    }
31
32    pub fn generate_cached_trace(program: Program<F>) -> DeviceMatrix<F> {
33        let instructions = program
34            .enumerate_by_pc()
35            .into_iter()
36            .map(|(pc, instruction, _)| {
37                [
38                    F::from_canonical_u32(pc),
39                    instruction.opcode.to_field(),
40                    instruction.a,
41                    instruction.b,
42                    instruction.c,
43                    instruction.d,
44                    instruction.e,
45                    instruction.f,
46                    instruction.g,
47                ]
48            })
49            .collect::<Vec<_>>();
50
51        let num_records = instructions.len();
52        let height = next_power_of_two_or_zero(num_records);
53        let records = instructions
54            .into_iter()
55            .flatten()
56            .collect::<Vec<_>>()
57            .to_device()
58            .unwrap();
59
60        let trace = DeviceMatrix::<F>::with_capacity(height, size_of::<ProgramExecutionCols<u8>>());
61        unsafe {
62            program::cached_tracegen(
63                trace.buffer(),
64                trace.height(),
65                trace.width(),
66                &records,
67                program.pc_base,
68                DEFAULT_PC_STEP,
69                SystemOpcode::TERMINATE.global_opcode().as_usize(),
70            )
71            .expect("Failed to generate cached trace");
72        }
73        trace
74    }
75
76    pub fn get_committed_trace(
77        trace: DeviceMatrix<F>,
78        device: &GpuDevice,
79    ) -> CommittedTraceData<GpuBackend> {
80        let (root, pcs_data) = device.commit(&[trace.clone()]);
81        CommittedTraceData {
82            commitment: root,
83            trace,
84            data: pcs_data,
85        }
86    }
87}
88
89impl Default for ProgramChipGPU {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl Chip<Vec<u32>, GpuBackend> for ProgramChipGPU {
96    fn generate_proving_ctx(&self, filtered_exec_freqs: Vec<u32>) -> AirProvingContext<GpuBackend> {
97        let cached = self.cached.clone().expect("Cached program must be loaded");
98        let height = cached.trace.height();
99        let filtered_len = filtered_exec_freqs.len();
100        assert!(
101            filtered_len <= height,
102            "filtered_exec_freqs len={} > cached trace height={}",
103            filtered_len,
104            height
105        );
106        let mut buffer: DeviceBuffer<F> = DeviceBuffer::with_capacity(height);
107
108        filtered_exec_freqs
109            .into_iter()
110            .map(F::from_canonical_u32)
111            .collect::<Vec<_>>()
112            .copy_to(&mut buffer)
113            .unwrap();
114        // Making sure to zero-out the untouched part of the buffer.
115        if filtered_len < height {
116            buffer.fill_zero_suffix(filtered_len).unwrap();
117        }
118
119        let trace = DeviceMatrix::new(Arc::new(buffer), height, 1);
120
121        AirProvingContext {
122            cached_mains: vec![cached],
123            common_main: Some(trace),
124            public_values: vec![],
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use openvm_circuit::system::program::trace::VmCommittedExe;
132    use openvm_cuda_backend::{
133        data_transporter::assert_eq_host_and_device_matrix, engine::GpuBabyBearPoseidon2Engine,
134        prelude::F,
135    };
136    use openvm_instructions::{
137        exe::VmExe,
138        instruction::Instruction,
139        program::{Program, DEFAULT_PC_STEP},
140        LocalOpcode,
141        SystemOpcode::*,
142    };
143    use openvm_native_compiler::{
144        FieldArithmeticOpcode::*, NativeBranchEqualOpcode, NativeJalOpcode::*,
145        NativeLoadStoreOpcode::*,
146    };
147    use openvm_rv32im_transpiler::BranchEqualOpcode::*;
148    use openvm_stark_backend::config::StarkGenericConfig;
149    use openvm_stark_sdk::{
150        config::{
151            baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine},
152            FriParameters,
153        },
154        engine::{StarkEngine, StarkFriEngine},
155    };
156
157    use super::ProgramChipGPU;
158
159    fn test_cached_committed_trace_data(program: Program<F>) {
160        let gpu_engine = GpuBabyBearPoseidon2Engine::new(FriParameters::new_for_testing(2));
161        let gpu_device = gpu_engine.device();
162        let gpu_trace = ProgramChipGPU::generate_cached_trace(program.clone());
163        let gpu_cached = ProgramChipGPU::get_committed_trace(gpu_trace, gpu_device);
164
165        let cpu_engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(2));
166        let cpu_exe = VmExe::new(program.clone());
167        let cpu_committed_exe =
168            VmCommittedExe::<BabyBearPoseidon2Config>::commit(cpu_exe, cpu_engine.config().pcs());
169        let cpu_cached = cpu_committed_exe.get_committed_trace();
170
171        assert_eq_host_and_device_matrix(cpu_cached.trace, &gpu_cached.trace);
172        assert_eq!(gpu_cached.commitment, cpu_cached.commitment);
173    }
174
175    #[test]
176    fn test_cuda_program_cached_tracegen_1() {
177        let instructions = vec![
178            Instruction::large_from_isize(STOREW.global_opcode(), 2, 0, 0, 0, 1, 0, 1),
179            Instruction::large_from_isize(STOREW.global_opcode(), 1, 1, 0, 0, 1, 0, 1),
180            Instruction::from_isize(
181                NativeBranchEqualOpcode(BEQ).global_opcode(),
182                0,
183                0,
184                3 * DEFAULT_PC_STEP as isize,
185                1,
186                0,
187            ),
188            Instruction::from_isize(SUB.global_opcode(), 0, 0, 1, 1, 1),
189            Instruction::from_isize(
190                JAL.global_opcode(),
191                2,
192                -2 * (DEFAULT_PC_STEP as isize),
193                0,
194                1,
195                0,
196            ),
197            Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0),
198        ];
199        let program = Program::from_instructions(&instructions);
200        test_cached_committed_trace_data(program);
201    }
202
203    #[test]
204    fn test_cuda_program_cached_tracegen_2() {
205        let instructions = vec![
206            Instruction::large_from_isize(STOREW.global_opcode(), 5, 0, 0, 0, 1, 0, 1),
207            Instruction::from_isize(
208                NativeBranchEqualOpcode(BNE).global_opcode(),
209                0,
210                4,
211                3 * DEFAULT_PC_STEP as isize,
212                1,
213                0,
214            ),
215            Instruction::from_isize(
216                JAL.global_opcode(),
217                2,
218                -2 * DEFAULT_PC_STEP as isize,
219                0,
220                1,
221                0,
222            ),
223            Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0),
224            Instruction::from_isize(
225                NativeBranchEqualOpcode(BEQ).global_opcode(),
226                0,
227                5,
228                -(DEFAULT_PC_STEP as isize),
229                1,
230                0,
231            ),
232        ];
233        let program = Program::from_instructions(&instructions);
234        test_cached_committed_trace_data(program);
235    }
236
237    #[test]
238    fn test_cuda_program_cached_tracegen_undefined_instructions() {
239        let instructions = vec![
240            Some(Instruction::large_from_isize(
241                STOREW.global_opcode(),
242                2,
243                0,
244                0,
245                0,
246                1,
247                0,
248                1,
249            )),
250            Some(Instruction::large_from_isize(
251                STOREW.global_opcode(),
252                1,
253                1,
254                0,
255                0,
256                1,
257                0,
258                1,
259            )),
260            Some(Instruction::from_isize(
261                NativeBranchEqualOpcode(BEQ).global_opcode(),
262                0,
263                2,
264                3 * DEFAULT_PC_STEP as isize,
265                1,
266                0,
267            )),
268            None,
269            None,
270            Some(Instruction::from_isize(
271                TERMINATE.global_opcode(),
272                0,
273                0,
274                0,
275                0,
276                0,
277            )),
278        ];
279        let program = Program::new_without_debug_infos_with_option(&instructions, 0);
280        test_cached_committed_trace_data(program);
281    }
282}