1use std::{collections::BTreeMap, fmt::Debug, marker::PhantomData, sync::Arc};
3
4use getset::{Getters, MutGetters};
5use openvm_circuit_primitives::{
6 assert_less_than::{AssertLtSubAir, LessThanAuxCols},
7 var_range::{
8 SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip,
9 },
10 TraceSubRowGenerator,
11};
12use openvm_stark_backend::{
13 config::{Domain, StarkGenericConfig},
14 interaction::PermutationCheckBus,
15 p3_commit::PolynomialSpace,
16 p3_field::{Field, PrimeField32},
17 p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator},
18 p3_util::{log2_ceil_usize, log2_strict_usize},
19 prover::{cpu::CpuBackend, types::AirProvingContext},
20 Chip,
21};
22use serde::{Deserialize, Serialize};
23
24use self::interface::MemoryInterface;
25use super::{volatile::VolatileBoundaryChip, AddressMap};
26use crate::{
27 arch::{DenseRecordArena, MemoryConfig, ADDR_SPACE_OFFSET},
28 system::{
29 memory::{
30 adapter::AccessAdapterInventory,
31 dimensions::MemoryDimensions,
32 merkle::MemoryMerkleChip,
33 offline_checker::{MemoryBaseAuxCols, MemoryBridge, MemoryBus, AUX_LEN},
34 persistent::PersistentBoundaryChip,
35 },
36 poseidon2::Poseidon2PeripheryChip,
37 TouchedMemory,
38 },
39};
40
41pub mod dimensions;
42pub mod interface;
43
44pub const CHUNK: usize = 8;
45
46pub const MERKLE_AIR_OFFSET: usize = 1;
48pub const BOUNDARY_AIR_OFFSET: usize = 0;
50
51pub type MemoryImage = AddressMap;
52
53#[repr(C)]
54#[derive(Clone, Copy, Debug, PartialEq, Eq)]
55pub struct TimestampedValues<T, const N: usize> {
56 pub timestamp: u32,
57 pub values: [T; N],
58}
59
60pub type TimestampedEquipartition<F, const N: usize> = Vec<((u32, u32), TimestampedValues<F, N>)>;
65
66pub type Equipartition<F, const N: usize> = BTreeMap<(u32, u32), [F; N]>;
73
74#[derive(Getters, MutGetters)]
75pub struct MemoryController<F: Field> {
76 pub memory_bus: MemoryBus,
77 pub interface_chip: MemoryInterface<F>,
78 pub range_checker: SharedVariableRangeCheckerChip,
79 range_checker_bus: VariableRangeCheckerBus,
81 pub(crate) access_adapter_inventory: AccessAdapterInventory<F>,
82 pub(crate) hasher_chip: Option<Arc<Poseidon2PeripheryChip<F>>>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
86pub struct VolatileMemoryTraceHeights {
87 pub boundary: usize,
88 pub access_adapters: Vec<usize>,
89}
90
91impl VolatileMemoryTraceHeights {
92 pub fn from_slice(heights: &[u32]) -> Self {
94 let boundary = heights[0] as usize;
95 let access_adapters = heights[1..].iter().map(|&h| h as usize).collect();
96 Self {
97 boundary,
98 access_adapters,
99 }
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
104pub struct PersistentMemoryTraceHeights {
105 boundary: usize,
106 merkle: usize,
107 access_adapters: Vec<usize>,
108}
109impl PersistentMemoryTraceHeights {
110 pub fn from_slice(heights: &[u32]) -> Self {
112 let boundary = heights[0] as usize;
113 let merkle = heights[1] as usize;
114 let access_adapters = heights[2..].iter().map(|&h| h as usize).collect();
115 Self {
116 boundary,
117 merkle,
118 access_adapters,
119 }
120 }
121}
122
123impl<F: PrimeField32> MemoryController<F> {
124 pub(crate) fn continuation_enabled(&self) -> bool {
125 match &self.interface_chip {
126 MemoryInterface::Volatile { .. } => false,
127 MemoryInterface::Persistent { .. } => true,
128 }
129 }
130 pub fn with_volatile_memory(
131 memory_bus: MemoryBus,
132 mem_config: MemoryConfig,
133 range_checker: SharedVariableRangeCheckerChip,
134 ) -> Self {
135 let range_checker_bus = range_checker.bus();
136 assert!(mem_config.pointer_max_bits <= F::bits() - 2);
137 assert!(mem_config
138 .addr_spaces
139 .iter()
140 .all(|&space| space.num_cells <= (1 << mem_config.pointer_max_bits)));
141 assert!(mem_config.addr_space_height < F::bits() - 2);
142 let addr_space_max_bits = log2_ceil_usize(
143 (ADDR_SPACE_OFFSET + 2u32.pow(mem_config.addr_space_height as u32)) as usize,
144 );
145 Self {
146 memory_bus,
147 interface_chip: MemoryInterface::Volatile {
148 boundary_chip: VolatileBoundaryChip::new(
149 memory_bus,
150 addr_space_max_bits,
151 mem_config.pointer_max_bits,
152 range_checker.clone(),
153 ),
154 },
155 access_adapter_inventory: AccessAdapterInventory::new(
156 range_checker.clone(),
157 memory_bus,
158 mem_config,
159 ),
160 range_checker,
161 range_checker_bus,
162 hasher_chip: None,
163 }
164 }
165
166 pub fn with_persistent_memory(
170 memory_bus: MemoryBus,
171 mem_config: MemoryConfig,
172 range_checker: SharedVariableRangeCheckerChip,
173 merkle_bus: PermutationCheckBus,
174 compression_bus: PermutationCheckBus,
175 hasher_chip: Arc<Poseidon2PeripheryChip<F>>,
176 ) -> Self {
177 let memory_dims = MemoryDimensions {
178 addr_space_height: mem_config.addr_space_height,
179 address_height: mem_config.pointer_max_bits - log2_strict_usize(CHUNK),
180 };
181 let range_checker_bus = range_checker.bus();
182 let interface_chip = MemoryInterface::Persistent {
183 boundary_chip: PersistentBoundaryChip::new(
184 memory_dims,
185 memory_bus,
186 merkle_bus,
187 compression_bus,
188 ),
189 merkle_chip: MemoryMerkleChip::new(memory_dims, merkle_bus, compression_bus),
190 initial_memory: AddressMap::from_mem_config(&mem_config),
191 };
192 Self {
193 memory_bus,
194 interface_chip,
195 access_adapter_inventory: AccessAdapterInventory::new(
196 range_checker.clone(),
197 memory_bus,
198 mem_config,
199 ),
200 range_checker,
201 range_checker_bus,
202 hasher_chip: Some(hasher_chip),
203 }
204 }
205
206 pub fn memory_config(&self) -> &MemoryConfig {
207 &self.access_adapter_inventory.memory_config
208 }
209
210 pub(crate) fn set_override_trace_heights(&mut self, overridden_heights: &[u32]) {
211 match &mut self.interface_chip {
212 MemoryInterface::Volatile { boundary_chip } => {
213 let oh = VolatileMemoryTraceHeights::from_slice(overridden_heights);
214 boundary_chip.set_overridden_height(oh.boundary);
215 self.access_adapter_inventory
216 .set_override_trace_heights(oh.access_adapters);
217 }
218 MemoryInterface::Persistent {
219 boundary_chip,
220 merkle_chip,
221 ..
222 } => {
223 let oh = PersistentMemoryTraceHeights::from_slice(overridden_heights);
224 boundary_chip.set_overridden_height(oh.boundary);
225 merkle_chip.set_overridden_height(oh.merkle);
226 self.access_adapter_inventory
227 .set_override_trace_heights(oh.access_adapters);
228 }
229 }
230 }
231
232 pub(crate) fn set_initial_memory(&mut self, memory: AddressMap) {
235 match &mut self.interface_chip {
236 MemoryInterface::Volatile { .. } => {
237 }
239 MemoryInterface::Persistent { initial_memory, .. } => {
240 *initial_memory = memory;
241 }
242 }
243 }
244
245 pub fn memory_bridge(&self) -> MemoryBridge {
246 MemoryBridge::new(
247 self.memory_bus,
248 self.memory_config().timestamp_max_bits,
249 self.range_checker_bus,
250 )
251 }
252
253 pub fn helper(&self) -> SharedMemoryHelper<F> {
254 let range_bus = self.range_checker.bus();
255 SharedMemoryHelper {
256 range_checker: self.range_checker.clone(),
257 timestamp_lt_air: AssertLtSubAir::new(
258 range_bus,
259 self.memory_config().timestamp_max_bits,
260 ),
261 _marker: Default::default(),
262 }
263 }
264
265 pub fn generate_proving_ctx<SC: StarkGenericConfig>(
270 &mut self,
271 access_adapter_records: DenseRecordArena,
272 touched_memory: TouchedMemory<F>,
273 ) -> Vec<AirProvingContext<CpuBackend<SC>>>
274 where
275 Domain<SC>: PolynomialSpace<Val = F>,
276 {
277 match (&mut self.interface_chip, touched_memory) {
278 (
279 MemoryInterface::Volatile { boundary_chip },
280 TouchedMemory::Volatile(final_memory),
281 ) => {
282 boundary_chip.finalize(final_memory);
283 }
284 (
285 MemoryInterface::Persistent {
286 boundary_chip,
287 merkle_chip,
288 initial_memory,
289 },
290 TouchedMemory::Persistent(final_memory),
291 ) => {
292 let hasher = self.hasher_chip.as_ref().unwrap();
293 boundary_chip.finalize(initial_memory, &final_memory, hasher.as_ref());
294 let final_memory_values = final_memory
295 .into_par_iter()
296 .map(|(key, value)| (key, value.values))
297 .collect();
298 merkle_chip.finalize(initial_memory, &final_memory_values, hasher.as_ref());
299 }
300 _ => panic!("TouchedMemory incorrect type"),
301 }
302
303 let mut ret = Vec::new();
304
305 let access_adapters = &mut self.access_adapter_inventory;
306 access_adapters.set_arena(access_adapter_records);
307 match &mut self.interface_chip {
308 MemoryInterface::Volatile { boundary_chip } => {
309 ret.push(boundary_chip.generate_proving_ctx(()));
310 }
311 MemoryInterface::Persistent {
312 merkle_chip,
313 boundary_chip,
314 ..
315 } => {
316 debug_assert_eq!(ret.len(), BOUNDARY_AIR_OFFSET);
317 ret.push(boundary_chip.generate_proving_ctx(()));
318 debug_assert_eq!(ret.len(), MERKLE_AIR_OFFSET);
319 ret.push(merkle_chip.generate_proving_ctx());
320 }
321 }
322 ret.extend(access_adapters.generate_proving_ctx());
323 ret
324 }
325
326 pub fn num_airs(&self) -> usize {
328 let mut num_airs = 1;
329 if self.continuation_enabled() {
330 num_airs += 1;
331 }
332 num_airs += self.access_adapter_inventory.num_access_adapters();
333 num_airs
334 }
335}
336
337#[derive(Clone)]
339pub struct SharedMemoryHelper<F> {
340 pub(crate) range_checker: SharedVariableRangeCheckerChip,
341 pub(crate) timestamp_lt_air: AssertLtSubAir,
342 pub(crate) _marker: PhantomData<F>,
343}
344
345impl<F> SharedMemoryHelper<F> {
346 pub fn new(range_checker: SharedVariableRangeCheckerChip, timestamp_max_bits: usize) -> Self {
347 let timestamp_lt_air = AssertLtSubAir::new(range_checker.bus(), timestamp_max_bits);
348 Self {
349 range_checker,
350 timestamp_lt_air,
351 _marker: PhantomData,
352 }
353 }
354}
355
356pub struct MemoryAuxColsFactory<'a, F> {
359 pub(crate) range_checker: &'a VariableRangeCheckerChip,
360 pub(crate) timestamp_lt_air: AssertLtSubAir,
361 pub(crate) _marker: PhantomData<F>,
362}
363
364impl<F: PrimeField32> MemoryAuxColsFactory<'_, F> {
365 pub fn fill(&self, prev_timestamp: u32, timestamp: u32, buffer: &mut MemoryBaseAuxCols<F>) {
367 self.generate_timestamp_lt(prev_timestamp, timestamp, &mut buffer.timestamp_lt_aux);
368 buffer.prev_timestamp = F::from_canonical_u32(prev_timestamp);
371 }
372
373 pub fn fill_zero(&self, buffer: &mut MemoryBaseAuxCols<F>) {
376 *buffer = unsafe { std::mem::zeroed() };
377 }
378
379 fn generate_timestamp_lt(
380 &self,
381 prev_timestamp: u32,
382 timestamp: u32,
383 buffer: &mut LessThanAuxCols<F, AUX_LEN>,
384 ) {
385 debug_assert!(
386 prev_timestamp < timestamp,
387 "prev_timestamp {prev_timestamp} >= timestamp {timestamp}"
388 );
389 self.timestamp_lt_air.generate_subrow(
390 (self.range_checker, prev_timestamp, timestamp),
391 &mut buffer.lower_decomp,
392 );
393 }
394}
395
396impl<F> SharedMemoryHelper<F> {
397 pub fn as_borrowed(&self) -> MemoryAuxColsFactory<'_, F> {
398 MemoryAuxColsFactory {
399 range_checker: self.range_checker.as_ref(),
400 timestamp_lt_air: self.timestamp_lt_air,
401 _marker: PhantomData,
402 }
403 }
404}