openvm_circuit/system/memory/
persistent.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4    iter,
5    sync::Arc,
6};
7
8use openvm_circuit_primitives_derive::AlignedBorrow;
9use openvm_stark_backend::{
10    config::{StarkGenericConfig, Val},
11    interaction::{InteractionBuilder, PermutationCheckBus},
12    p3_air::{Air, AirBuilder, BaseAir},
13    p3_field::{FieldAlgebra, PrimeField32},
14    p3_matrix::{dense::RowMajorMatrix, Matrix},
15    p3_maybe_rayon::prelude::*,
16    prover::types::AirProofInput,
17    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
18    AirRef, Chip, ChipUsageGetter,
19};
20use rustc_hash::FxHashSet;
21
22use super::merkle::SerialReceiver;
23use crate::{
24    arch::hasher::Hasher,
25    system::memory::{
26        dimensions::MemoryDimensions, offline_checker::MemoryBus, MemoryAddress, MemoryImage,
27        TimestampedEquipartition, INITIAL_TIMESTAMP,
28    },
29};
30
31/// The values describe aligned chunk of memory of size `CHUNK`---the data together with the last
32/// accessed timestamp---in either the initial or final memory state.
33#[repr(C)]
34#[derive(Debug, AlignedBorrow)]
35pub struct PersistentBoundaryCols<T, const CHUNK: usize> {
36    // `expand_direction` =  1 corresponds to initial memory state
37    // `expand_direction` = -1 corresponds to final memory state
38    // `expand_direction` =  0 corresponds to irrelevant row (all interactions multiplicity 0)
39    pub expand_direction: T,
40    pub address_space: T,
41    pub leaf_label: T,
42    pub values: [T; CHUNK],
43    pub hash: [T; CHUNK],
44    pub timestamp: T,
45}
46
47/// Imposes the following constraints:
48/// - `expand_direction` should be -1, 0, 1
49///
50/// Sends the following interactions:
51/// - if `expand_direction` is 1, sends `[0, 0, address_space_label, leaf_label]` to `merkle_bus`.
52/// - if `expand_direction` is -1, receives `[1, 0, address_space_label, leaf_label]` from
53///   `merkle_bus`.
54#[derive(Clone, Debug)]
55pub struct PersistentBoundaryAir<const CHUNK: usize> {
56    pub memory_dims: MemoryDimensions,
57    pub memory_bus: MemoryBus,
58    pub merkle_bus: PermutationCheckBus,
59    pub compression_bus: PermutationCheckBus,
60}
61
62impl<const CHUNK: usize, F> BaseAir<F> for PersistentBoundaryAir<CHUNK> {
63    fn width(&self) -> usize {
64        PersistentBoundaryCols::<F, CHUNK>::width()
65    }
66}
67
68impl<const CHUNK: usize, F> BaseAirWithPublicValues<F> for PersistentBoundaryAir<CHUNK> {}
69impl<const CHUNK: usize, F> PartitionedBaseAir<F> for PersistentBoundaryAir<CHUNK> {}
70
71impl<const CHUNK: usize, AB: InteractionBuilder> Air<AB> for PersistentBoundaryAir<CHUNK> {
72    fn eval(&self, builder: &mut AB) {
73        let main = builder.main();
74        let local = main.row_slice(0);
75        let local: &PersistentBoundaryCols<AB::Var, CHUNK> = (*local).borrow();
76
77        // `direction` should be -1, 0, 1
78        builder.assert_eq(
79            local.expand_direction,
80            local.expand_direction * local.expand_direction * local.expand_direction,
81        );
82
83        // Constrain that an "initial" row has timestamp zero.
84        // Since `direction` is constrained to be in {-1, 0, 1}, we can select `direction == 1`
85        // with the constraint below.
86        builder
87            .when(local.expand_direction * (local.expand_direction + AB::F::ONE))
88            .assert_zero(local.timestamp);
89
90        let mut expand_fields = vec![
91            // direction =  1 => is_final = 0
92            // direction = -1 => is_final = 1
93            local.expand_direction.into(),
94            AB::Expr::ZERO,
95            local.address_space - AB::F::from_canonical_u32(self.memory_dims.as_offset),
96            local.leaf_label.into(),
97        ];
98        expand_fields.extend(local.hash.map(Into::into));
99        self.merkle_bus
100            .interact(builder, expand_fields, local.expand_direction.into());
101
102        self.compression_bus.interact(
103            builder,
104            iter::empty()
105                .chain(local.values.map(Into::into))
106                .chain(iter::repeat_n(AB::Expr::ZERO, CHUNK))
107                .chain(local.hash.map(Into::into)),
108            local.expand_direction * local.expand_direction,
109        );
110
111        self.memory_bus
112            .send(
113                MemoryAddress::new(
114                    local.address_space,
115                    local.leaf_label * AB::F::from_canonical_usize(CHUNK),
116                ),
117                local.values.to_vec(),
118                local.timestamp,
119            )
120            .eval(builder, local.expand_direction);
121    }
122}
123
124pub struct PersistentBoundaryChip<F, const CHUNK: usize> {
125    pub air: PersistentBoundaryAir<CHUNK>,
126    touched_labels: TouchedLabels<F, CHUNK>,
127    overridden_height: Option<usize>,
128}
129
130#[derive(Debug)]
131enum TouchedLabels<F, const CHUNK: usize> {
132    Running(FxHashSet<(u32, u32)>),
133    Final(Vec<FinalTouchedLabel<F, CHUNK>>),
134}
135
136#[derive(Debug)]
137struct FinalTouchedLabel<F, const CHUNK: usize> {
138    address_space: u32,
139    label: u32,
140    init_values: [F; CHUNK],
141    final_values: [F; CHUNK],
142    init_hash: [F; CHUNK],
143    final_hash: [F; CHUNK],
144    final_timestamp: u32,
145}
146
147impl<F: PrimeField32, const CHUNK: usize> Default for TouchedLabels<F, CHUNK> {
148    fn default() -> Self {
149        Self::Running(FxHashSet::default())
150    }
151}
152
153impl<F: PrimeField32, const CHUNK: usize> TouchedLabels<F, CHUNK> {
154    fn touch(&mut self, address_space: u32, label: u32) {
155        match self {
156            TouchedLabels::Running(touched_labels) => {
157                touched_labels.insert((address_space, label));
158            }
159            _ => panic!("Cannot touch after finalization"),
160        }
161    }
162    fn len(&self) -> usize {
163        match self {
164            TouchedLabels::Running(touched_labels) => touched_labels.len(),
165            TouchedLabels::Final(touched_labels) => touched_labels.len(),
166        }
167    }
168}
169
170impl<const CHUNK: usize, F: PrimeField32> PersistentBoundaryChip<F, CHUNK> {
171    pub fn new(
172        memory_dimensions: MemoryDimensions,
173        memory_bus: MemoryBus,
174        merkle_bus: PermutationCheckBus,
175        compression_bus: PermutationCheckBus,
176    ) -> Self {
177        Self {
178            air: PersistentBoundaryAir {
179                memory_dims: memory_dimensions,
180                memory_bus,
181                merkle_bus,
182                compression_bus,
183            },
184            touched_labels: Default::default(),
185            overridden_height: None,
186        }
187    }
188
189    pub fn set_overridden_height(&mut self, overridden_height: usize) {
190        self.overridden_height = Some(overridden_height);
191    }
192
193    pub fn touch_range(&mut self, address_space: u32, pointer: u32, len: u32) {
194        let start_label = pointer / CHUNK as u32;
195        let end_label = (pointer + len - 1) / CHUNK as u32;
196        for label in start_label..=end_label {
197            self.touched_labels.touch(address_space, label);
198        }
199    }
200
201    pub fn finalize<H>(
202        &mut self,
203        initial_memory: &MemoryImage<F>,
204        final_memory: &TimestampedEquipartition<F, CHUNK>,
205        hasher: &mut H,
206    ) where
207        H: Hasher<CHUNK, F> + Sync + for<'a> SerialReceiver<&'a [F]>,
208    {
209        match &mut self.touched_labels {
210            TouchedLabels::Running(touched_labels) => {
211                let final_touched_labels: Vec<_> = touched_labels
212                    .par_iter()
213                    .map(|&(address_space, label)| {
214                        let pointer = label * CHUNK as u32;
215                        let init_values = array::from_fn(|i| {
216                            *initial_memory
217                                .get(&(address_space, pointer + i as u32))
218                                .unwrap_or(&F::ZERO)
219                        });
220                        let initial_hash = hasher.hash(&init_values);
221                        let timestamped_values = final_memory.get(&(address_space, label)).unwrap();
222                        let final_hash = hasher.hash(&timestamped_values.values);
223                        FinalTouchedLabel {
224                            address_space,
225                            label,
226                            init_values,
227                            final_values: timestamped_values.values,
228                            init_hash: initial_hash,
229                            final_hash,
230                            final_timestamp: timestamped_values.timestamp,
231                        }
232                    })
233                    .collect();
234                for l in &final_touched_labels {
235                    hasher.receive(&l.init_values);
236                    hasher.receive(&l.final_values);
237                }
238                self.touched_labels = TouchedLabels::Final(final_touched_labels);
239            }
240            _ => panic!("Cannot finalize after finalization"),
241        }
242    }
243}
244
245impl<const CHUNK: usize, SC: StarkGenericConfig> Chip<SC> for PersistentBoundaryChip<Val<SC>, CHUNK>
246where
247    Val<SC>: PrimeField32,
248{
249    fn air(&self) -> AirRef<SC> {
250        Arc::new(self.air.clone())
251    }
252
253    fn generate_air_proof_input(self) -> AirProofInput<SC> {
254        let trace = {
255            let width = PersistentBoundaryCols::<Val<SC>, CHUNK>::width();
256            // Boundary AIR should always present in order to fix the AIR ID of merkle AIR.
257            let mut height = (2 * self.touched_labels.len()).next_power_of_two();
258            if let Some(mut oh) = self.overridden_height {
259                oh = oh.next_power_of_two();
260                assert!(
261                    oh >= height,
262                    "Overridden height is less than the required height"
263                );
264                height = oh;
265            }
266            let mut rows = Val::<SC>::zero_vec(height * width);
267
268            let touched_labels = match self.touched_labels {
269                TouchedLabels::Final(touched_labels) => touched_labels,
270                _ => panic!("Cannot generate trace before finalization"),
271            };
272
273            rows.par_chunks_mut(2 * width)
274                .zip(touched_labels.into_par_iter())
275                .for_each(|(row, touched_label)| {
276                    let (initial_row, final_row) = row.split_at_mut(width);
277                    *initial_row.borrow_mut() = PersistentBoundaryCols {
278                        expand_direction: Val::<SC>::ONE,
279                        address_space: Val::<SC>::from_canonical_u32(touched_label.address_space),
280                        leaf_label: Val::<SC>::from_canonical_u32(touched_label.label),
281                        values: touched_label.init_values,
282                        hash: touched_label.init_hash,
283                        timestamp: Val::<SC>::from_canonical_u32(INITIAL_TIMESTAMP),
284                    };
285
286                    *final_row.borrow_mut() = PersistentBoundaryCols {
287                        expand_direction: Val::<SC>::NEG_ONE,
288                        address_space: Val::<SC>::from_canonical_u32(touched_label.address_space),
289                        leaf_label: Val::<SC>::from_canonical_u32(touched_label.label),
290                        values: touched_label.final_values,
291                        hash: touched_label.final_hash,
292                        timestamp: Val::<SC>::from_canonical_u32(touched_label.final_timestamp),
293                    };
294                });
295            RowMajorMatrix::new(rows, width)
296        };
297        AirProofInput::simple_no_pis(trace)
298    }
299}
300
301impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for PersistentBoundaryChip<F, CHUNK> {
302    fn air_name(&self) -> String {
303        "Boundary".to_string()
304    }
305
306    fn current_trace_height(&self) -> usize {
307        2 * self.touched_labels.len()
308    }
309
310    fn trace_width(&self) -> usize {
311        PersistentBoundaryCols::<F, CHUNK>::width()
312    }
313}