1use std::{
2 borrow::BorrowMut,
3 io::Cursor,
4 marker::PhantomData,
5 ptr::{copy_nonoverlapping, slice_from_raw_parts_mut},
6};
7
8use openvm_circuit_primitives::utils::next_power_of_two_or_zero;
9use openvm_stark_backend::{
10 p3_field::{Field, PrimeField32},
11 p3_matrix::dense::RowMajorMatrix,
12};
13
14pub trait Arena {
15 fn with_capacity(height: usize, width: usize) -> Self;
17
18 fn is_empty(&self) -> bool;
19
20 #[cfg(feature = "metrics")]
25 fn current_trace_height(&self) -> usize {
26 0
27 }
28}
29
30pub trait RecordArena<'a, Layout, RecordMut> {
34 fn alloc(&'a mut self, layout: Layout) -> RecordMut;
38}
39
40pub trait RowMajorMatrixArena<F>: Arena {
42 fn set_capacity(&mut self, trace_height: usize);
44 fn width(&self) -> usize;
45 fn trace_offset(&self) -> usize;
46 fn into_matrix(self) -> RowMajorMatrix<F>;
47}
48
49pub trait SizedRecord<Layout> {
52 fn size(layout: &Layout) -> usize;
55 fn alignment(layout: &Layout) -> usize;
58}
59
60impl<Layout, Record> SizedRecord<Layout> for &mut Record
61where
62 Record: Sized,
63{
64 fn size(_layout: &Layout) -> usize {
65 size_of::<Record>()
66 }
67
68 fn alignment(_layout: &Layout) -> usize {
69 align_of::<Record>()
70 }
71}
72
73#[derive(Default)]
76pub struct MatrixRecordArena<F> {
77 pub trace_buffer: Vec<F>,
78 pub width: usize,
79 pub trace_offset: usize,
80 pub(super) allow_truncate: bool,
85}
86
87impl<F: Field> MatrixRecordArena<F> {
88 pub fn alloc_single_row(&mut self) -> &mut [u8] {
89 self.alloc_buffer(1)
90 }
91
92 pub fn alloc_buffer(&mut self, num_rows: usize) -> &mut [u8] {
93 let start = self.trace_offset;
94 self.trace_offset += num_rows * self.width;
95 let row_slice = &mut self.trace_buffer[start..self.trace_offset];
96 let size = size_of_val(row_slice);
97 let ptr = row_slice as *mut [F] as *mut u8;
98 unsafe { &mut *std::ptr::slice_from_raw_parts_mut(ptr, size) }
103 }
104
105 pub fn force_matrix_dimensions(&mut self) {
106 self.allow_truncate = false;
107 }
108}
109
110impl<F: Field> Arena for MatrixRecordArena<F> {
111 fn with_capacity(height: usize, width: usize) -> Self {
112 let height = next_power_of_two_or_zero(height);
113 let trace_buffer = F::zero_vec(height * width);
114 Self {
115 trace_buffer,
116 width,
117 trace_offset: 0,
118 allow_truncate: true,
119 }
120 }
121
122 fn is_empty(&self) -> bool {
123 self.trace_offset == 0
124 }
125
126 #[cfg(feature = "metrics")]
127 fn current_trace_height(&self) -> usize {
128 self.trace_offset / self.width
129 }
130}
131
132impl<F: Field> RowMajorMatrixArena<F> for MatrixRecordArena<F> {
133 fn set_capacity(&mut self, trace_height: usize) {
134 let size = trace_height * self.width;
135 self.trace_buffer.resize(size, F::ZERO);
137 }
138
139 fn width(&self) -> usize {
140 self.width
141 }
142
143 fn trace_offset(&self) -> usize {
144 self.trace_offset
145 }
146
147 fn into_matrix(mut self) -> RowMajorMatrix<F> {
148 let width = self.width();
149 assert_eq!(self.trace_offset() % width, 0);
150 let rows_used = self.trace_offset() / width;
151 let height = next_power_of_two_or_zero(rows_used);
152 assert!(height.checked_mul(width).unwrap() <= self.trace_buffer.len());
154 if self.allow_truncate {
155 self.trace_buffer.truncate(height * width);
156 } else {
157 assert_eq!(self.trace_buffer.len() % width, 0);
158 let height = self.trace_buffer.len() / width;
159 assert!(height.is_power_of_two() || height == 0);
160 }
161 RowMajorMatrix::new(self.trace_buffer, self.width)
162 }
163}
164
165pub struct DenseRecordArena {
166 pub records_buffer: Cursor<Vec<u8>>,
167}
168
169const MAX_ALIGNMENT: usize = 32;
170
171impl DenseRecordArena {
172 pub fn with_byte_capacity(size_bytes: usize) -> Self {
174 let buffer = vec![0; size_bytes + MAX_ALIGNMENT];
175 let offset = (MAX_ALIGNMENT - (buffer.as_ptr() as usize % MAX_ALIGNMENT)) % MAX_ALIGNMENT;
176 let mut cursor = Cursor::new(buffer);
177 cursor.set_position(offset as u64);
178 Self {
179 records_buffer: cursor,
180 }
181 }
182
183 pub fn set_byte_capacity(&mut self, size_bytes: usize) {
184 let buffer = vec![0; size_bytes + MAX_ALIGNMENT];
185 let offset = (MAX_ALIGNMENT - (buffer.as_ptr() as usize % MAX_ALIGNMENT)) % MAX_ALIGNMENT;
186 let mut cursor = Cursor::new(buffer);
187 cursor.set_position(offset as u64);
188 self.records_buffer = cursor;
189 }
190
191 pub fn capacity(&self) -> usize {
195 self.records_buffer.get_ref().len()
196 }
197
198 pub fn alloc_bytes<'a>(&mut self, count: usize) -> &'a mut [u8] {
200 let begin = self.records_buffer.position();
201 debug_assert!(
202 begin as usize + count <= self.records_buffer.get_ref().len(),
203 "failed to allocate {count} bytes from {begin} when the capacity is {}",
204 self.records_buffer.get_ref().len()
205 );
206 self.records_buffer.set_position(begin + count as u64);
207 unsafe {
211 std::slice::from_raw_parts_mut(
212 self.records_buffer
213 .get_mut()
214 .as_mut_ptr()
215 .add(begin as usize),
216 count,
217 )
218 }
219 }
220
221 pub fn allocated(&self) -> &[u8] {
222 let size = self.records_buffer.position() as usize;
223 let offset = (MAX_ALIGNMENT
224 - (self.records_buffer.get_ref().as_ptr() as usize % MAX_ALIGNMENT))
225 % MAX_ALIGNMENT;
226 &self.records_buffer.get_ref()[offset..size]
227 }
228
229 pub fn allocated_mut(&mut self) -> &mut [u8] {
230 let size = self.records_buffer.position() as usize;
231 let offset = (MAX_ALIGNMENT
232 - (self.records_buffer.get_ref().as_ptr() as usize % MAX_ALIGNMENT))
233 % MAX_ALIGNMENT;
234 &mut self.records_buffer.get_mut()[offset..size]
235 }
236
237 pub fn align_to(&mut self, alignment: usize) {
238 debug_assert!(MAX_ALIGNMENT % alignment == 0);
239 let offset =
240 (alignment - (self.records_buffer.get_ref().as_ptr() as usize % alignment)) % alignment;
241 self.records_buffer.set_position(offset as u64);
242 }
243
244 pub fn get_record_seeker<R, L>(&mut self) -> RecordSeeker<DenseRecordArena, R, L> {
246 RecordSeeker::new(self.allocated_mut())
247 }
248}
249
250impl Arena for DenseRecordArena {
251 fn with_capacity(height: usize, width: usize) -> Self {
253 let size_bytes = height * (width * size_of::<u32>());
254 Self::with_byte_capacity(size_bytes)
255 }
256
257 fn is_empty(&self) -> bool {
258 self.allocated().is_empty()
259 }
260}
261
262pub unsafe fn get_record_from_slice<'a, T, F, L>(slice: &mut &'a mut [F], layout: L) -> T
270where
271 [u8]: CustomBorrow<'a, T, L>,
272{
273 let record_buffer =
275 &mut *slice_from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, size_of_val::<[F]>(*slice));
276 let record: T = record_buffer.custom_borrow(layout);
277 record
278}
279
280pub trait CustomBorrow<'a, T, L> {
283 fn custom_borrow(&'a mut self, layout: L) -> T;
284
285 unsafe fn extract_layout(&self) -> L;
294}
295
296pub struct RecordSeeker<'a, RA, RecordMut, Layout> {
298 pub buffer: &'a mut [u8], _phantom: PhantomData<(RA, RecordMut, Layout)>,
300}
301
302impl<'a, RA, RecordMut, Layout> RecordSeeker<'a, RA, RecordMut, Layout> {
303 pub fn new(record_buffer: &'a mut [u8]) -> Self {
304 Self {
305 buffer: record_buffer,
306 _phantom: PhantomData,
307 }
308 }
309}
310
311impl<'a, R, M> RecordSeeker<'a, DenseRecordArena, R, MultiRowLayout<M>>
314where
315 [u8]: CustomBorrow<'a, R, MultiRowLayout<M>>,
316 R: SizedRecord<MultiRowLayout<M>>,
317 M: MultiRowMetadata + Clone,
318{
319 pub fn get_layout_at(offset: &mut usize, buffer: &[u8]) -> MultiRowLayout<M> {
322 let buffer = &buffer[*offset..];
323 unsafe { buffer.extract_layout() }
325 }
326
327 pub fn get_record_at(offset: &mut usize, buffer: &'a mut [u8]) -> R {
330 let layout = Self::get_layout_at(offset, buffer);
331 let buffer = &mut buffer[*offset..];
332 let record_size = R::size(&layout);
333 let record_alignment = R::alignment(&layout);
334 let aligned_record_size = record_size.next_multiple_of(record_alignment);
335 let record: R = buffer.custom_borrow(layout);
336 *offset += aligned_record_size;
337 record
338 }
339
340 pub fn extract_records(&'a mut self) -> Vec<R> {
342 let mut records = Vec::new();
343 let len = self.buffer.len();
344 let buff = &mut self.buffer[..];
345 let mut offset = 0;
346 while offset < len {
347 let record: R = {
348 let buff = unsafe { &mut *slice_from_raw_parts_mut(buff.as_mut_ptr(), len) };
353 Self::get_record_at(&mut offset, buff)
354 };
355 records.push(record);
356 }
357 records
358 }
359
360 pub fn transfer_to_matrix_arena<F: PrimeField32>(
362 &'a mut self,
363 arena: &mut MatrixRecordArena<F>,
364 ) {
365 let len = self.buffer.len();
366 arena.trace_offset = 0;
367 let mut offset = 0;
368 while offset < len {
369 let layout = Self::get_layout_at(&mut offset, self.buffer);
370 let record_size = R::size(&layout);
371 let record_alignment = R::alignment(&layout);
372 let aligned_record_size = record_size.next_multiple_of(record_alignment);
373 let src_ptr = unsafe { self.buffer.as_ptr().add(offset) };
375 let dst_ptr = arena
376 .alloc_buffer(layout.metadata.get_num_rows())
377 .as_mut_ptr();
378 unsafe { copy_nonoverlapping(src_ptr, dst_ptr, aligned_record_size) };
382 offset += aligned_record_size;
383 }
384 }
385}
386
387impl<'a, A, C, M> RecordSeeker<'a, DenseRecordArena, (A, C), AdapterCoreLayout<M>>
391where
392 [u8]: CustomBorrow<'a, A, AdapterCoreLayout<M>> + CustomBorrow<'a, C, AdapterCoreLayout<M>>,
393 A: SizedRecord<AdapterCoreLayout<M>>,
394 C: SizedRecord<AdapterCoreLayout<M>>,
395 M: AdapterCoreMetadata + Clone,
396{
397 pub fn get_aligned_sizes(layout: &AdapterCoreLayout<M>) -> (usize, usize) {
399 let adapter_alignment = A::alignment(layout);
400 let core_alignment = C::alignment(layout);
401 let adapter_size = A::size(layout);
402 let aligned_adapter_size = adapter_size.next_multiple_of(core_alignment);
403 let core_size = C::size(layout);
404 let aligned_core_size = (aligned_adapter_size + core_size)
405 .next_multiple_of(adapter_alignment)
406 - aligned_adapter_size;
407 (aligned_adapter_size, aligned_core_size)
408 }
409
410 pub fn get_aligned_record_size(layout: &AdapterCoreLayout<M>) -> usize {
412 let (adapter_size, core_size) = Self::get_aligned_sizes(layout);
413 adapter_size + core_size
414 }
415
416 pub fn get_record_at(
419 offset: &mut usize,
420 buffer: &'a mut [u8],
421 layout: AdapterCoreLayout<M>,
422 ) -> (A, C) {
423 let buffer = &mut buffer[*offset..];
424 let (adapter_size, core_size) = Self::get_aligned_sizes(&layout);
425 let (adapter_buffer, core_buffer) = unsafe { buffer.split_at_mut_unchecked(adapter_size) };
429 let adapter_record: A = adapter_buffer.custom_borrow(layout.clone());
430 let core_record: C = core_buffer.custom_borrow(layout);
431 *offset += adapter_size + core_size;
432 (adapter_record, core_record)
433 }
434
435 pub fn extract_records(&'a mut self, layout: AdapterCoreLayout<M>) -> Vec<(A, C)> {
437 let mut records = Vec::new();
438 let len = self.buffer.len();
439 let buff = &mut self.buffer[..];
440 let mut offset = 0;
441 while offset < len {
442 let record: (A, C) = {
443 let buff = unsafe { &mut *slice_from_raw_parts_mut(buff.as_mut_ptr(), len) };
448 Self::get_record_at(&mut offset, buff, layout.clone())
449 };
450 records.push(record);
451 }
452 records
453 }
454
455 pub fn transfer_to_matrix_arena<F: PrimeField32>(
457 &'a mut self,
458 arena: &mut MatrixRecordArena<F>,
459 layout: AdapterCoreLayout<M>,
460 ) {
461 let len = self.buffer.len();
462 arena.trace_offset = 0;
463 let mut offset = 0;
464 let (adapter_size, core_size) = Self::get_aligned_sizes(&layout);
465 while offset < len {
466 let dst_buffer = arena.alloc_single_row();
467 let (adapter_buf, core_buf) =
471 unsafe { dst_buffer.split_at_mut_unchecked(M::get_adapter_width()) };
472 unsafe {
473 let src_ptr = self.buffer.as_ptr().add(offset);
474 copy_nonoverlapping(src_ptr, adapter_buf.as_mut_ptr(), adapter_size);
475 copy_nonoverlapping(src_ptr.add(adapter_size), core_buf.as_mut_ptr(), core_size);
476 }
477 offset += adapter_size + core_size;
478 }
479 }
480}
481
482#[derive(Debug, Clone, Default, derive_new::new)]
491pub struct MultiRowLayout<M> {
492 pub metadata: M,
493}
494
495pub trait MultiRowMetadata {
497 fn get_num_rows(&self) -> usize;
498}
499
500#[derive(Debug, Clone, Default, derive_new::new)]
502pub struct EmptyMultiRowMetadata {}
503
504impl MultiRowMetadata for EmptyMultiRowMetadata {
505 #[inline(always)]
506 fn get_num_rows(&self) -> usize {
507 1
508 }
509}
510
511pub type EmptyMultiRowLayout = MultiRowLayout<EmptyMultiRowMetadata>;
513
514impl<'a, T: Sized, L: Default> CustomBorrow<'a, &'a mut T, L> for [u8]
517where
518 [u8]: BorrowMut<T>,
519{
520 fn custom_borrow(&'a mut self, _layout: L) -> &'a mut T {
521 self.borrow_mut()
522 }
523
524 unsafe fn extract_layout(&self) -> L {
525 L::default()
526 }
527}
528
529impl<'a, F: Field, M: MultiRowMetadata, R> RecordArena<'a, MultiRowLayout<M>, R>
532 for MatrixRecordArena<F>
533where
534 [u8]: CustomBorrow<'a, R, MultiRowLayout<M>>,
535{
536 fn alloc(&'a mut self, layout: MultiRowLayout<M>) -> R {
537 let buffer = self.alloc_buffer(layout.metadata.get_num_rows());
538 let record: R = buffer.custom_borrow(layout);
539 record
540 }
541}
542
543impl<'a, R, M> RecordArena<'a, MultiRowLayout<M>, R> for DenseRecordArena
546where
547 [u8]: CustomBorrow<'a, R, MultiRowLayout<M>>,
548 R: SizedRecord<MultiRowLayout<M>>,
549{
550 fn alloc(&'a mut self, layout: MultiRowLayout<M>) -> R {
551 let record_size = R::size(&layout);
552 let record_alignment = R::alignment(&layout);
553 let aligned_record_size = record_size.next_multiple_of(record_alignment);
554 let buffer = self.alloc_bytes(aligned_record_size);
555 let record: R = buffer.custom_borrow(layout);
556 record
557 }
558}
559
560#[derive(Debug, Clone, Default)]
570pub struct AdapterCoreLayout<M> {
571 pub metadata: M,
572}
573
574pub trait AdapterCoreMetadata {
577 fn get_adapter_width() -> usize;
578}
579
580impl<M> AdapterCoreLayout<M> {
581 pub fn new() -> Self
582 where
583 M: Default,
584 {
585 Self::default()
586 }
587
588 pub fn with_metadata(metadata: M) -> Self {
589 Self { metadata }
590 }
591}
592
593pub struct AdapterCoreEmptyMetadata<F, AS> {
597 _phantom: PhantomData<(F, AS)>,
598}
599
600impl<F, AS> Clone for AdapterCoreEmptyMetadata<F, AS> {
601 fn clone(&self) -> Self {
602 Self {
603 _phantom: PhantomData,
604 }
605 }
606}
607
608impl<F, AS> AdapterCoreEmptyMetadata<F, AS> {
609 pub fn new() -> Self {
610 Self {
611 _phantom: PhantomData,
612 }
613 }
614}
615
616impl<F, AS> Default for AdapterCoreEmptyMetadata<F, AS> {
617 fn default() -> Self {
618 Self {
619 _phantom: PhantomData,
620 }
621 }
622}
623
624impl<F, AS> AdapterCoreMetadata for AdapterCoreEmptyMetadata<F, AS>
625where
626 AS: super::AdapterTraceExecutor<F>,
627{
628 #[inline(always)]
629 fn get_adapter_width() -> usize {
630 AS::WIDTH * size_of::<F>()
631 }
632}
633
634pub type EmptyAdapterCoreLayout<F, AS> = AdapterCoreLayout<AdapterCoreEmptyMetadata<F, AS>>;
637
638impl<'a, F: Field, A, C, M: AdapterCoreMetadata> RecordArena<'a, AdapterCoreLayout<M>, (A, C)>
641 for MatrixRecordArena<F>
642where
643 [u8]: CustomBorrow<'a, A, AdapterCoreLayout<M>> + CustomBorrow<'a, C, AdapterCoreLayout<M>>,
644 M: Clone,
645{
646 fn alloc(&'a mut self, layout: AdapterCoreLayout<M>) -> (A, C) {
647 let adapter_width = M::get_adapter_width();
648 let buffer = self.alloc_single_row();
649 let (adapter_buffer, core_buffer) = unsafe { buffer.split_at_mut_unchecked(adapter_width) };
654
655 let adapter_record: A = adapter_buffer.custom_borrow(layout.clone());
656 let core_record: C = core_buffer.custom_borrow(layout);
657
658 (adapter_record, core_record)
659 }
660}
661
662impl<'a, A, C, M> RecordArena<'a, AdapterCoreLayout<M>, (A, C)> for DenseRecordArena
665where
666 [u8]: CustomBorrow<'a, A, AdapterCoreLayout<M>> + CustomBorrow<'a, C, AdapterCoreLayout<M>>,
667 M: Clone,
668 A: SizedRecord<AdapterCoreLayout<M>>,
669 C: SizedRecord<AdapterCoreLayout<M>>,
670{
671 fn alloc(&'a mut self, layout: AdapterCoreLayout<M>) -> (A, C) {
672 let adapter_alignment = A::alignment(&layout);
673 let core_alignment = C::alignment(&layout);
674 let adapter_size = A::size(&layout);
675 let aligned_adapter_size = adapter_size.next_multiple_of(core_alignment);
676 let core_size = C::size(&layout);
677 let aligned_core_size = (aligned_adapter_size + core_size)
678 .next_multiple_of(adapter_alignment)
679 - aligned_adapter_size;
680 debug_assert_eq!(MAX_ALIGNMENT % adapter_alignment, 0);
681 debug_assert_eq!(MAX_ALIGNMENT % core_alignment, 0);
682 let buffer = self.alloc_bytes(aligned_adapter_size + aligned_core_size);
683 let (adapter_buffer, core_buffer) =
688 unsafe { buffer.split_at_mut_unchecked(aligned_adapter_size) };
689
690 let adapter_record: A = adapter_buffer.custom_borrow(layout.clone());
691 let core_record: C = core_buffer.custom_borrow(layout);
692
693 (adapter_record, core_record)
694 }
695}