1use std::{array::from_fn, borrow::Borrow, marker::PhantomData, sync::Arc};
2
3use openvm_circuit_primitives_derive::AlignedBorrow;
4use openvm_instructions::{instruction::Instruction, LocalOpcode};
5use openvm_stark_backend::{
6 config::{StarkGenericConfig, Val},
7 p3_air::{Air, AirBuilder, BaseAir},
8 p3_field::FieldAlgebra,
9 p3_matrix::{dense::RowMajorMatrix, Matrix},
10 p3_maybe_rayon::prelude::*,
11 prover::{cpu::CpuBackend, types::AirProvingContext},
12 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
13 Chip,
14};
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 arch::RowMajorMatrixArena,
19 system::memory::{online::TracingMemory, MemoryAuxColsFactory, SharedMemoryHelper},
20};
21
22pub trait VmAdapterInterface<T> {
24 type Reads;
26 type Writes;
28 type ProcessedInstruction;
33}
34
35pub trait VmAdapterAir<AB: AirBuilder>: BaseAir<AB::F> {
36 type Interface: VmAdapterInterface<AB::Expr>;
37
38 fn eval(
45 &self,
46 builder: &mut AB,
47 local: &[AB::Var],
48 interface: AdapterAirContext<AB::Expr, Self::Interface>,
49 );
50
51 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var;
53}
54
55pub trait VmCoreAir<AB, I>: BaseAirWithPublicValues<AB::F>
56where
57 AB: AirBuilder,
58 I: VmAdapterInterface<AB::Expr>,
59{
60 fn eval(
62 &self,
63 builder: &mut AB,
64 local_core: &[AB::Var],
65 from_pc: AB::Var,
66 ) -> AdapterAirContext<AB::Expr, I>;
67
68 fn start_offset(&self) -> usize;
72
73 fn start_offset_expr(&self) -> AB::Expr {
74 AB::Expr::from_canonical_usize(self.start_offset())
75 }
76
77 fn expr_to_global_expr(&self, local_expr: impl Into<AB::Expr>) -> AB::Expr {
78 self.start_offset_expr() + local_expr.into()
79 }
80
81 fn opcode_to_global_expr(&self, local_opcode: impl LocalOpcode) -> AB::Expr {
82 self.expr_to_global_expr(AB::Expr::from_canonical_usize(local_opcode.local_usize()))
83 }
84}
85
86pub struct AdapterAirContext<T, I: VmAdapterInterface<T>> {
87 pub to_pc: Option<T>,
89 pub reads: I::Reads,
90 pub writes: I::Writes,
91 pub instruction: I::ProcessedInstruction,
92}
93
94pub trait TraceFiller<F>: Send + Sync {
96 fn fill_trace(
100 &self,
101 mem_helper: &MemoryAuxColsFactory<F>,
102 trace: &mut RowMajorMatrix<F>,
103 rows_used: usize,
104 ) where
105 F: Send + Sync + Clone,
106 {
107 let width = trace.width();
108 trace.values[..rows_used * width]
109 .par_chunks_exact_mut(width)
110 .for_each(|row_slice| {
111 self.fill_trace_row(mem_helper, row_slice);
112 });
113 trace.values[rows_used * width..]
114 .par_chunks_exact_mut(width)
115 .for_each(|row_slice| {
116 self.fill_dummy_trace_row(row_slice);
117 });
118 }
119
120 fn fill_trace_row(&self, _mem_helper: &MemoryAuxColsFactory<F>, _row_slice: &mut [F]) {
127 unreachable!("fill_trace_row is not implemented")
128 }
129
130 fn fill_dummy_trace_row(&self, _row_slice: &mut [F]) {
135 }
137
138 fn generate_public_values(&self) -> Vec<F> {
140 vec![]
141 }
142}
143
144#[derive(derive_new::new)]
148pub struct VmChipWrapper<F, FILLER> {
149 pub inner: FILLER,
150 pub mem_helper: SharedMemoryHelper<F>,
151}
152
153impl<SC, FILLER, RA> Chip<RA, CpuBackend<SC>> for VmChipWrapper<Val<SC>, FILLER>
154where
155 SC: StarkGenericConfig,
156 FILLER: TraceFiller<Val<SC>>,
157 RA: RowMajorMatrixArena<Val<SC>>,
158{
159 fn generate_proving_ctx(&self, arena: RA) -> AirProvingContext<CpuBackend<SC>> {
160 let rows_used = arena.trace_offset() / arena.width();
161 let mut trace = arena.into_matrix();
162 let mem_helper = self.mem_helper.as_borrowed();
163 self.inner.fill_trace(&mem_helper, &mut trace, rows_used);
164
165 AirProvingContext::simple(Arc::new(trace), self.inner.generate_public_values())
166 }
167}
168
169pub trait AdapterTraceExecutor<F>: Clone {
174 const WIDTH: usize;
175 type ReadData;
176 type WriteData;
177 type RecordMut<'a>
181 where
182 Self: 'a;
183
184 fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>);
185
186 fn read(
187 &self,
188 memory: &mut TracingMemory,
189 instruction: &Instruction<F>,
190 record: &mut Self::RecordMut<'_>,
191 ) -> Self::ReadData;
192
193 fn write(
194 &self,
195 memory: &mut TracingMemory,
196 instruction: &Instruction<F>,
197 data: Self::WriteData,
198 record: &mut Self::RecordMut<'_>,
199 );
200}
201
202pub trait AdapterTraceFiller<F>: Send + Sync {
205 const WIDTH: usize;
206 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, adapter_row: &mut [F]);
208}
209
210#[derive(Clone, Copy, derive_new::new)]
213pub struct VmAirWrapper<A, C> {
214 pub adapter: A,
215 pub core: C,
216}
217
218impl<F, A, C> BaseAir<F> for VmAirWrapper<A, C>
219where
220 A: BaseAir<F>,
221 C: BaseAir<F>,
222{
223 fn width(&self) -> usize {
224 self.adapter.width() + self.core.width()
225 }
226}
227
228impl<F, A, M> BaseAirWithPublicValues<F> for VmAirWrapper<A, M>
229where
230 A: BaseAir<F>,
231 M: BaseAirWithPublicValues<F>,
232{
233 fn num_public_values(&self) -> usize {
234 self.core.num_public_values()
235 }
236}
237
238impl<F, A, M> PartitionedBaseAir<F> for VmAirWrapper<A, M>
240where
241 A: BaseAir<F>,
242 M: BaseAir<F>,
243{
244}
245
246impl<AB, A, M> Air<AB> for VmAirWrapper<A, M>
247where
248 AB: AirBuilder,
249 A: VmAdapterAir<AB>,
250 M: VmCoreAir<AB, A::Interface>,
251{
252 fn eval(&self, builder: &mut AB) {
253 let main = builder.main();
254 let local = main.row_slice(0);
255 let local: &[AB::Var] = (*local).borrow();
256 let (local_adapter, local_core) = local.split_at(self.adapter.width());
257
258 let ctx = self
259 .core
260 .eval(builder, local_core, self.adapter.get_from_pc(local_adapter));
261 self.adapter.eval(builder, local_adapter, ctx);
262 }
263}
264
265pub struct BasicAdapterInterface<
273 T,
274 PI,
275 const NUM_READS: usize,
276 const NUM_WRITES: usize,
277 const READ_SIZE: usize,
278 const WRITE_SIZE: usize,
279>(PhantomData<T>, PhantomData<PI>);
280
281impl<
282 T,
283 PI,
284 const NUM_READS: usize,
285 const NUM_WRITES: usize,
286 const READ_SIZE: usize,
287 const WRITE_SIZE: usize,
288 > VmAdapterInterface<T>
289 for BasicAdapterInterface<T, PI, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE>
290{
291 type Reads = [[T; READ_SIZE]; NUM_READS];
292 type Writes = [[T; WRITE_SIZE]; NUM_WRITES];
293 type ProcessedInstruction = PI;
294}
295
296pub struct VecHeapAdapterInterface<
297 T,
298 const NUM_READS: usize,
299 const BLOCKS_PER_READ: usize,
300 const BLOCKS_PER_WRITE: usize,
301 const READ_SIZE: usize,
302 const WRITE_SIZE: usize,
303>(PhantomData<T>);
304
305impl<
306 T,
307 const NUM_READS: usize,
308 const BLOCKS_PER_READ: usize,
309 const BLOCKS_PER_WRITE: usize,
310 const READ_SIZE: usize,
311 const WRITE_SIZE: usize,
312 > VmAdapterInterface<T>
313 for VecHeapAdapterInterface<
314 T,
315 NUM_READS,
316 BLOCKS_PER_READ,
317 BLOCKS_PER_WRITE,
318 READ_SIZE,
319 WRITE_SIZE,
320 >
321{
322 type Reads = [[[T; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS];
323 type Writes = [[T; WRITE_SIZE]; BLOCKS_PER_WRITE];
324 type ProcessedInstruction = MinimalInstruction<T>;
325}
326
327pub struct FlatInterface<T, PI, const READ_CELLS: usize, const WRITE_CELLS: usize>(
330 PhantomData<T>,
331 PhantomData<PI>,
332);
333
334impl<T, PI, const READ_CELLS: usize, const WRITE_CELLS: usize> VmAdapterInterface<T>
335 for FlatInterface<T, PI, READ_CELLS, WRITE_CELLS>
336{
337 type Reads = [T; READ_CELLS];
338 type Writes = [T; WRITE_CELLS];
339 type ProcessedInstruction = PI;
340}
341
342#[derive(Serialize, Deserialize)]
345pub struct DynAdapterInterface<T>(PhantomData<T>);
346
347impl<T> VmAdapterInterface<T> for DynAdapterInterface<T> {
348 type Reads = DynArray<T>;
350 type Writes = DynArray<T>;
352 type ProcessedInstruction = DynArray<T>;
354}
355
356#[derive(Clone, Debug, Default)]
358pub struct DynArray<T>(pub Vec<T>);
359
360#[repr(C)]
365#[derive(AlignedBorrow)]
366pub struct MinimalInstruction<T> {
367 pub is_valid: T,
368 pub opcode: T,
370}
371
372#[repr(C)]
374#[derive(AlignedBorrow)]
375pub struct ImmInstruction<T> {
376 pub is_valid: T,
377 pub opcode: T,
379 pub immediate: T,
380}
381
382#[repr(C)]
384#[derive(AlignedBorrow)]
385pub struct SignedImmInstruction<T> {
386 pub is_valid: T,
387 pub opcode: T,
389 pub immediate: T,
390 pub imm_sign: T,
392}
393
394mod conversions {
399 use super::*;
400
401 impl<
403 T,
404 const NUM_READS: usize,
405 const BLOCKS_PER_READ: usize,
406 const BLOCKS_PER_WRITE: usize,
407 const READ_SIZE: usize,
408 const WRITE_SIZE: usize,
409 >
410 From<
411 AdapterAirContext<
412 T,
413 VecHeapAdapterInterface<
414 T,
415 NUM_READS,
416 BLOCKS_PER_READ,
417 BLOCKS_PER_WRITE,
418 READ_SIZE,
419 WRITE_SIZE,
420 >,
421 >,
422 > for AdapterAirContext<T, DynAdapterInterface<T>>
423 {
424 fn from(
425 ctx: AdapterAirContext<
426 T,
427 VecHeapAdapterInterface<
428 T,
429 NUM_READS,
430 BLOCKS_PER_READ,
431 BLOCKS_PER_WRITE,
432 READ_SIZE,
433 WRITE_SIZE,
434 >,
435 >,
436 ) -> Self {
437 AdapterAirContext {
438 to_pc: ctx.to_pc,
439 reads: ctx.reads.into(),
440 writes: ctx.writes.into(),
441 instruction: ctx.instruction.into(),
442 }
443 }
444 }
445
446 impl<
448 T,
449 const NUM_READS: usize,
450 const BLOCKS_PER_READ: usize,
451 const BLOCKS_PER_WRITE: usize,
452 const READ_SIZE: usize,
453 const WRITE_SIZE: usize,
454 > From<AdapterAirContext<T, DynAdapterInterface<T>>>
455 for AdapterAirContext<
456 T,
457 VecHeapAdapterInterface<
458 T,
459 NUM_READS,
460 BLOCKS_PER_READ,
461 BLOCKS_PER_WRITE,
462 READ_SIZE,
463 WRITE_SIZE,
464 >,
465 >
466 {
467 fn from(ctx: AdapterAirContext<T, DynAdapterInterface<T>>) -> Self {
468 AdapterAirContext {
469 to_pc: ctx.to_pc,
470 reads: ctx.reads.into(),
471 writes: ctx.writes.into(),
472 instruction: ctx.instruction.into(),
473 }
474 }
475 }
476
477 impl<
479 T,
480 PI: Into<MinimalInstruction<T>>,
481 const BASIC_NUM_READS: usize,
482 const BASIC_NUM_WRITES: usize,
483 const NUM_READS: usize,
484 const BLOCKS_PER_READ: usize,
485 const BLOCKS_PER_WRITE: usize,
486 const READ_SIZE: usize,
487 const WRITE_SIZE: usize,
488 >
489 From<
490 AdapterAirContext<
491 T,
492 BasicAdapterInterface<
493 T,
494 PI,
495 BASIC_NUM_READS,
496 BASIC_NUM_WRITES,
497 READ_SIZE,
498 WRITE_SIZE,
499 >,
500 >,
501 >
502 for AdapterAirContext<
503 T,
504 VecHeapAdapterInterface<
505 T,
506 NUM_READS,
507 BLOCKS_PER_READ,
508 BLOCKS_PER_WRITE,
509 READ_SIZE,
510 WRITE_SIZE,
511 >,
512 >
513 {
514 fn from(
515 ctx: AdapterAirContext<
516 T,
517 BasicAdapterInterface<
518 T,
519 PI,
520 BASIC_NUM_READS,
521 BASIC_NUM_WRITES,
522 READ_SIZE,
523 WRITE_SIZE,
524 >,
525 >,
526 ) -> Self {
527 assert_eq!(BASIC_NUM_READS, NUM_READS * BLOCKS_PER_READ);
528 let mut reads_it = ctx.reads.into_iter();
529 let reads = from_fn(|_| from_fn(|_| reads_it.next().unwrap()));
530 assert_eq!(BASIC_NUM_WRITES, BLOCKS_PER_WRITE);
531 let mut writes_it = ctx.writes.into_iter();
532 let writes = from_fn(|_| writes_it.next().unwrap());
533 AdapterAirContext {
534 to_pc: ctx.to_pc,
535 reads,
536 writes,
537 instruction: ctx.instruction.into(),
538 }
539 }
540 }
541
542 impl<
544 T,
545 PI,
546 const NUM_READS: usize,
547 const NUM_WRITES: usize,
548 const READ_SIZE: usize,
549 const WRITE_SIZE: usize,
550 const READ_CELLS: usize,
551 const WRITE_CELLS: usize,
552 >
553 From<
554 AdapterAirContext<
555 T,
556 BasicAdapterInterface<T, PI, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE>,
557 >,
558 > for AdapterAirContext<T, FlatInterface<T, PI, READ_CELLS, WRITE_CELLS>>
559 {
560 fn from(
564 ctx: AdapterAirContext<
565 T,
566 BasicAdapterInterface<T, PI, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE>,
567 >,
568 ) -> AdapterAirContext<T, FlatInterface<T, PI, READ_CELLS, WRITE_CELLS>> {
569 assert_eq!(READ_CELLS, NUM_READS * READ_SIZE);
570 assert_eq!(WRITE_CELLS, NUM_WRITES * WRITE_SIZE);
571 let mut reads_it = ctx.reads.into_iter().flatten();
572 let reads = from_fn(|_| reads_it.next().unwrap());
573 let mut writes_it = ctx.writes.into_iter().flatten();
574 let writes = from_fn(|_| writes_it.next().unwrap());
575 AdapterAirContext {
576 to_pc: ctx.to_pc,
577 reads,
578 writes,
579 instruction: ctx.instruction,
580 }
581 }
582 }
583
584 impl<
586 T,
587 PI,
588 const NUM_READS: usize,
589 const NUM_WRITES: usize,
590 const READ_SIZE: usize,
591 const WRITE_SIZE: usize,
592 const READ_CELLS: usize,
593 const WRITE_CELLS: usize,
594 > From<AdapterAirContext<T, FlatInterface<T, PI, READ_CELLS, WRITE_CELLS>>>
595 for AdapterAirContext<
596 T,
597 BasicAdapterInterface<T, PI, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE>,
598 >
599 {
600 fn from(
604 AdapterAirContext {
605 to_pc,
606 reads,
607 writes,
608 instruction,
609 }: AdapterAirContext<T, FlatInterface<T, PI, READ_CELLS, WRITE_CELLS>>,
610 ) -> AdapterAirContext<
611 T,
612 BasicAdapterInterface<T, PI, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE>,
613 > {
614 assert_eq!(READ_CELLS, NUM_READS * READ_SIZE);
615 assert_eq!(WRITE_CELLS, NUM_WRITES * WRITE_SIZE);
616 let mut reads_it = reads.into_iter();
617 let reads: [[T; READ_SIZE]; NUM_READS] =
618 from_fn(|_| from_fn(|_| reads_it.next().unwrap()));
619 let mut writes_it = writes.into_iter();
620 let writes: [[T; WRITE_SIZE]; NUM_WRITES] =
621 from_fn(|_| from_fn(|_| writes_it.next().unwrap()));
622 AdapterAirContext {
623 to_pc,
624 reads,
625 writes,
626 instruction,
627 }
628 }
629 }
630
631 impl<T> From<Vec<T>> for DynArray<T> {
632 fn from(v: Vec<T>) -> Self {
633 Self(v)
634 }
635 }
636
637 impl<T> From<DynArray<T>> for Vec<T> {
638 fn from(v: DynArray<T>) -> Vec<T> {
639 v.0
640 }
641 }
642
643 impl<T, const N: usize, const M: usize> From<[[T; N]; M]> for DynArray<T> {
644 fn from(v: [[T; N]; M]) -> Self {
645 Self(v.into_iter().flatten().collect())
646 }
647 }
648
649 impl<T, const N: usize, const M: usize> From<DynArray<T>> for [[T; N]; M] {
650 fn from(v: DynArray<T>) -> Self {
651 assert_eq!(v.0.len(), N * M, "Incorrect vector length {}", v.0.len());
652 let mut it = v.0.into_iter();
653 from_fn(|_| from_fn(|_| it.next().unwrap()))
654 }
655 }
656
657 impl<T, const N: usize, const M: usize, const R: usize> From<[[[T; N]; M]; R]> for DynArray<T> {
658 fn from(v: [[[T; N]; M]; R]) -> Self {
659 Self(
660 v.into_iter()
661 .flat_map(|x| x.into_iter().flatten())
662 .collect(),
663 )
664 }
665 }
666
667 impl<T, const N: usize, const M: usize, const R: usize> From<DynArray<T>> for [[[T; N]; M]; R] {
668 fn from(v: DynArray<T>) -> Self {
669 assert_eq!(
670 v.0.len(),
671 N * M * R,
672 "Incorrect vector length {}",
673 v.0.len()
674 );
675 let mut it = v.0.into_iter();
676 from_fn(|_| from_fn(|_| from_fn(|_| it.next().unwrap())))
677 }
678 }
679
680 impl<T, const N: usize, const M1: usize, const M2: usize> From<([[T; N]; M1], [[T; N]; M2])>
681 for DynArray<T>
682 {
683 fn from(v: ([[T; N]; M1], [[T; N]; M2])) -> Self {
684 let vec =
685 v.0.into_iter()
686 .flatten()
687 .chain(v.1.into_iter().flatten())
688 .collect();
689 Self(vec)
690 }
691 }
692
693 impl<T, const N: usize, const M1: usize, const M2: usize> From<DynArray<T>>
694 for ([[T; N]; M1], [[T; N]; M2])
695 {
696 fn from(v: DynArray<T>) -> Self {
697 assert_eq!(
698 v.0.len(),
699 N * (M1 + M2),
700 "Incorrect vector length {}",
701 v.0.len()
702 );
703 let mut it = v.0.into_iter();
704 (
705 from_fn(|_| from_fn(|_| it.next().unwrap())),
706 from_fn(|_| from_fn(|_| it.next().unwrap())),
707 )
708 }
709 }
710
711 impl<
713 T,
714 PI: Into<DynArray<T>>,
715 const NUM_READS: usize,
716 const NUM_WRITES: usize,
717 const READ_SIZE: usize,
718 const WRITE_SIZE: usize,
719 >
720 From<
721 AdapterAirContext<
722 T,
723 BasicAdapterInterface<T, PI, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE>,
724 >,
725 > for AdapterAirContext<T, DynAdapterInterface<T>>
726 {
727 fn from(
728 ctx: AdapterAirContext<
729 T,
730 BasicAdapterInterface<T, PI, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE>,
731 >,
732 ) -> Self {
733 AdapterAirContext {
734 to_pc: ctx.to_pc,
735 reads: ctx.reads.into(),
736 writes: ctx.writes.into(),
737 instruction: ctx.instruction.into(),
738 }
739 }
740 }
741
742 impl<
744 T,
745 PI,
746 const NUM_READS: usize,
747 const NUM_WRITES: usize,
748 const READ_SIZE: usize,
749 const WRITE_SIZE: usize,
750 > From<AdapterAirContext<T, DynAdapterInterface<T>>>
751 for AdapterAirContext<
752 T,
753 BasicAdapterInterface<T, PI, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE>,
754 >
755 where
756 PI: From<DynArray<T>>,
757 {
758 fn from(ctx: AdapterAirContext<T, DynAdapterInterface<T>>) -> Self {
759 AdapterAirContext {
760 to_pc: ctx.to_pc,
761 reads: ctx.reads.into(),
762 writes: ctx.writes.into(),
763 instruction: ctx.instruction.into(),
764 }
765 }
766 }
767
768 impl<T: Clone, PI: Into<DynArray<T>>, const READ_CELLS: usize, const WRITE_CELLS: usize>
770 From<AdapterAirContext<T, FlatInterface<T, PI, READ_CELLS, WRITE_CELLS>>>
771 for AdapterAirContext<T, DynAdapterInterface<T>>
772 {
773 fn from(ctx: AdapterAirContext<T, FlatInterface<T, PI, READ_CELLS, WRITE_CELLS>>) -> Self {
774 AdapterAirContext {
775 to_pc: ctx.to_pc,
776 reads: ctx.reads.to_vec().into(),
777 writes: ctx.writes.to_vec().into(),
778 instruction: ctx.instruction.into(),
779 }
780 }
781 }
782
783 impl<T> From<MinimalInstruction<T>> for DynArray<T> {
784 fn from(m: MinimalInstruction<T>) -> Self {
785 Self(vec![m.is_valid, m.opcode])
786 }
787 }
788
789 impl<T> From<DynArray<T>> for MinimalInstruction<T> {
790 fn from(m: DynArray<T>) -> Self {
791 let mut m = m.0.into_iter();
792 MinimalInstruction {
793 is_valid: m.next().unwrap(),
794 opcode: m.next().unwrap(),
795 }
796 }
797 }
798
799 impl<T> From<DynArray<T>> for ImmInstruction<T> {
800 fn from(m: DynArray<T>) -> Self {
801 let mut m = m.0.into_iter();
802 ImmInstruction {
803 is_valid: m.next().unwrap(),
804 opcode: m.next().unwrap(),
805 immediate: m.next().unwrap(),
806 }
807 }
808 }
809
810 impl<T> From<ImmInstruction<T>> for DynArray<T> {
811 fn from(instruction: ImmInstruction<T>) -> Self {
812 DynArray::from(vec![
813 instruction.is_valid,
814 instruction.opcode,
815 instruction.immediate,
816 ])
817 }
818 }
819}