openvm_circuit/system/memory/volatile/
mod.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    cmp::min,
4    sync::Arc,
5};
6
7use itertools::zip_eq;
8use openvm_circuit_primitives::{
9    is_less_than_array::{
10        IsLtArrayAuxCols, IsLtArrayIo, IsLtArraySubAir, IsLtArrayWhenTransitionAir,
11    },
12    utils::{compose, implies},
13    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
14    SubAir, TraceSubRowGenerator,
15};
16use openvm_circuit_primitives_derive::AlignedBorrow;
17use openvm_stark_backend::{
18    config::{StarkGenericConfig, Val},
19    interaction::InteractionBuilder,
20    p3_air::{Air, AirBuilder, BaseAir},
21    p3_field::{Field, FieldAlgebra, PrimeField32},
22    p3_matrix::{dense::RowMajorMatrix, Matrix},
23    p3_maybe_rayon::prelude::*,
24    prover::types::AirProofInput,
25    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
26    AirRef, Chip, ChipUsageGetter,
27};
28use static_assertions::const_assert;
29
30use super::TimestampedEquipartition;
31use crate::system::memory::{
32    offline_checker::{MemoryBus, AUX_LEN},
33    MemoryAddress,
34};
35
36#[cfg(test)]
37mod tests;
38
39/// Address stored as address space, pointer
40const ADDR_ELTS: usize = 2;
41const NUM_AS_LIMBS: usize = 1;
42const_assert!(NUM_AS_LIMBS <= AUX_LEN);
43
44#[repr(C)]
45#[derive(Clone, Copy, Debug, AlignedBorrow)]
46pub struct VolatileBoundaryCols<T> {
47    pub addr_space_limbs: [T; NUM_AS_LIMBS],
48    pub pointer_limbs: [T; AUX_LEN],
49
50    pub initial_data: T,
51    pub final_data: T,
52    pub final_timestamp: T,
53
54    /// Boolean. `1` if a non-padding row with a valid touched address, `0` if it is a padding row.
55    pub is_valid: T,
56    pub addr_lt_aux: IsLtArrayAuxCols<T, ADDR_ELTS, AUX_LEN>,
57}
58
59#[derive(Clone, Debug)]
60pub struct VolatileBoundaryAir {
61    pub memory_bus: MemoryBus,
62    pub addr_lt_air: IsLtArrayWhenTransitionAir<ADDR_ELTS>,
63
64    addr_space_limb_bits: [usize; NUM_AS_LIMBS],
65    pointer_limb_bits: [usize; AUX_LEN],
66}
67
68impl VolatileBoundaryAir {
69    pub fn new(
70        memory_bus: MemoryBus,
71        addr_space_max_bits: usize,
72        pointer_max_bits: usize,
73        range_bus: VariableRangeCheckerBus,
74    ) -> Self {
75        let addr_lt_air =
76            IsLtArraySubAir::<ADDR_ELTS>::new(range_bus, addr_space_max_bits.max(pointer_max_bits))
77                .when_transition();
78        let range_max_bits = range_bus.range_max_bits;
79        let mut addr_space_limb_bits = [0; NUM_AS_LIMBS];
80        let mut bits_remaining = addr_space_max_bits;
81        for limb_bits in &mut addr_space_limb_bits {
82            *limb_bits = min(bits_remaining, range_max_bits);
83            bits_remaining -= *limb_bits;
84        }
85        assert_eq!(bits_remaining, 0, "addr_space_max_bits={addr_space_max_bits} with {NUM_AS_LIMBS} limbs exceeds range_max_bits={range_max_bits}");
86        let mut pointer_limb_bits = [0; AUX_LEN];
87        let mut bits_remaining = pointer_max_bits;
88        for limb_bits in &mut pointer_limb_bits {
89            *limb_bits = min(bits_remaining, range_max_bits);
90            bits_remaining -= *limb_bits;
91        }
92        assert_eq!(bits_remaining, 0, "pointer_max_bits={pointer_max_bits} with {AUX_LEN} limbs exceeds range_max_bits={range_max_bits}");
93        Self {
94            memory_bus,
95            addr_lt_air,
96            addr_space_limb_bits,
97            pointer_limb_bits,
98        }
99    }
100
101    pub fn range_bus(&self) -> VariableRangeCheckerBus {
102        self.addr_lt_air.0.lt.bus
103    }
104}
105
106impl<F: Field> BaseAirWithPublicValues<F> for VolatileBoundaryAir {}
107impl<F: Field> PartitionedBaseAir<F> for VolatileBoundaryAir {}
108impl<F: Field> BaseAir<F> for VolatileBoundaryAir {
109    fn width(&self) -> usize {
110        VolatileBoundaryCols::<F>::width()
111    }
112}
113
114impl<AB: InteractionBuilder> Air<AB> for VolatileBoundaryAir {
115    fn eval(&self, builder: &mut AB) {
116        let main = builder.main();
117
118        let [local, next] = [0, 1].map(|i| main.row_slice(i));
119        let local: &VolatileBoundaryCols<_> = (*local).borrow();
120        let next: &VolatileBoundaryCols<_> = (*next).borrow();
121
122        builder.assert_bool(local.is_valid);
123
124        // Ensuring all non-padding rows are at the bottom
125        builder
126            .when_transition()
127            .assert_one(implies(next.is_valid, local.is_valid));
128
129        // Range check local.addr_space_limbs to addr_space_max_bits
130        for (&limb, limb_bits) in zip_eq(&local.addr_space_limbs, self.addr_space_limb_bits) {
131            self.range_bus()
132                .range_check(limb, limb_bits)
133                .eval(builder, local.is_valid);
134        }
135        // Range check local.pointer_limbs to pointer_max_bits
136        for (&limb, limb_bits) in zip_eq(&local.pointer_limbs, self.pointer_limb_bits) {
137            self.range_bus()
138                .range_check(limb, limb_bits)
139                .eval(builder, local.is_valid);
140        }
141        let range_max_bits = self.range_bus().range_max_bits;
142        // Compose addr_space_limbs and pointer_limbs into addr_space, pointer for both local and
143        // next
144        let [addr_space, next_addr_space] = [&local.addr_space_limbs, &next.addr_space_limbs]
145            .map(|limbs| compose::<AB::Expr>(limbs, range_max_bits));
146        let [pointer, next_pointer] = [&local.pointer_limbs, &next.pointer_limbs]
147            .map(|limbs| compose::<AB::Expr>(limbs, range_max_bits));
148
149        // Assert local addr < next addr when next.is_valid
150        // This ensures the addresses in non-padding rows are all sorted
151        let lt_io = IsLtArrayIo {
152            x: [addr_space.clone(), pointer.clone()],
153            y: [next_addr_space, next_pointer],
154            out: AB::Expr::ONE,
155            count: next.is_valid.into(),
156        };
157        // N.B.: this will do range checks (but not other constraints) on the last row if the first
158        // row has is_valid = 1 due to wraparound
159        self.addr_lt_air
160            .eval(builder, (lt_io, (&local.addr_lt_aux).into()));
161
162        // Write the initial memory values at initial timestamps
163        self.memory_bus
164            .send(
165                MemoryAddress::new(addr_space.clone(), pointer.clone()),
166                vec![local.initial_data],
167                AB::Expr::ZERO,
168            )
169            .eval(builder, local.is_valid);
170
171        // Read the final memory values at last timestamps when written to
172        self.memory_bus
173            .receive(
174                MemoryAddress::new(addr_space.clone(), pointer.clone()),
175                vec![local.final_data],
176                local.final_timestamp,
177            )
178            .eval(builder, local.is_valid);
179    }
180}
181
182pub struct VolatileBoundaryChip<F> {
183    pub air: VolatileBoundaryAir,
184    range_checker: SharedVariableRangeCheckerChip,
185    overridden_height: Option<usize>,
186    final_memory: Option<TimestampedEquipartition<F, 1>>,
187    addr_space_max_bits: usize,
188    pointer_max_bits: usize,
189}
190
191impl<F> VolatileBoundaryChip<F> {
192    pub fn new(
193        memory_bus: MemoryBus,
194        addr_space_max_bits: usize,
195        pointer_max_bits: usize,
196        range_checker: SharedVariableRangeCheckerChip,
197    ) -> Self {
198        let range_bus = range_checker.bus();
199        Self {
200            air: VolatileBoundaryAir::new(
201                memory_bus,
202                addr_space_max_bits,
203                pointer_max_bits,
204                range_bus,
205            ),
206            range_checker,
207            overridden_height: None,
208            final_memory: None,
209            addr_space_max_bits,
210            pointer_max_bits,
211        }
212    }
213}
214
215impl<F: PrimeField32> VolatileBoundaryChip<F> {
216    pub fn set_overridden_height(&mut self, overridden_height: usize) {
217        self.overridden_height = Some(overridden_height);
218    }
219    /// Volatile memory requires the starting and final memory to be in equipartition with block
220    /// size `1`. When block size is `1`, then the `label` is the same as the address pointer.
221    pub fn finalize(&mut self, final_memory: TimestampedEquipartition<F, 1>) {
222        self.final_memory = Some(final_memory);
223    }
224}
225
226impl<SC: StarkGenericConfig> Chip<SC> for VolatileBoundaryChip<Val<SC>>
227where
228    Val<SC>: PrimeField32,
229{
230    fn air(&self) -> AirRef<SC> {
231        Arc::new(self.air.clone())
232    }
233
234    fn generate_air_proof_input(self) -> AirProofInput<SC> {
235        // Volatile memory requires the starting and final memory to be in equipartition with block
236        // size `1`. When block size is `1`, then the `label` is the same as the address
237        // pointer.
238        let width = self.trace_width();
239        let air = Arc::new(self.air);
240        let final_memory = self
241            .final_memory
242            .expect("Trace generation should be called after finalize");
243        let trace_height = if let Some(height) = self.overridden_height {
244            assert!(
245                height >= final_memory.len(),
246                "Overridden height is less than the required height"
247            );
248            height
249        } else {
250            final_memory.len()
251        };
252        let trace_height = trace_height.next_power_of_two();
253
254        // Collect into Vec to sort from BTreeMap and also so we can look at adjacent entries
255        let sorted_final_memory: Vec<_> = final_memory.into_par_iter().collect();
256        let memory_len = sorted_final_memory.len();
257
258        let range_checker = self.range_checker.as_ref();
259        let mut rows = Val::<SC>::zero_vec(trace_height * width);
260        rows.par_chunks_mut(width)
261            .zip(sorted_final_memory.par_iter())
262            .enumerate()
263            .for_each(|(i, (row, ((addr_space, ptr), timestamped_values)))| {
264                // `pointer` is the same as `label` since the equipartition has block size 1
265                let [data] = timestamped_values.values;
266                let row: &mut VolatileBoundaryCols<_> = row.borrow_mut();
267                range_checker.decompose(
268                    *addr_space,
269                    self.addr_space_max_bits,
270                    &mut row.addr_space_limbs,
271                );
272                range_checker.decompose(*ptr, self.pointer_max_bits, &mut row.pointer_limbs);
273                row.initial_data = Val::<SC>::ZERO;
274                row.final_data = data;
275                row.final_timestamp = Val::<SC>::from_canonical_u32(timestamped_values.timestamp);
276                row.is_valid = Val::<SC>::ONE;
277
278                // If next.is_valid == 1:
279                if i != memory_len - 1 {
280                    let (next_addr_space, next_ptr) = sorted_final_memory[i + 1].0;
281                    let mut out = Val::<SC>::ZERO;
282                    air.addr_lt_air.0.generate_subrow(
283                        (
284                            self.range_checker.as_ref(),
285                            &[
286                                Val::<SC>::from_canonical_u32(*addr_space),
287                                Val::<SC>::from_canonical_u32(*ptr),
288                            ],
289                            &[
290                                Val::<SC>::from_canonical_u32(next_addr_space),
291                                Val::<SC>::from_canonical_u32(next_ptr),
292                            ],
293                        ),
294                        ((&mut row.addr_lt_aux).into(), &mut out),
295                    );
296                    debug_assert_eq!(out, Val::<SC>::ONE, "Addresses are not sorted");
297                }
298            });
299        // Always do a dummy range check on the last row due to wraparound
300        if memory_len > 0 {
301            let mut out = Val::<SC>::ZERO;
302            let row: &mut VolatileBoundaryCols<_> = rows[width * (trace_height - 1)..].borrow_mut();
303            air.addr_lt_air.0.generate_subrow(
304                (
305                    self.range_checker.as_ref(),
306                    &[Val::<SC>::ZERO, Val::<SC>::ZERO],
307                    &[Val::<SC>::ZERO, Val::<SC>::ZERO],
308                ),
309                ((&mut row.addr_lt_aux).into(), &mut out),
310            );
311        }
312
313        let trace = RowMajorMatrix::new(rows, width);
314        AirProofInput::simple_no_pis(trace)
315    }
316}
317
318impl<F: PrimeField32> ChipUsageGetter for VolatileBoundaryChip<F> {
319    fn air_name(&self) -> String {
320        "Boundary".to_string()
321    }
322
323    fn current_trace_height(&self) -> usize {
324        if let Some(final_memory) = &self.final_memory {
325            final_memory.len()
326        } else {
327            0
328        }
329    }
330
331    fn trace_width(&self) -> usize {
332        VolatileBoundaryCols::<F>::width()
333    }
334}