openvm_circuit/system/program/
trace.rs

1use std::{borrow::BorrowMut, sync::Arc};
2
3use derivative::Derivative;
4use itertools::Itertools;
5use openvm_circuit::arch::hasher::poseidon2::Poseidon2Hasher;
6use openvm_instructions::{
7    exe::VmExe,
8    program::{Program, DEFAULT_PC_STEP},
9    LocalOpcode, SystemOpcode,
10};
11use openvm_stark_backend::{
12    config::{Com, PcsProverData, StarkGenericConfig, Val},
13    p3_commit::Pcs,
14    p3_field::{Field, FieldAlgebra, PrimeField32},
15    p3_matrix::{dense::RowMajorMatrix, Matrix},
16    p3_maybe_rayon::prelude::*,
17    p3_util::log2_strict_usize,
18    prover::{
19        cpu::{self, CpuBackend},
20        types::{AirProvingContext, CommittedTraceData},
21    },
22    Chip,
23};
24use serde::{Deserialize, Serialize};
25
26use super::{Instruction, ProgramExecutionCols, EXIT_CODE_FAIL};
27use crate::{
28    arch::{
29        hasher::{poseidon2::vm_poseidon2_hasher, Hasher},
30        MemoryConfig,
31    },
32    system::{
33        memory::{merkle::MerkleTree, AddressMap, CHUNK},
34        program::ProgramChip,
35    },
36};
37
38/// **Note**: this struct stores the program ROM twice: once in [VmExe] and once as a cached trace
39/// matrix `trace`.
40#[derive(Serialize, Deserialize, Derivative)]
41#[serde(bound(
42    serialize = "VmExe<Val<SC>>: Serialize, Com<SC>: Serialize, PcsProverData<SC>: Serialize",
43    deserialize = "VmExe<Val<SC>>: Deserialize<'de>, Com<SC>: Deserialize<'de>, PcsProverData<SC>: Deserialize<'de>"
44))]
45#[derivative(Clone(bound = "Com<SC>: Clone"))]
46pub struct VmCommittedExe<SC: StarkGenericConfig> {
47    /// Raw executable.
48    pub exe: Arc<VmExe<Val<SC>>>,
49    program_commitment: Com<SC>,
50    /// Program ROM as cached trace matrix.
51    pub trace: Arc<RowMajorMatrix<Val<SC>>>,
52    pub prover_data: Arc<PcsProverData<SC>>,
53}
54
55impl<SC: StarkGenericConfig> VmCommittedExe<SC> {
56    /// Creates [VmCommittedExe] from [VmExe] by using `pcs` to commit to the
57    /// program code as a _cached trace_ matrix.
58    pub fn commit(exe: VmExe<Val<SC>>, pcs: &SC::Pcs) -> Self {
59        let trace = generate_cached_trace(&exe.program);
60        let domain = pcs.natural_domain_for_degree(trace.height());
61
62        let (program_commitment, data) = pcs.commit(vec![(domain, trace.clone())]);
63        Self {
64            exe: Arc::new(exe),
65            program_commitment,
66            trace: Arc::new(trace),
67            prover_data: Arc::new(data),
68        }
69    }
70    pub fn get_program_commit(&self) -> Com<SC> {
71        self.program_commitment.clone()
72    }
73
74    pub fn get_committed_trace(&self) -> CommittedTraceData<CpuBackend<SC>> {
75        let log_trace_height: u8 = log2_strict_usize(self.trace.height()).try_into().unwrap();
76        let data = cpu::PcsData::new(self.prover_data.clone(), vec![log_trace_height]);
77        CommittedTraceData {
78            commitment: self.program_commitment.clone(),
79            trace: self.trace.clone(),
80            data,
81        }
82    }
83
84    /// Computes a commitment to [VmCommittedExe]. This is a Merklelized hash of:
85    /// - Program code commitment (commitment of the cached trace)
86    /// - Merkle root of the initial memory
87    /// - Starting program counter (`pc_start`)
88    ///
89    /// The program code commitment is itself a commitment (via the proof system PCS) to
90    /// the program code.
91    ///
92    /// The Merklelization uses Poseidon2 as a cryptographic hash function (for the leaves)
93    /// and a cryptographic compression function (for internal nodes).
94    ///
95    /// **Note**: This function recomputes the Merkle tree for the initial memory image.
96    pub fn compute_exe_commit(
97        program_commitment: &Com<SC>,
98        exe: &VmExe<Val<SC>>,
99        memory_config: &MemoryConfig,
100    ) -> Com<SC>
101    where
102        Com<SC>: AsRef<[Val<SC>; CHUNK]> + From<[Val<SC>; CHUNK]>,
103        Val<SC>: PrimeField32,
104    {
105        let hasher = vm_poseidon2_hasher();
106        let memory_dimensions = memory_config.memory_dimensions();
107        let app_program_commit: &[Val<SC>; CHUNK] = program_commitment.as_ref();
108        let mem_config = memory_config;
109        let mut memory_image = AddressMap::new(mem_config.addr_spaces.clone());
110        memory_image.set_from_sparse(&exe.init_memory);
111        let init_memory_commit =
112            MerkleTree::from_memory(&memory_image, &memory_dimensions, &hasher).root();
113        Com::<SC>::from(compute_exe_commit(
114            &hasher,
115            app_program_commit,
116            &init_memory_commit,
117            Val::<SC>::from_canonical_u32(exe.pc_start),
118        ))
119    }
120}
121
122impl<RA, SC: StarkGenericConfig> Chip<RA, CpuBackend<SC>> for ProgramChip<SC> {
123    /// The cached program trace is cloned and left for future use. The clone is cheap because the
124    /// cached trace is behind smart pointers. The execution frequencies are left unchanged.
125    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<CpuBackend<SC>> {
126        let cached = self
127            .cached
128            .clone()
129            .expect("cached program trace must be loaded");
130        assert!(self.filtered_exec_frequencies.len() <= cached.trace.height());
131        let mut freqs = Val::<SC>::zero_vec(cached.trace.height());
132        freqs
133            .par_iter_mut()
134            .zip(self.filtered_exec_frequencies.par_iter())
135            .for_each(|(f, x)| *f = Val::<SC>::from_canonical_u32(*x));
136        let common_trace = RowMajorMatrix::new_col(freqs);
137        AirProvingContext {
138            cached_mains: vec![cached],
139            common_main: Some(Arc::new(common_trace)),
140            public_values: vec![],
141        }
142    }
143}
144
145/// Computes a Merklelized hash of:
146/// - Program code commitment (commitment of the cached trace)
147/// - Merkle root of the initial memory
148/// - Starting program counter (`pc_start`)
149///
150/// The Merklelization uses [Poseidon2Hasher] as a cryptographic hash function (for the leaves)
151/// and a cryptographic compression function (for internal nodes).
152pub fn compute_exe_commit<F: PrimeField32>(
153    hasher: &Poseidon2Hasher<F>,
154    program_commit: &[F; CHUNK],
155    init_memory_root: &[F; CHUNK],
156    pc_start: F,
157) -> [F; CHUNK] {
158    let mut padded_pc_start = [F::ZERO; CHUNK];
159    padded_pc_start[0] = pc_start;
160    let program_hash = hasher.hash(program_commit);
161    let memory_hash = hasher.hash(init_memory_root);
162    let pc_hash = hasher.hash(&padded_pc_start);
163    hasher.compress(&hasher.compress(&program_hash, &memory_hash), &pc_hash)
164}
165
166pub(crate) fn generate_cached_trace<F: Field>(program: &Program<F>) -> RowMajorMatrix<F> {
167    let width = ProgramExecutionCols::<F>::width();
168    let mut instructions = program
169        .enumerate_by_pc()
170        .into_iter()
171        .map(|(pc, instruction, _)| (pc, instruction))
172        .collect_vec();
173
174    let padding = padding_instruction();
175    while !instructions.len().is_power_of_two() {
176        instructions.push((
177            program.pc_base + instructions.len() as u32 * DEFAULT_PC_STEP,
178            padding.clone(),
179        ));
180    }
181
182    let mut rows = F::zero_vec(instructions.len() * width);
183    rows.par_chunks_mut(width)
184        .zip(instructions)
185        .for_each(|(row, (pc, instruction))| {
186            let row: &mut ProgramExecutionCols<F> = row.borrow_mut();
187            *row = ProgramExecutionCols {
188                pc: F::from_canonical_u32(pc),
189                opcode: instruction.opcode.to_field(),
190                a: instruction.a,
191                b: instruction.b,
192                c: instruction.c,
193                d: instruction.d,
194                e: instruction.e,
195                f: instruction.f,
196                g: instruction.g,
197            };
198        });
199
200    RowMajorMatrix::new(rows, width)
201}
202
203pub(super) fn padding_instruction<F: Field>() -> Instruction<F> {
204    Instruction::from_usize(
205        SystemOpcode::TERMINATE.global_opcode(),
206        [0, 0, EXIT_CODE_FAIL],
207    )
208}