openvm_circuit/system/memory/
persistent.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4    iter,
5    sync::Arc,
6};
7
8use openvm_circuit_primitives_derive::AlignedBorrow;
9use openvm_stark_backend::{
10    config::{StarkGenericConfig, Val},
11    interaction::{InteractionBuilder, PermutationCheckBus},
12    p3_air::{Air, AirBuilder, BaseAir},
13    p3_field::{FieldAlgebra, PrimeField32},
14    p3_matrix::{dense::RowMajorMatrix, Matrix},
15    p3_maybe_rayon::prelude::*,
16    prover::{cpu::CpuBackend, types::AirProvingContext},
17    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
18    Chip, ChipUsageGetter,
19};
20use rustc_hash::FxHashSet;
21use tracing::instrument;
22
23use super::{merkle::SerialReceiver, online::INITIAL_TIMESTAMP, TimestampedValues};
24use crate::{
25    arch::{hasher::Hasher, ADDR_SPACE_OFFSET},
26    system::memory::{
27        dimensions::MemoryDimensions, offline_checker::MemoryBus, MemoryAddress, MemoryImage,
28        TimestampedEquipartition,
29    },
30};
31
32/// The values describe aligned chunk of memory of size `CHUNK`---the data together with the last
33/// accessed timestamp---in either the initial or final memory state.
34#[repr(C)]
35#[derive(Debug, AlignedBorrow)]
36pub struct PersistentBoundaryCols<T, const CHUNK: usize> {
37    // `expand_direction` =  1 corresponds to initial memory state
38    // `expand_direction` = -1 corresponds to final memory state
39    // `expand_direction` =  0 corresponds to irrelevant row (all interactions multiplicity 0)
40    pub expand_direction: T,
41    pub address_space: T,
42    pub leaf_label: T,
43    pub values: [T; CHUNK],
44    pub hash: [T; CHUNK],
45    pub timestamp: T,
46}
47
48/// Imposes the following constraints:
49/// - `expand_direction` should be -1, 0, 1
50///
51/// Sends the following interactions:
52/// - if `expand_direction` is 1, sends `[0, 0, address_space_label, leaf_label]` to `merkle_bus`.
53/// - if `expand_direction` is -1, receives `[1, 0, address_space_label, leaf_label]` from
54///   `merkle_bus`.
55#[derive(Clone, Debug)]
56pub struct PersistentBoundaryAir<const CHUNK: usize> {
57    pub memory_dims: MemoryDimensions,
58    pub memory_bus: MemoryBus,
59    pub merkle_bus: PermutationCheckBus,
60    pub compression_bus: PermutationCheckBus,
61}
62
63impl<const CHUNK: usize, F> BaseAir<F> for PersistentBoundaryAir<CHUNK> {
64    fn width(&self) -> usize {
65        PersistentBoundaryCols::<F, CHUNK>::width()
66    }
67}
68
69impl<const CHUNK: usize, F> BaseAirWithPublicValues<F> for PersistentBoundaryAir<CHUNK> {}
70impl<const CHUNK: usize, F> PartitionedBaseAir<F> for PersistentBoundaryAir<CHUNK> {}
71
72impl<const CHUNK: usize, AB: InteractionBuilder> Air<AB> for PersistentBoundaryAir<CHUNK> {
73    fn eval(&self, builder: &mut AB) {
74        let main = builder.main();
75        let local = main.row_slice(0);
76        let local: &PersistentBoundaryCols<AB::Var, CHUNK> = (*local).borrow();
77
78        // `direction` should be -1, 0, 1
79        builder.assert_eq(
80            local.expand_direction,
81            local.expand_direction * local.expand_direction * local.expand_direction,
82        );
83
84        // Constrain that an "initial" row has timestamp zero.
85        // Since `direction` is constrained to be in {-1, 0, 1}, we can select `direction == 1`
86        // with the constraint below.
87        builder
88            .when(local.expand_direction * (local.expand_direction + AB::F::ONE))
89            .assert_zero(local.timestamp);
90
91        let mut expand_fields = vec![
92            // direction =  1 => is_final = 0
93            // direction = -1 => is_final = 1
94            local.expand_direction.into(),
95            AB::Expr::ZERO,
96            local.address_space - AB::F::from_canonical_u32(ADDR_SPACE_OFFSET),
97            local.leaf_label.into(),
98        ];
99        expand_fields.extend(local.hash.map(Into::into));
100        self.merkle_bus
101            .interact(builder, expand_fields, local.expand_direction.into());
102
103        self.compression_bus.interact(
104            builder,
105            iter::empty()
106                .chain(local.values.map(Into::into))
107                .chain(iter::repeat_n(AB::Expr::ZERO, CHUNK))
108                .chain(local.hash.map(Into::into)),
109            local.expand_direction * local.expand_direction,
110        );
111
112        self.memory_bus
113            .send(
114                MemoryAddress::new(
115                    local.address_space,
116                    local.leaf_label * AB::F::from_canonical_usize(CHUNK),
117                ),
118                local.values.to_vec(),
119                local.timestamp,
120            )
121            .eval(builder, local.expand_direction);
122    }
123}
124
125pub struct PersistentBoundaryChip<F, const CHUNK: usize> {
126    pub air: PersistentBoundaryAir<CHUNK>,
127    pub touched_labels: TouchedLabels<F, CHUNK>,
128    overridden_height: Option<usize>,
129}
130
131#[derive(Debug)]
132pub enum TouchedLabels<F, const CHUNK: usize> {
133    Running(FxHashSet<(u32, u32)>),
134    Final(Vec<FinalTouchedLabel<F, CHUNK>>),
135}
136
137#[derive(Debug)]
138pub struct FinalTouchedLabel<F, const CHUNK: usize> {
139    address_space: u32,
140    label: u32,
141    init_values: [F; CHUNK],
142    final_values: [F; CHUNK],
143    init_hash: [F; CHUNK],
144    final_hash: [F; CHUNK],
145    final_timestamp: u32,
146}
147
148impl<F: PrimeField32, const CHUNK: usize> Default for TouchedLabels<F, CHUNK> {
149    fn default() -> Self {
150        Self::Running(FxHashSet::default())
151    }
152}
153
154impl<F: PrimeField32, const CHUNK: usize> TouchedLabels<F, CHUNK> {
155    fn touch(&mut self, address_space: u32, label: u32) {
156        match self {
157            TouchedLabels::Running(touched_labels) => {
158                touched_labels.insert((address_space, label));
159            }
160            _ => panic!("Cannot touch after finalization"),
161        }
162    }
163
164    pub fn is_empty(&self) -> bool {
165        match self {
166            TouchedLabels::Running(touched_labels) => touched_labels.is_empty(),
167            TouchedLabels::Final(touched_labels) => touched_labels.is_empty(),
168        }
169    }
170
171    pub fn len(&self) -> usize {
172        match self {
173            TouchedLabels::Running(touched_labels) => touched_labels.len(),
174            TouchedLabels::Final(touched_labels) => touched_labels.len(),
175        }
176    }
177}
178
179impl<const CHUNK: usize, F: PrimeField32> PersistentBoundaryChip<F, CHUNK> {
180    pub fn new(
181        memory_dimensions: MemoryDimensions,
182        memory_bus: MemoryBus,
183        merkle_bus: PermutationCheckBus,
184        compression_bus: PermutationCheckBus,
185    ) -> Self {
186        Self {
187            air: PersistentBoundaryAir {
188                memory_dims: memory_dimensions,
189                memory_bus,
190                merkle_bus,
191                compression_bus,
192            },
193            touched_labels: Default::default(),
194            overridden_height: None,
195        }
196    }
197
198    pub fn set_overridden_height(&mut self, overridden_height: usize) {
199        self.overridden_height = Some(overridden_height);
200    }
201
202    pub fn touch_range(&mut self, address_space: u32, pointer: u32, len: u32) {
203        let start_label = pointer / CHUNK as u32;
204        let end_label = (pointer + len - 1) / CHUNK as u32;
205        for label in start_label..=end_label {
206            self.touched_labels.touch(address_space, label);
207        }
208    }
209
210    #[instrument(name = "boundary_finalize", level = "debug", skip_all)]
211    pub(crate) fn finalize<H>(
212        &mut self,
213        initial_memory: &MemoryImage,
214        // Only touched stuff
215        final_memory: &TimestampedEquipartition<F, CHUNK>,
216        hasher: &H,
217    ) where
218        H: Hasher<CHUNK, F> + Sync + for<'a> SerialReceiver<&'a [F]>,
219    {
220        let final_touched_labels: Vec<_> = final_memory
221            .par_iter()
222            .map(|&((addr_space, ptr), ts_values)| {
223                let TimestampedValues { timestamp, values } = ts_values;
224                // SAFETY: addr_space from `final_memory` are all in bounds
225                let init_values = array::from_fn(|i| unsafe {
226                    initial_memory.get_f::<F>(addr_space, ptr + i as u32)
227                });
228                let initial_hash = hasher.hash(&init_values);
229                let final_hash = hasher.hash(&values);
230                FinalTouchedLabel {
231                    address_space: addr_space,
232                    label: ptr / CHUNK as u32,
233                    init_values,
234                    final_values: values,
235                    init_hash: initial_hash,
236                    final_hash,
237                    final_timestamp: timestamp,
238                }
239            })
240            .collect();
241        for l in &final_touched_labels {
242            hasher.receive(&l.init_values);
243            hasher.receive(&l.final_values);
244        }
245        self.touched_labels = TouchedLabels::Final(final_touched_labels);
246    }
247}
248
249impl<const CHUNK: usize, RA, SC> Chip<RA, CpuBackend<SC>> for PersistentBoundaryChip<Val<SC>, CHUNK>
250where
251    SC: StarkGenericConfig,
252    Val<SC>: PrimeField32,
253{
254    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<CpuBackend<SC>> {
255        let trace = {
256            let width = PersistentBoundaryCols::<Val<SC>, CHUNK>::width();
257            // Boundary AIR should always present in order to fix the AIR ID of merkle AIR.
258            let mut height = (2 * self.touched_labels.len()).next_power_of_two();
259            if let Some(mut oh) = self.overridden_height {
260                oh = oh.next_power_of_two();
261                assert!(
262                    oh >= height,
263                    "Overridden height is less than the required height"
264                );
265                height = oh;
266            }
267            let mut rows = Val::<SC>::zero_vec(height * width);
268
269            let touched_labels = match &self.touched_labels {
270                TouchedLabels::Final(touched_labels) => touched_labels,
271                _ => panic!("Cannot generate trace before finalization"),
272            };
273
274            rows.par_chunks_mut(2 * width)
275                .zip(touched_labels.par_iter())
276                .for_each(|(row, touched_label)| {
277                    let (initial_row, final_row) = row.split_at_mut(width);
278                    *initial_row.borrow_mut() = PersistentBoundaryCols {
279                        expand_direction: Val::<SC>::ONE,
280                        address_space: Val::<SC>::from_canonical_u32(touched_label.address_space),
281                        leaf_label: Val::<SC>::from_canonical_u32(touched_label.label),
282                        values: touched_label.init_values,
283                        hash: touched_label.init_hash,
284                        timestamp: Val::<SC>::from_canonical_u32(INITIAL_TIMESTAMP),
285                    };
286
287                    *final_row.borrow_mut() = PersistentBoundaryCols {
288                        expand_direction: Val::<SC>::NEG_ONE,
289                        address_space: Val::<SC>::from_canonical_u32(touched_label.address_space),
290                        leaf_label: Val::<SC>::from_canonical_u32(touched_label.label),
291                        values: touched_label.final_values,
292                        hash: touched_label.final_hash,
293                        timestamp: Val::<SC>::from_canonical_u32(touched_label.final_timestamp),
294                    };
295                });
296            Arc::new(RowMajorMatrix::new(rows, width))
297        };
298        AirProvingContext::simple_no_pis(trace)
299    }
300}
301
302impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for PersistentBoundaryChip<F, CHUNK> {
303    fn air_name(&self) -> String {
304        "Boundary".to_string()
305    }
306
307    fn current_trace_height(&self) -> usize {
308        2 * self.touched_labels.len()
309    }
310
311    fn trace_width(&self) -> usize {
312        PersistentBoundaryCols::<F, CHUNK>::width()
313    }
314}