1use std::{
2 borrow::{Borrow, BorrowMut},
3 cmp::min,
4 sync::Arc,
5};
6
7use itertools::zip_eq;
8use openvm_circuit_primitives::{
9 is_less_than_array::{
10 IsLtArrayAuxCols, IsLtArrayIo, IsLtArraySubAir, IsLtArrayWhenTransitionAir,
11 },
12 utils::{compose, implies},
13 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
14 SubAir, TraceSubRowGenerator,
15};
16use openvm_circuit_primitives_derive::AlignedBorrow;
17use openvm_stark_backend::{
18 config::{StarkGenericConfig, Val},
19 interaction::InteractionBuilder,
20 p3_air::{Air, AirBuilder, BaseAir},
21 p3_field::{Field, FieldAlgebra, PrimeField32},
22 p3_matrix::{dense::RowMajorMatrix, Matrix},
23 p3_maybe_rayon::prelude::*,
24 prover::types::AirProofInput,
25 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
26 AirRef, Chip, ChipUsageGetter,
27};
28use static_assertions::const_assert;
29
30use super::TimestampedEquipartition;
31use crate::system::memory::{
32 offline_checker::{MemoryBus, AUX_LEN},
33 MemoryAddress,
34};
35
36#[cfg(test)]
37mod tests;
38
39const ADDR_ELTS: usize = 2;
41const NUM_AS_LIMBS: usize = 1;
42const_assert!(NUM_AS_LIMBS <= AUX_LEN);
43
44#[repr(C)]
45#[derive(Clone, Copy, Debug, AlignedBorrow)]
46pub struct VolatileBoundaryCols<T> {
47 pub addr_space_limbs: [T; NUM_AS_LIMBS],
48 pub pointer_limbs: [T; AUX_LEN],
49
50 pub initial_data: T,
51 pub final_data: T,
52 pub final_timestamp: T,
53
54 pub is_valid: T,
56 pub addr_lt_aux: IsLtArrayAuxCols<T, ADDR_ELTS, AUX_LEN>,
57}
58
59#[derive(Clone, Debug)]
60pub struct VolatileBoundaryAir {
61 pub memory_bus: MemoryBus,
62 pub addr_lt_air: IsLtArrayWhenTransitionAir<ADDR_ELTS>,
63
64 addr_space_limb_bits: [usize; NUM_AS_LIMBS],
65 pointer_limb_bits: [usize; AUX_LEN],
66}
67
68impl VolatileBoundaryAir {
69 pub fn new(
70 memory_bus: MemoryBus,
71 addr_space_max_bits: usize,
72 pointer_max_bits: usize,
73 range_bus: VariableRangeCheckerBus,
74 ) -> Self {
75 let addr_lt_air =
76 IsLtArraySubAir::<ADDR_ELTS>::new(range_bus, addr_space_max_bits.max(pointer_max_bits))
77 .when_transition();
78 let range_max_bits = range_bus.range_max_bits;
79 let mut addr_space_limb_bits = [0; NUM_AS_LIMBS];
80 let mut bits_remaining = addr_space_max_bits;
81 for limb_bits in &mut addr_space_limb_bits {
82 *limb_bits = min(bits_remaining, range_max_bits);
83 bits_remaining -= *limb_bits;
84 }
85 assert_eq!(bits_remaining, 0, "addr_space_max_bits={addr_space_max_bits} with {NUM_AS_LIMBS} limbs exceeds range_max_bits={range_max_bits}");
86 let mut pointer_limb_bits = [0; AUX_LEN];
87 let mut bits_remaining = pointer_max_bits;
88 for limb_bits in &mut pointer_limb_bits {
89 *limb_bits = min(bits_remaining, range_max_bits);
90 bits_remaining -= *limb_bits;
91 }
92 assert_eq!(bits_remaining, 0, "pointer_max_bits={pointer_max_bits} with {AUX_LEN} limbs exceeds range_max_bits={range_max_bits}");
93 Self {
94 memory_bus,
95 addr_lt_air,
96 addr_space_limb_bits,
97 pointer_limb_bits,
98 }
99 }
100
101 pub fn range_bus(&self) -> VariableRangeCheckerBus {
102 self.addr_lt_air.0.lt.bus
103 }
104}
105
106impl<F: Field> BaseAirWithPublicValues<F> for VolatileBoundaryAir {}
107impl<F: Field> PartitionedBaseAir<F> for VolatileBoundaryAir {}
108impl<F: Field> BaseAir<F> for VolatileBoundaryAir {
109 fn width(&self) -> usize {
110 VolatileBoundaryCols::<F>::width()
111 }
112}
113
114impl<AB: InteractionBuilder> Air<AB> for VolatileBoundaryAir {
115 fn eval(&self, builder: &mut AB) {
116 let main = builder.main();
117
118 let [local, next] = [0, 1].map(|i| main.row_slice(i));
119 let local: &VolatileBoundaryCols<_> = (*local).borrow();
120 let next: &VolatileBoundaryCols<_> = (*next).borrow();
121
122 builder.assert_bool(local.is_valid);
123
124 builder
126 .when_transition()
127 .assert_one(implies(next.is_valid, local.is_valid));
128
129 for (&limb, limb_bits) in zip_eq(&local.addr_space_limbs, self.addr_space_limb_bits) {
131 self.range_bus()
132 .range_check(limb, limb_bits)
133 .eval(builder, local.is_valid);
134 }
135 for (&limb, limb_bits) in zip_eq(&local.pointer_limbs, self.pointer_limb_bits) {
137 self.range_bus()
138 .range_check(limb, limb_bits)
139 .eval(builder, local.is_valid);
140 }
141 let range_max_bits = self.range_bus().range_max_bits;
142 let [addr_space, next_addr_space] = [&local.addr_space_limbs, &next.addr_space_limbs]
145 .map(|limbs| compose::<AB::Expr>(limbs, range_max_bits));
146 let [pointer, next_pointer] = [&local.pointer_limbs, &next.pointer_limbs]
147 .map(|limbs| compose::<AB::Expr>(limbs, range_max_bits));
148
149 let lt_io = IsLtArrayIo {
152 x: [addr_space.clone(), pointer.clone()],
153 y: [next_addr_space, next_pointer],
154 out: AB::Expr::ONE,
155 count: next.is_valid.into(),
156 };
157 self.addr_lt_air
160 .eval(builder, (lt_io, (&local.addr_lt_aux).into()));
161
162 self.memory_bus
164 .send(
165 MemoryAddress::new(addr_space.clone(), pointer.clone()),
166 vec![local.initial_data],
167 AB::Expr::ZERO,
168 )
169 .eval(builder, local.is_valid);
170
171 self.memory_bus
173 .receive(
174 MemoryAddress::new(addr_space.clone(), pointer.clone()),
175 vec![local.final_data],
176 local.final_timestamp,
177 )
178 .eval(builder, local.is_valid);
179 }
180}
181
182pub struct VolatileBoundaryChip<F> {
183 pub air: VolatileBoundaryAir,
184 range_checker: SharedVariableRangeCheckerChip,
185 overridden_height: Option<usize>,
186 final_memory: Option<TimestampedEquipartition<F, 1>>,
187 addr_space_max_bits: usize,
188 pointer_max_bits: usize,
189}
190
191impl<F> VolatileBoundaryChip<F> {
192 pub fn new(
193 memory_bus: MemoryBus,
194 addr_space_max_bits: usize,
195 pointer_max_bits: usize,
196 range_checker: SharedVariableRangeCheckerChip,
197 ) -> Self {
198 let range_bus = range_checker.bus();
199 Self {
200 air: VolatileBoundaryAir::new(
201 memory_bus,
202 addr_space_max_bits,
203 pointer_max_bits,
204 range_bus,
205 ),
206 range_checker,
207 overridden_height: None,
208 final_memory: None,
209 addr_space_max_bits,
210 pointer_max_bits,
211 }
212 }
213}
214
215impl<F: PrimeField32> VolatileBoundaryChip<F> {
216 pub fn set_overridden_height(&mut self, overridden_height: usize) {
217 self.overridden_height = Some(overridden_height);
218 }
219 pub fn finalize(&mut self, final_memory: TimestampedEquipartition<F, 1>) {
222 self.final_memory = Some(final_memory);
223 }
224}
225
226impl<SC: StarkGenericConfig> Chip<SC> for VolatileBoundaryChip<Val<SC>>
227where
228 Val<SC>: PrimeField32,
229{
230 fn air(&self) -> AirRef<SC> {
231 Arc::new(self.air.clone())
232 }
233
234 fn generate_air_proof_input(self) -> AirProofInput<SC> {
235 let width = self.trace_width();
239 let air = Arc::new(self.air);
240 let final_memory = self
241 .final_memory
242 .expect("Trace generation should be called after finalize");
243 let trace_height = if let Some(height) = self.overridden_height {
244 assert!(
245 height >= final_memory.len(),
246 "Overridden height is less than the required height"
247 );
248 height
249 } else {
250 final_memory.len()
251 };
252 let trace_height = trace_height.next_power_of_two();
253
254 let sorted_final_memory: Vec<_> = final_memory.into_par_iter().collect();
256 let memory_len = sorted_final_memory.len();
257
258 let range_checker = self.range_checker.as_ref();
259 let mut rows = Val::<SC>::zero_vec(trace_height * width);
260 rows.par_chunks_mut(width)
261 .zip(sorted_final_memory.par_iter())
262 .enumerate()
263 .for_each(|(i, (row, ((addr_space, ptr), timestamped_values)))| {
264 let [data] = timestamped_values.values;
266 let row: &mut VolatileBoundaryCols<_> = row.borrow_mut();
267 range_checker.decompose(
268 *addr_space,
269 self.addr_space_max_bits,
270 &mut row.addr_space_limbs,
271 );
272 range_checker.decompose(*ptr, self.pointer_max_bits, &mut row.pointer_limbs);
273 row.initial_data = Val::<SC>::ZERO;
274 row.final_data = data;
275 row.final_timestamp = Val::<SC>::from_canonical_u32(timestamped_values.timestamp);
276 row.is_valid = Val::<SC>::ONE;
277
278 if i != memory_len - 1 {
280 let (next_addr_space, next_ptr) = sorted_final_memory[i + 1].0;
281 let mut out = Val::<SC>::ZERO;
282 air.addr_lt_air.0.generate_subrow(
283 (
284 self.range_checker.as_ref(),
285 &[
286 Val::<SC>::from_canonical_u32(*addr_space),
287 Val::<SC>::from_canonical_u32(*ptr),
288 ],
289 &[
290 Val::<SC>::from_canonical_u32(next_addr_space),
291 Val::<SC>::from_canonical_u32(next_ptr),
292 ],
293 ),
294 ((&mut row.addr_lt_aux).into(), &mut out),
295 );
296 debug_assert_eq!(out, Val::<SC>::ONE, "Addresses are not sorted");
297 }
298 });
299 if memory_len > 0 {
301 let mut out = Val::<SC>::ZERO;
302 let row: &mut VolatileBoundaryCols<_> = rows[width * (trace_height - 1)..].borrow_mut();
303 air.addr_lt_air.0.generate_subrow(
304 (
305 self.range_checker.as_ref(),
306 &[Val::<SC>::ZERO, Val::<SC>::ZERO],
307 &[Val::<SC>::ZERO, Val::<SC>::ZERO],
308 ),
309 ((&mut row.addr_lt_aux).into(), &mut out),
310 );
311 }
312
313 let trace = RowMajorMatrix::new(rows, width);
314 AirProofInput::simple_no_pis(trace)
315 }
316}
317
318impl<F: PrimeField32> ChipUsageGetter for VolatileBoundaryChip<F> {
319 fn air_name(&self) -> String {
320 "Boundary".to_string()
321 }
322
323 fn current_trace_height(&self) -> usize {
324 if let Some(final_memory) = &self.final_memory {
325 final_memory.len()
326 } else {
327 0
328 }
329 }
330
331 fn trace_width(&self) -> usize {
332 VolatileBoundaryCols::<F>::width()
333 }
334}