openvm_circuit/system/cuda/
program.rs

1use std::{mem::size_of, sync::Arc};
2
3use openvm_circuit::{system::program::ProgramExecutionCols, utils::next_power_of_two_or_zero};
4use openvm_cuda_backend::{
5    base::DeviceMatrix, gpu_device::GpuDevice, prover_backend::GpuBackend, types::F,
6};
7use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer};
8use openvm_instructions::{
9    program::{Program, DEFAULT_PC_STEP},
10    LocalOpcode, SystemOpcode,
11};
12use openvm_stark_backend::{
13    prover::{
14        hal::{MatrixDimensions, TraceCommitter},
15        types::{AirProvingContext, CommittedTraceData},
16    },
17    Chip,
18};
19use p3_field::FieldAlgebra;
20
21use crate::cuda_abi::program;
22
23pub struct ProgramChipGPU {
24    pub cached: Option<CommittedTraceData<GpuBackend>>,
25}
26
27impl ProgramChipGPU {
28    pub fn new() -> Self {
29        Self { cached: None }
30    }
31
32    pub fn generate_cached_trace(program: Program<F>) -> DeviceMatrix<F> {
33        let instructions = program
34            .enumerate_by_pc()
35            .into_iter()
36            .map(|(pc, instruction, _)| {
37                [
38                    F::from_canonical_u32(pc),
39                    instruction.opcode.to_field(),
40                    instruction.a,
41                    instruction.b,
42                    instruction.c,
43                    instruction.d,
44                    instruction.e,
45                    instruction.f,
46                    instruction.g,
47                ]
48            })
49            .collect::<Vec<_>>();
50
51        let num_records = instructions.len();
52        let height = next_power_of_two_or_zero(num_records);
53        let records = instructions
54            .into_iter()
55            .flatten()
56            .collect::<Vec<_>>()
57            .to_device()
58            .unwrap();
59
60        let trace = DeviceMatrix::<F>::with_capacity(height, size_of::<ProgramExecutionCols<u8>>());
61        unsafe {
62            program::cached_tracegen(
63                trace.buffer(),
64                trace.height(),
65                trace.width(),
66                &records,
67                program.pc_base,
68                DEFAULT_PC_STEP,
69                SystemOpcode::TERMINATE.global_opcode().as_usize(),
70            )
71            .expect("Failed to generate cached trace");
72        }
73        trace
74    }
75
76    pub fn get_committed_trace(
77        trace: DeviceMatrix<F>,
78        device: &GpuDevice,
79    ) -> CommittedTraceData<GpuBackend> {
80        let (root, pcs_data) = device.commit(std::slice::from_ref(&trace));
81        CommittedTraceData {
82            commitment: root,
83            trace,
84            data: pcs_data,
85        }
86    }
87}
88
89impl Default for ProgramChipGPU {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl Chip<Vec<u32>, GpuBackend> for ProgramChipGPU {
96    fn generate_proving_ctx(&self, filtered_exec_freqs: Vec<u32>) -> AirProvingContext<GpuBackend> {
97        let cached = self.cached.clone().expect("Cached program must be loaded");
98        let height = cached.trace.height();
99        let filtered_len = filtered_exec_freqs.len();
100        assert!(
101            filtered_len <= height,
102            "filtered_exec_freqs len={filtered_len} > cached trace height={height}"
103        );
104        let mut buffer: DeviceBuffer<F> = DeviceBuffer::with_capacity(height);
105
106        filtered_exec_freqs
107            .into_iter()
108            .map(F::from_canonical_u32)
109            .collect::<Vec<_>>()
110            .copy_to(&mut buffer)
111            .unwrap();
112        // Making sure to zero-out the untouched part of the buffer.
113        if filtered_len < height {
114            buffer.fill_zero_suffix(filtered_len).unwrap();
115        }
116
117        let trace = DeviceMatrix::new(Arc::new(buffer), height, 1);
118
119        AirProvingContext {
120            cached_mains: vec![cached],
121            common_main: Some(trace),
122            public_values: vec![],
123        }
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use openvm_circuit::system::program::trace::VmCommittedExe;
130    use openvm_cuda_backend::{
131        data_transporter::assert_eq_host_and_device_matrix, engine::GpuBabyBearPoseidon2Engine,
132        prelude::F,
133    };
134    use openvm_instructions::{
135        exe::VmExe,
136        instruction::Instruction,
137        program::{Program, DEFAULT_PC_STEP},
138        LocalOpcode,
139        SystemOpcode::*,
140    };
141    use openvm_native_compiler::{
142        FieldArithmeticOpcode::*, NativeBranchEqualOpcode, NativeJalOpcode::*,
143        NativeLoadStoreOpcode::*,
144    };
145    use openvm_rv32im_transpiler::BranchEqualOpcode::*;
146    use openvm_stark_backend::config::StarkGenericConfig;
147    use openvm_stark_sdk::{
148        config::{
149            baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine},
150            FriParameters,
151        },
152        engine::{StarkEngine, StarkFriEngine},
153    };
154
155    use super::ProgramChipGPU;
156
157    fn test_cached_committed_trace_data(program: Program<F>) {
158        let gpu_engine = GpuBabyBearPoseidon2Engine::new(FriParameters::new_for_testing(2));
159        let gpu_device = gpu_engine.device();
160        let gpu_trace = ProgramChipGPU::generate_cached_trace(program.clone());
161        let gpu_cached = ProgramChipGPU::get_committed_trace(gpu_trace, gpu_device);
162
163        let cpu_engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(2));
164        let cpu_exe = VmExe::new(program.clone());
165        let cpu_committed_exe =
166            VmCommittedExe::<BabyBearPoseidon2Config>::commit(cpu_exe, cpu_engine.config().pcs());
167        let cpu_cached = cpu_committed_exe.get_committed_trace();
168
169        assert_eq_host_and_device_matrix(cpu_cached.trace, &gpu_cached.trace);
170        assert_eq!(gpu_cached.commitment, cpu_cached.commitment);
171    }
172
173    #[test]
174    fn test_cuda_program_cached_tracegen_1() {
175        let instructions = vec![
176            Instruction::large_from_isize(STOREW.global_opcode(), 2, 0, 0, 0, 1, 0, 1),
177            Instruction::large_from_isize(STOREW.global_opcode(), 1, 1, 0, 0, 1, 0, 1),
178            Instruction::from_isize(
179                NativeBranchEqualOpcode(BEQ).global_opcode(),
180                0,
181                0,
182                3 * DEFAULT_PC_STEP as isize,
183                1,
184                0,
185            ),
186            Instruction::from_isize(SUB.global_opcode(), 0, 0, 1, 1, 1),
187            Instruction::from_isize(
188                JAL.global_opcode(),
189                2,
190                -2 * (DEFAULT_PC_STEP as isize),
191                0,
192                1,
193                0,
194            ),
195            Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0),
196        ];
197        let program = Program::from_instructions(&instructions);
198        test_cached_committed_trace_data(program);
199    }
200
201    #[test]
202    fn test_cuda_program_cached_tracegen_2() {
203        let instructions = vec![
204            Instruction::large_from_isize(STOREW.global_opcode(), 5, 0, 0, 0, 1, 0, 1),
205            Instruction::from_isize(
206                NativeBranchEqualOpcode(BNE).global_opcode(),
207                0,
208                4,
209                3 * DEFAULT_PC_STEP as isize,
210                1,
211                0,
212            ),
213            Instruction::from_isize(
214                JAL.global_opcode(),
215                2,
216                -2 * DEFAULT_PC_STEP as isize,
217                0,
218                1,
219                0,
220            ),
221            Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0),
222            Instruction::from_isize(
223                NativeBranchEqualOpcode(BEQ).global_opcode(),
224                0,
225                5,
226                -(DEFAULT_PC_STEP as isize),
227                1,
228                0,
229            ),
230        ];
231        let program = Program::from_instructions(&instructions);
232        test_cached_committed_trace_data(program);
233    }
234
235    #[test]
236    fn test_cuda_program_cached_tracegen_undefined_instructions() {
237        let instructions = vec![
238            Some(Instruction::large_from_isize(
239                STOREW.global_opcode(),
240                2,
241                0,
242                0,
243                0,
244                1,
245                0,
246                1,
247            )),
248            Some(Instruction::large_from_isize(
249                STOREW.global_opcode(),
250                1,
251                1,
252                0,
253                0,
254                1,
255                0,
256                1,
257            )),
258            Some(Instruction::from_isize(
259                NativeBranchEqualOpcode(BEQ).global_opcode(),
260                0,
261                2,
262                3 * DEFAULT_PC_STEP as isize,
263                1,
264                0,
265            )),
266            None,
267            None,
268            Some(Instruction::from_isize(
269                TERMINATE.global_opcode(),
270                0,
271                0,
272                0,
273                0,
274                0,
275            )),
276        ];
277        let program = Program::new_without_debug_infos_with_option(&instructions, 0);
278        test_cached_committed_trace_data(program);
279    }
280}