openvm_circuit/system/memory/volatile/
mod.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    cmp::min,
4    sync::Arc,
5};
6
7use itertools::zip_eq;
8use openvm_circuit_primitives::{
9    is_less_than_array::{
10        IsLtArrayAuxCols, IsLtArrayIo, IsLtArraySubAir, IsLtArrayWhenTransitionAir,
11    },
12    utils::{compose, implies},
13    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
14    SubAir, TraceSubRowGenerator,
15};
16use openvm_circuit_primitives_derive::AlignedBorrow;
17use openvm_stark_backend::{
18    config::{StarkGenericConfig, Val},
19    interaction::InteractionBuilder,
20    p3_air::{Air, AirBuilder, BaseAir},
21    p3_field::{Field, FieldAlgebra, PrimeField32},
22    p3_matrix::{dense::RowMajorMatrix, Matrix},
23    p3_maybe_rayon::prelude::*,
24    prover::{cpu::CpuBackend, types::AirProvingContext},
25    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
26    Chip, ChipUsageGetter,
27};
28use static_assertions::const_assert;
29use tracing::instrument;
30
31use super::TimestampedEquipartition;
32use crate::system::memory::{
33    offline_checker::{MemoryBus, AUX_LEN},
34    MemoryAddress,
35};
36
37#[cfg(test)]
38mod tests;
39
40/// Address stored as address space, pointer
41const ADDR_ELTS: usize = 2;
42const NUM_AS_LIMBS: usize = 1;
43const_assert!(NUM_AS_LIMBS <= AUX_LEN);
44
45#[repr(C)]
46#[derive(Clone, Copy, Debug, AlignedBorrow)]
47pub struct VolatileBoundaryCols<T> {
48    pub addr_space_limbs: [T; NUM_AS_LIMBS],
49    pub pointer_limbs: [T; AUX_LEN],
50
51    pub initial_data: T,
52    pub final_data: T,
53    pub final_timestamp: T,
54
55    /// Boolean. `1` if a non-padding row with a valid touched address, `0` if it is a padding row.
56    pub is_valid: T,
57    pub addr_lt_aux: IsLtArrayAuxCols<T, ADDR_ELTS, AUX_LEN>,
58}
59
60#[derive(Clone, Debug)]
61pub struct VolatileBoundaryAir {
62    pub memory_bus: MemoryBus,
63    pub addr_lt_air: IsLtArrayWhenTransitionAir<ADDR_ELTS>,
64
65    addr_space_limb_bits: [usize; NUM_AS_LIMBS],
66    pointer_limb_bits: [usize; AUX_LEN],
67}
68
69impl VolatileBoundaryAir {
70    pub fn new(
71        memory_bus: MemoryBus,
72        addr_space_max_bits: usize,
73        pointer_max_bits: usize,
74        range_bus: VariableRangeCheckerBus,
75    ) -> Self {
76        let addr_lt_air =
77            IsLtArraySubAir::<ADDR_ELTS>::new(range_bus, addr_space_max_bits.max(pointer_max_bits))
78                .when_transition();
79        let range_max_bits = range_bus.range_max_bits;
80        let mut addr_space_limb_bits = [0; NUM_AS_LIMBS];
81        let mut bits_remaining = addr_space_max_bits;
82        for limb_bits in &mut addr_space_limb_bits {
83            *limb_bits = min(bits_remaining, range_max_bits);
84            bits_remaining -= *limb_bits;
85        }
86        assert_eq!(bits_remaining, 0, "addr_space_max_bits={addr_space_max_bits} with {NUM_AS_LIMBS} limbs exceeds range_max_bits={range_max_bits}");
87        let mut pointer_limb_bits = [0; AUX_LEN];
88        let mut bits_remaining = pointer_max_bits;
89        for limb_bits in &mut pointer_limb_bits {
90            *limb_bits = min(bits_remaining, range_max_bits);
91            bits_remaining -= *limb_bits;
92        }
93        assert_eq!(bits_remaining, 0, "pointer_max_bits={pointer_max_bits} with {AUX_LEN} limbs exceeds range_max_bits={range_max_bits}");
94        Self {
95            memory_bus,
96            addr_lt_air,
97            addr_space_limb_bits,
98            pointer_limb_bits,
99        }
100    }
101
102    pub fn range_bus(&self) -> VariableRangeCheckerBus {
103        self.addr_lt_air.0.lt.bus
104    }
105}
106
107impl<F: Field> BaseAirWithPublicValues<F> for VolatileBoundaryAir {}
108impl<F: Field> PartitionedBaseAir<F> for VolatileBoundaryAir {}
109impl<F: Field> BaseAir<F> for VolatileBoundaryAir {
110    fn width(&self) -> usize {
111        VolatileBoundaryCols::<F>::width()
112    }
113}
114
115impl<AB: InteractionBuilder> Air<AB> for VolatileBoundaryAir {
116    fn eval(&self, builder: &mut AB) {
117        let main = builder.main();
118
119        let [local, next] = [0, 1].map(|i| main.row_slice(i));
120        let local: &VolatileBoundaryCols<_> = (*local).borrow();
121        let next: &VolatileBoundaryCols<_> = (*next).borrow();
122
123        builder.assert_bool(local.is_valid);
124
125        // Ensuring all non-padding rows are at the bottom
126        builder
127            .when_transition()
128            .assert_one(implies(next.is_valid, local.is_valid));
129
130        // Range check local.addr_space_limbs to addr_space_max_bits
131        for (&limb, limb_bits) in zip_eq(&local.addr_space_limbs, self.addr_space_limb_bits) {
132            self.range_bus()
133                .range_check(limb, limb_bits)
134                .eval(builder, local.is_valid);
135        }
136        // Range check local.pointer_limbs to pointer_max_bits
137        for (&limb, limb_bits) in zip_eq(&local.pointer_limbs, self.pointer_limb_bits) {
138            self.range_bus()
139                .range_check(limb, limb_bits)
140                .eval(builder, local.is_valid);
141        }
142        let range_max_bits = self.range_bus().range_max_bits;
143        // Compose addr_space_limbs and pointer_limbs into addr_space, pointer for both local and
144        // next
145        let [addr_space, next_addr_space] = [&local.addr_space_limbs, &next.addr_space_limbs]
146            .map(|limbs| compose::<AB::Expr>(limbs, range_max_bits));
147        let [pointer, next_pointer] = [&local.pointer_limbs, &next.pointer_limbs]
148            .map(|limbs| compose::<AB::Expr>(limbs, range_max_bits));
149
150        // Assert local addr < next addr when next.is_valid
151        // This ensures the addresses in non-padding rows are all sorted
152        let lt_io = IsLtArrayIo {
153            x: [addr_space.clone(), pointer.clone()],
154            y: [next_addr_space, next_pointer],
155            out: AB::Expr::ONE,
156            count: next.is_valid.into(),
157        };
158        // N.B.: this will do range checks (but not other constraints) on the last row if the first
159        // row has is_valid = 1 due to wraparound
160        self.addr_lt_air
161            .eval(builder, (lt_io, (&local.addr_lt_aux).into()));
162
163        // Write the initial memory values at initial timestamps
164        self.memory_bus
165            .send(
166                MemoryAddress::new(addr_space.clone(), pointer.clone()),
167                vec![local.initial_data],
168                AB::Expr::ZERO,
169            )
170            .eval(builder, local.is_valid);
171
172        // Read the final memory values at last timestamps when written to
173        self.memory_bus
174            .receive(
175                MemoryAddress::new(addr_space.clone(), pointer.clone()),
176                vec![local.final_data],
177                local.final_timestamp,
178            )
179            .eval(builder, local.is_valid);
180    }
181}
182
183pub struct VolatileBoundaryChip<F> {
184    pub air: VolatileBoundaryAir,
185    range_checker: SharedVariableRangeCheckerChip,
186    overridden_height: Option<usize>,
187    pub final_memory: Option<TimestampedEquipartition<F, 1>>,
188    addr_space_max_bits: usize,
189    pointer_max_bits: usize,
190}
191
192impl<F> VolatileBoundaryChip<F> {
193    pub fn new(
194        memory_bus: MemoryBus,
195        addr_space_max_bits: usize,
196        pointer_max_bits: usize,
197        range_checker: SharedVariableRangeCheckerChip,
198    ) -> Self {
199        let range_bus = range_checker.bus();
200        Self {
201            air: VolatileBoundaryAir::new(
202                memory_bus,
203                addr_space_max_bits,
204                pointer_max_bits,
205                range_bus,
206            ),
207            range_checker,
208            overridden_height: None,
209            final_memory: None,
210            addr_space_max_bits,
211            pointer_max_bits,
212        }
213    }
214}
215
216impl<F: PrimeField32> VolatileBoundaryChip<F> {
217    pub fn set_overridden_height(&mut self, overridden_height: usize) {
218        self.overridden_height = Some(overridden_height);
219    }
220    /// Volatile memory requires the starting and final memory to be in equipartition with block
221    /// size `1`. When block size is `1`, then the `label` is the same as the address pointer.
222    #[instrument(name = "boundary_finalize", level = "debug", skip_all)]
223    pub fn finalize(&mut self, final_memory: TimestampedEquipartition<F, 1>) {
224        self.final_memory = Some(final_memory);
225    }
226}
227
228impl<RA, SC: StarkGenericConfig> Chip<RA, CpuBackend<SC>> for VolatileBoundaryChip<Val<SC>>
229where
230    Val<SC>: PrimeField32,
231{
232    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<CpuBackend<SC>> {
233        // Volatile memory requires the starting and final memory to be in equipartition with block
234        // size `1`. When block size is `1`, then the `label` is the same as the address
235        // pointer.
236        let width = self.trace_width();
237        let addr_lt_air = &self.air.addr_lt_air;
238        // TEMP[jpw]: clone
239        let final_memory = self
240            .final_memory
241            .clone()
242            .expect("Trace generation should be called after finalize");
243        let trace_height = if let Some(height) = self.overridden_height {
244            assert!(
245                height >= final_memory.len(),
246                "Overridden height is less than the required height"
247            );
248            height
249        } else {
250            final_memory.len()
251        };
252        let trace_height = trace_height.next_power_of_two();
253
254        // Collect into Vec to sort from BTreeMap and also so we can look at adjacent entries
255        let sorted_final_memory: Vec<_> = final_memory.into_par_iter().collect();
256        let memory_len = sorted_final_memory.len();
257
258        let range_checker = self.range_checker.as_ref();
259        let mut rows = Val::<SC>::zero_vec(trace_height * width);
260        rows.par_chunks_mut(width)
261            .zip(sorted_final_memory.par_iter())
262            .enumerate()
263            .for_each(|(i, (row, ((addr_space, ptr), timestamped_values)))| {
264                // `pointer` is the same as `label` since the equipartition has block size 1
265                let [data] = timestamped_values.values;
266                let row: &mut VolatileBoundaryCols<_> = row.borrow_mut();
267                range_checker.decompose(
268                    *addr_space,
269                    self.addr_space_max_bits,
270                    &mut row.addr_space_limbs,
271                );
272                range_checker.decompose(*ptr, self.pointer_max_bits, &mut row.pointer_limbs);
273                row.initial_data = Val::<SC>::ZERO;
274                row.final_data = data;
275                row.final_timestamp = Val::<SC>::from_canonical_u32(timestamped_values.timestamp);
276                row.is_valid = Val::<SC>::ONE;
277
278                // If next.is_valid == 1:
279                if i != memory_len - 1 {
280                    let (next_addr_space, next_ptr) = sorted_final_memory[i + 1].0;
281                    let mut out = Val::<SC>::ZERO;
282                    addr_lt_air.0.generate_subrow(
283                        (
284                            self.range_checker.as_ref(),
285                            &[
286                                Val::<SC>::from_canonical_u32(*addr_space),
287                                Val::<SC>::from_canonical_u32(*ptr),
288                            ],
289                            &[
290                                Val::<SC>::from_canonical_u32(next_addr_space),
291                                Val::<SC>::from_canonical_u32(next_ptr),
292                            ],
293                        ),
294                        ((&mut row.addr_lt_aux).into(), &mut out),
295                    );
296                    debug_assert_eq!(out, Val::<SC>::ONE, "Addresses are not sorted");
297                }
298            });
299        // Always do a dummy range check on the last row due to wraparound
300        if memory_len > 0 {
301            let mut out = Val::<SC>::ZERO;
302            let row: &mut VolatileBoundaryCols<_> = rows[width * (trace_height - 1)..].borrow_mut();
303            addr_lt_air.0.generate_subrow(
304                (
305                    self.range_checker.as_ref(),
306                    &[Val::<SC>::ZERO, Val::<SC>::ZERO],
307                    &[Val::<SC>::ZERO, Val::<SC>::ZERO],
308                ),
309                ((&mut row.addr_lt_aux).into(), &mut out),
310            );
311        }
312
313        let trace = Arc::new(RowMajorMatrix::new(rows, width));
314        AirProvingContext::simple_no_pis(trace)
315    }
316}
317
318impl<F: PrimeField32> ChipUsageGetter for VolatileBoundaryChip<F> {
319    fn air_name(&self) -> String {
320        "Boundary".to_string()
321    }
322
323    fn current_trace_height(&self) -> usize {
324        if let Some(final_memory) = &self.final_memory {
325            final_memory.len()
326        } else {
327            0
328        }
329    }
330
331    fn trace_width(&self) -> usize {
332        VolatileBoundaryCols::<F>::width()
333    }
334}