1use std::{
2 any::{Any, TypeId},
3 cell::RefCell,
4 iter::once,
5 sync::{Arc, Mutex},
6};
7
8use derive_more::derive::From;
9use getset::Getters;
10use itertools::{zip_eq, Itertools};
11#[cfg(feature = "bench-metrics")]
12use metrics::counter;
13use openvm_circuit_derive::{AnyEnum, InstructionExecutor};
14use openvm_circuit_primitives::{
15 utils::next_power_of_two_or_zero,
16 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
17};
18use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
19use openvm_instructions::{
20 program::Program, LocalOpcode, PhantomDiscriminant, PublishOpcode, SystemOpcode, VmOpcode,
21};
22use openvm_stark_backend::{
23 config::{Domain, StarkGenericConfig},
24 interaction::{BusIndex, PermutationCheckBus},
25 keygen::types::LinearConstraint,
26 p3_commit::PolynomialSpace,
27 p3_field::{FieldAlgebra, PrimeField32, TwoAdicField},
28 p3_matrix::Matrix,
29 p3_util::log2_ceil_usize,
30 prover::types::{AirProofInput, CommittedTraceData, ProofInput},
31 AirRef, Chip, ChipUsageGetter,
32};
33use p3_baby_bear::BabyBear;
34use rustc_hash::FxHashMap;
35use serde::{Deserialize, Serialize};
36
37use super::{
38 vm_poseidon2_config, ExecutionBus, GenerationError, InstructionExecutor, PhantomSubExecutor,
39 Streams, SystemConfig, SystemTraceHeights,
40};
41#[cfg(feature = "bench-metrics")]
42use crate::metrics::VmMetrics;
43use crate::system::{
44 connector::VmConnectorChip,
45 memory::{
46 offline_checker::{MemoryBridge, MemoryBus},
47 MemoryController, MemoryImage, OfflineMemory, BOUNDARY_AIR_OFFSET, MERKLE_AIR_OFFSET,
48 },
49 native_adapter::NativeAdapterChip,
50 phantom::PhantomChip,
51 poseidon2::Poseidon2PeripheryChip,
52 program::{ProgramBus, ProgramChip},
53 public_values::{core::PublicValuesCoreChip, PublicValuesChip},
54};
55
56pub const PROGRAM_AIR_ID: usize = 0;
58pub const PROGRAM_CACHED_TRACE_INDEX: usize = 0;
60pub const CONNECTOR_AIR_ID: usize = 1;
61pub const PUBLIC_VALUES_AIR_ID: usize = 2;
64pub const BOUNDARY_AIR_ID: usize = PUBLIC_VALUES_AIR_ID + 1 + BOUNDARY_AIR_OFFSET;
66pub const MERKLE_AIR_ID: usize = CONNECTOR_AIR_ID + 1 + MERKLE_AIR_OFFSET;
69
70pub trait VmExtension<F: PrimeField32> {
76 type Executor: InstructionExecutor<F> + AnyEnum;
80 type Periphery: AnyEnum;
84
85 fn build(
86 &self,
87 builder: &mut VmInventoryBuilder<F>,
88 ) -> Result<VmInventory<Self::Executor, Self::Periphery>, VmInventoryError>;
89}
90
91impl<F: PrimeField32, E: VmExtension<F>> VmExtension<F> for Option<E> {
92 type Executor = E::Executor;
93 type Periphery = E::Periphery;
94
95 fn build(
96 &self,
97 builder: &mut VmInventoryBuilder<F>,
98 ) -> Result<VmInventory<Self::Executor, Self::Periphery>, VmInventoryError> {
99 if let Some(extension) = self {
100 extension.build(builder)
101 } else {
102 Ok(VmInventory::new())
103 }
104 }
105}
106
107#[derive(Clone, Copy)]
109pub struct SystemPort {
110 pub execution_bus: ExecutionBus,
111 pub program_bus: ProgramBus,
112 pub memory_bridge: MemoryBridge,
113}
114
115pub struct VmInventoryBuilder<'a, F: PrimeField32> {
117 system_config: &'a SystemConfig,
118 system: &'a SystemBase<F>,
119 streams: &'a Arc<Mutex<Streams<F>>>,
120 bus_idx_mgr: BusIndexManager,
121 chips: Vec<&'a dyn AnyEnum>,
125}
126
127impl<'a, F: PrimeField32> VmInventoryBuilder<'a, F> {
128 pub fn new(
129 system_config: &'a SystemConfig,
130 system: &'a SystemBase<F>,
131 streams: &'a Arc<Mutex<Streams<F>>>,
132 bus_idx_mgr: BusIndexManager,
133 ) -> Self {
134 Self {
135 system_config,
136 system,
137 streams,
138 bus_idx_mgr,
139 chips: Vec::new(),
140 }
141 }
142
143 pub fn system_config(&self) -> &SystemConfig {
144 self.system_config
145 }
146
147 pub fn system_base(&self) -> &SystemBase<F> {
148 self.system
149 }
150
151 pub fn system_port(&self) -> SystemPort {
152 SystemPort {
153 execution_bus: self.system_base().execution_bus(),
154 program_bus: self.system_base().program_bus(),
155 memory_bridge: self.system_base().memory_bridge(),
156 }
157 }
158
159 pub fn new_bus_idx(&mut self) -> BusIndex {
160 self.bus_idx_mgr.new_bus_idx()
161 }
162
163 pub fn find_chip<C: 'static>(&self) -> Vec<&C> {
168 self.chips
169 .iter()
170 .filter_map(|c| c.as_any_kind().downcast_ref())
171 .collect()
172 }
173
174 pub fn add_phantom_sub_executor<PE: PhantomSubExecutor<F> + 'static>(
176 &self,
177 phantom_sub: PE,
178 discriminant: PhantomDiscriminant,
179 ) -> Result<(), VmInventoryError> {
180 let chip_ref: &RefCell<PhantomChip<F>> =
181 self.find_chip().first().expect("PhantomChip always exists");
182 let mut chip = chip_ref.borrow_mut();
183 let existing = chip.add_sub_executor(phantom_sub, discriminant);
184 if existing.is_some() {
185 return Err(VmInventoryError::PhantomSubExecutorExists { discriminant });
186 }
187 Ok(())
188 }
189
190 pub fn streams(&self) -> &Arc<Mutex<Streams<F>>> {
192 self.streams
193 }
194
195 fn add_chip<E: AnyEnum>(&mut self, chip: &'a E) {
196 self.chips.push(chip);
197 }
198}
199
200#[derive(Clone, Debug)]
201pub struct VmInventory<E, P> {
202 instruction_lookup: FxHashMap<VmOpcode, ExecutorId>,
204 pub executors: Vec<E>,
205 pub periphery: Vec<P>,
206 insertion_order: Vec<ChipId>,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
212pub struct VmInventoryTraceHeights {
213 pub chips: FxHashMap<ChipId, usize>,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, derive_new::new)]
217pub struct VmComplexTraceHeights {
218 pub system: SystemTraceHeights,
219 pub inventory: VmInventoryTraceHeights,
220}
221
222type ExecutorId = usize;
223
224#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
225pub enum ChipId {
226 Executor(usize),
227 Periphery(usize),
228}
229
230#[derive(thiserror::Error, Debug)]
231pub enum VmInventoryError {
232 #[error("Opcode {opcode} already owned by executor id {id}")]
233 ExecutorExists { opcode: VmOpcode, id: ExecutorId },
234 #[error("Phantom discriminant {} already has sub-executor", .discriminant.0)]
235 PhantomSubExecutorExists { discriminant: PhantomDiscriminant },
236 #[error("Chip {name} not found")]
237 ChipNotFound { name: String },
238}
239
240impl<E, P> Default for VmInventory<E, P> {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246impl<E, P> VmInventory<E, P> {
247 pub fn new() -> Self {
248 Self {
249 instruction_lookup: FxHashMap::default(),
250 executors: Vec::new(),
251 periphery: Vec::new(),
252 insertion_order: Vec::new(),
253 }
254 }
255
256 pub fn transmute<E2, P2>(self) -> VmInventory<E2, P2>
257 where
258 E: Into<E2>,
259 P: Into<P2>,
260 {
261 VmInventory {
262 instruction_lookup: self.instruction_lookup,
263 executors: self.executors.into_iter().map(|e| e.into()).collect(),
264 periphery: self.periphery.into_iter().map(|p| p.into()).collect(),
265 insertion_order: self.insertion_order,
266 }
267 }
268
269 pub fn append(&mut self, mut other: VmInventory<E, P>) -> Result<(), VmInventoryError> {
271 let num_executors = self.executors.len();
272 let num_periphery = self.periphery.len();
273 for (opcode, mut id) in other.instruction_lookup.into_iter() {
274 id += num_executors;
275 if let Some(old_id) = self.instruction_lookup.insert(opcode, id) {
276 return Err(VmInventoryError::ExecutorExists { opcode, id: old_id });
277 }
278 }
279 for chip_id in other.insertion_order.iter_mut() {
280 match chip_id {
281 ChipId::Executor(id) => *id += num_executors,
282 ChipId::Periphery(id) => *id += num_periphery,
283 }
284 }
285 self.executors.append(&mut other.executors);
286 self.periphery.append(&mut other.periphery);
287 self.insertion_order.append(&mut other.insertion_order);
288 Ok(())
289 }
290
291 pub fn add_executor(
295 &mut self,
296 executor: impl Into<E>,
297 opcodes: impl IntoIterator<Item = VmOpcode>,
298 ) -> Result<(), VmInventoryError> {
299 let opcodes: Vec<_> = opcodes.into_iter().collect();
300 for opcode in &opcodes {
301 if let Some(id) = self.instruction_lookup.get(opcode) {
302 return Err(VmInventoryError::ExecutorExists {
303 opcode: *opcode,
304 id: *id,
305 });
306 }
307 }
308 let id = self.executors.len();
309 self.executors.push(executor.into());
310 self.insertion_order.push(ChipId::Executor(id));
311 for opcode in opcodes {
312 self.instruction_lookup.insert(opcode, id);
313 }
314 Ok(())
315 }
316
317 pub fn add_periphery_chip(&mut self, periphery_chip: impl Into<P>) {
318 let id = self.periphery.len();
319 self.periphery.push(periphery_chip.into());
320 self.insertion_order.push(ChipId::Periphery(id));
321 }
322
323 pub fn get_executor(&self, opcode: VmOpcode) -> Option<&E> {
324 let id = self.instruction_lookup.get(&opcode)?;
325 self.executors.get(*id)
326 }
327
328 pub fn get_mut_executor(&mut self, opcode: &VmOpcode) -> Option<&mut E> {
329 let id = self.instruction_lookup.get(opcode)?;
330 self.executors.get_mut(*id)
331 }
332
333 pub fn executors(&self) -> &[E] {
334 &self.executors
335 }
336
337 pub fn periphery(&self) -> &[P] {
338 &self.periphery
339 }
340
341 pub fn num_airs(&self) -> usize {
342 self.executors.len() + self.periphery.len()
343 }
344
345 pub fn get_trace_heights(&self) -> VmInventoryTraceHeights
350 where
351 E: ChipUsageGetter,
352 P: ChipUsageGetter,
353 {
354 VmInventoryTraceHeights {
355 chips: self
356 .executors
357 .iter()
358 .enumerate()
359 .map(|(i, chip)| (ChipId::Executor(i), chip.current_trace_height()))
360 .chain(
361 self.periphery
362 .iter()
363 .enumerate()
364 .map(|(i, chip)| (ChipId::Periphery(i), chip.current_trace_height())),
365 )
366 .collect(),
367 }
368 }
369
370 pub fn get_dummy_trace_heights(&self) -> VmInventoryTraceHeights
373 where
374 E: ChipUsageGetter,
375 P: ChipUsageGetter,
376 {
377 VmInventoryTraceHeights {
378 chips: self
379 .executors
380 .iter()
381 .enumerate()
382 .map(|(i, _)| (ChipId::Executor(i), 1))
383 .chain(self.periphery.iter().enumerate().map(|(i, chip)| {
384 (
385 ChipId::Periphery(i),
386 chip.constant_trace_height().unwrap_or(1),
387 )
388 }))
389 .collect(),
390 }
391 }
392}
393
394impl VmInventoryTraceHeights {
395 pub fn round_to_next_power_of_two(&mut self) {
397 self.chips
398 .values_mut()
399 .for_each(|v| *v = v.next_power_of_two());
400 }
401
402 pub fn round_to_next_power_of_two_or_zero(&mut self) {
404 self.chips
405 .values_mut()
406 .for_each(|v| *v = next_power_of_two_or_zero(*v));
407 }
408}
409
410impl VmComplexTraceHeights {
411 pub fn round_to_next_power_of_two(&mut self) {
413 self.system.round_to_next_power_of_two();
414 self.inventory.round_to_next_power_of_two();
415 }
416
417 pub fn round_to_next_power_of_two_or_zero(&mut self) {
419 self.system.round_to_next_power_of_two_or_zero();
420 self.inventory.round_to_next_power_of_two_or_zero();
421 }
422}
423
424#[derive(Getters)]
427pub struct VmChipComplex<F: PrimeField32, E, P> {
428 #[getset(get = "pub")]
429 config: SystemConfig,
430 pub base: SystemBase<F>,
432 pub inventory: VmInventory<E, P>,
438 overridden_inventory_heights: Option<VmInventoryTraceHeights>,
439
440 max_trace_height: usize,
442
443 streams: Arc<Mutex<Streams<F>>>,
444 bus_idx_mgr: BusIndexManager,
445}
446
447#[derive(Clone, Copy, Debug, Default)]
448pub struct BusIndexManager {
449 bus_idx_max: BusIndex,
451}
452
453impl BusIndexManager {
454 pub fn new() -> Self {
455 Self { bus_idx_max: 0 }
456 }
457
458 pub fn new_bus_idx(&mut self) -> BusIndex {
459 let idx = self.bus_idx_max;
460 self.bus_idx_max = self.bus_idx_max.checked_add(1).unwrap();
461 idx
462 }
463}
464
465pub type SystemComplex<F> = VmChipComplex<F, SystemExecutor<F>, SystemPeriphery<F>>;
467
468pub struct SystemBase<F> {
472 pub range_checker_chip: SharedVariableRangeCheckerChip,
474 pub memory_controller: MemoryController<F>,
475 pub connector_chip: VmConnectorChip<F>,
476 pub program_chip: ProgramChip<F>,
477}
478
479impl<F: PrimeField32> SystemBase<F> {
480 pub fn range_checker_bus(&self) -> VariableRangeCheckerBus {
481 self.range_checker_chip.bus()
482 }
483
484 pub fn memory_bus(&self) -> MemoryBus {
485 self.memory_controller.memory_bus
486 }
487
488 pub fn program_bus(&self) -> ProgramBus {
489 self.program_chip.air.bus
490 }
491
492 pub fn memory_bridge(&self) -> MemoryBridge {
493 self.memory_controller.memory_bridge()
494 }
495
496 pub fn offline_memory(&self) -> Arc<Mutex<OfflineMemory<F>>> {
497 self.memory_controller.offline_memory().clone()
498 }
499
500 pub fn execution_bus(&self) -> ExecutionBus {
501 self.connector_chip.air.execution_bus
502 }
503
504 pub fn get_system_trace_heights(&self) -> SystemTraceHeights {
507 SystemTraceHeights {
508 memory: self.memory_controller.get_memory_trace_heights(),
509 }
510 }
511
512 pub fn get_dummy_system_trace_heights(&self) -> SystemTraceHeights {
515 SystemTraceHeights {
516 memory: self.memory_controller.get_dummy_memory_trace_heights(),
517 }
518 }
519}
520
521#[derive(ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor)]
522pub enum SystemExecutor<F: PrimeField32> {
523 PublicValues(PublicValuesChip<F>),
524 Phantom(RefCell<PhantomChip<F>>),
525}
526
527#[derive(ChipUsageGetter, Chip, AnyEnum, From)]
528pub enum SystemPeriphery<F: PrimeField32> {
529 Poseidon2(Poseidon2PeripheryChip<F>),
531}
532
533impl<F: PrimeField32> SystemComplex<F> {
534 pub fn new(config: SystemConfig) -> Self {
535 let mut bus_idx_mgr = BusIndexManager::new();
536 let execution_bus = ExecutionBus::new(bus_idx_mgr.new_bus_idx());
537 let memory_bus = MemoryBus::new(bus_idx_mgr.new_bus_idx());
538 let program_bus = ProgramBus::new(bus_idx_mgr.new_bus_idx());
539 let range_bus =
540 VariableRangeCheckerBus::new(bus_idx_mgr.new_bus_idx(), config.memory_config.decomp);
541
542 let range_checker = SharedVariableRangeCheckerChip::new(range_bus);
543 let memory_controller = if config.continuation_enabled {
544 MemoryController::with_persistent_memory(
545 memory_bus,
546 config.memory_config,
547 range_checker.clone(),
548 PermutationCheckBus::new(bus_idx_mgr.new_bus_idx()),
549 PermutationCheckBus::new(bus_idx_mgr.new_bus_idx()),
550 )
551 } else {
552 MemoryController::with_volatile_memory(
553 memory_bus,
554 config.memory_config,
555 range_checker.clone(),
556 )
557 };
558 let memory_bridge = memory_controller.memory_bridge();
559 let offline_memory = memory_controller.offline_memory();
560 let program_chip = ProgramChip::new(program_bus);
561 let connector_chip = VmConnectorChip::new(
562 execution_bus,
563 program_bus,
564 range_checker.clone(),
565 config.memory_config.clk_max_bits,
566 );
567
568 let mut inventory = VmInventory::new();
569 if config.has_public_values_chip() {
571 assert_eq!(inventory.executors().len(), Self::PV_EXECUTOR_IDX);
572 let chip = PublicValuesChip::new(
573 NativeAdapterChip::new(execution_bus, program_bus, memory_bridge),
574 PublicValuesCoreChip::new(
575 config.num_public_values,
576 config.max_constraint_degree as u32 - 1,
577 ),
578 offline_memory,
579 );
580 inventory
581 .add_executor(chip, [PublishOpcode::PUBLISH.global_opcode()])
582 .unwrap();
583 }
584 if config.continuation_enabled {
585 assert_eq!(inventory.periphery().len(), Self::POSEIDON2_PERIPHERY_IDX);
586 let direct_bus_idx = memory_controller
591 .interface_chip
592 .compression_bus()
593 .unwrap()
594 .index;
595 let chip = Poseidon2PeripheryChip::new(
596 vm_poseidon2_config(),
597 direct_bus_idx,
598 config.max_constraint_degree,
599 );
600 inventory.add_periphery_chip(chip);
601 }
602 let streams = Arc::new(Mutex::new(Streams::default()));
603 let phantom_opcode = SystemOpcode::PHANTOM.global_opcode();
604 let mut phantom_chip =
605 PhantomChip::new(execution_bus, program_bus, SystemOpcode::CLASS_OFFSET);
606 phantom_chip.set_streams(streams.clone());
607 inventory
608 .add_executor(RefCell::new(phantom_chip), [phantom_opcode])
609 .unwrap();
610
611 let base = SystemBase {
612 program_chip,
613 connector_chip,
614 memory_controller,
615 range_checker_chip: range_checker,
616 };
617
618 let max_trace_height = if TypeId::of::<F>() == TypeId::of::<BabyBear>() {
619 let min_log_blowup = log2_ceil_usize(config.max_constraint_degree - 1);
620 1 << (BabyBear::TWO_ADICITY - min_log_blowup)
621 } else {
622 tracing::warn!(
623 "constructing SystemComplex for unrecognized field; using max_trace_height = 2^30"
624 );
625 1 << 30
626 };
627
628 Self {
629 config,
630 base,
631 inventory,
632 bus_idx_mgr,
633 streams,
634 overridden_inventory_heights: None,
635 max_trace_height,
636 }
637 }
638}
639
640impl<F: PrimeField32, E, P> VmChipComplex<F, E, P> {
641 pub(super) const PV_EXECUTOR_IDX: ExecutorId = 0;
643 pub(super) const POSEIDON2_PERIPHERY_IDX: usize = 0;
645
646 pub fn inventory_builder(&self) -> VmInventoryBuilder<F>
648 where
649 E: AnyEnum,
650 P: AnyEnum,
651 {
652 let mut builder =
653 VmInventoryBuilder::new(&self.config, &self.base, &self.streams, self.bus_idx_mgr);
654 builder.add_chip(&self.base.range_checker_chip);
656 for chip in self.inventory.executors() {
657 builder.add_chip(chip);
658 }
659 for chip in self.inventory.periphery() {
660 builder.add_chip(chip);
661 }
662
663 builder
664 }
665
666 pub fn extend<E3, P3, Ext>(
669 mut self,
670 config: &Ext,
671 ) -> Result<VmChipComplex<F, E3, P3>, VmInventoryError>
672 where
673 Ext: VmExtension<F>,
674 E: Into<E3> + AnyEnum,
675 P: Into<P3> + AnyEnum,
676 Ext::Executor: Into<E3>,
677 Ext::Periphery: Into<P3>,
678 {
679 let mut builder = self.inventory_builder();
680 let inventory_ext = config.build(&mut builder)?;
681 self.bus_idx_mgr = builder.bus_idx_mgr;
682 let mut ext_complex = self.transmute();
683 ext_complex.append(inventory_ext.transmute())?;
684 Ok(ext_complex)
685 }
686
687 pub fn transmute<E2, P2>(self) -> VmChipComplex<F, E2, P2>
688 where
689 E: Into<E2>,
690 P: Into<P2>,
691 {
692 VmChipComplex {
693 config: self.config,
694 base: self.base,
695 inventory: self.inventory.transmute(),
696 bus_idx_mgr: self.bus_idx_mgr,
697 streams: self.streams,
698 overridden_inventory_heights: self.overridden_inventory_heights,
699 max_trace_height: self.max_trace_height,
700 }
701 }
702
703 pub fn append(&mut self, other: VmInventory<E, P>) -> Result<(), VmInventoryError> {
706 self.inventory.append(other)
707 }
708
709 pub fn program_chip(&self) -> &ProgramChip<F> {
710 &self.base.program_chip
711 }
712
713 pub fn program_chip_mut(&mut self) -> &mut ProgramChip<F> {
714 &mut self.base.program_chip
715 }
716
717 pub fn connector_chip(&self) -> &VmConnectorChip<F> {
718 &self.base.connector_chip
719 }
720
721 pub fn connector_chip_mut(&mut self) -> &mut VmConnectorChip<F> {
722 &mut self.base.connector_chip
723 }
724
725 pub fn memory_controller(&self) -> &MemoryController<F> {
726 &self.base.memory_controller
727 }
728
729 pub fn range_checker_chip(&self) -> &SharedVariableRangeCheckerChip {
730 &self.base.range_checker_chip
731 }
732
733 pub fn public_values_chip(&self) -> Option<&PublicValuesChip<F>>
734 where
735 E: AnyEnum,
736 {
737 let chip = self.inventory.executors().get(Self::PV_EXECUTOR_IDX)?;
738 chip.as_any_kind().downcast_ref()
739 }
740
741 pub fn poseidon2_chip(&self) -> Option<&Poseidon2PeripheryChip<F>>
742 where
743 P: AnyEnum,
744 {
745 let chip = self
746 .inventory
747 .periphery
748 .get(Self::POSEIDON2_PERIPHERY_IDX)?;
749 chip.as_any_kind().downcast_ref()
750 }
751
752 pub fn poseidon2_chip_mut(&mut self) -> Option<&mut Poseidon2PeripheryChip<F>>
753 where
754 P: AnyEnum,
755 {
756 let chip = self
757 .inventory
758 .periphery
759 .get_mut(Self::POSEIDON2_PERIPHERY_IDX)?;
760 chip.as_any_kind_mut().downcast_mut()
761 }
762
763 pub fn finalize_memory(&mut self)
764 where
765 P: AnyEnum,
766 {
767 if self.config.continuation_enabled {
768 let chip = self
769 .inventory
770 .periphery
771 .get_mut(Self::POSEIDON2_PERIPHERY_IDX)
772 .expect("Poseidon2 chip required for persistent memory");
773 let hasher: &mut Poseidon2PeripheryChip<F> = chip
774 .as_any_kind_mut()
775 .downcast_mut()
776 .expect("Poseidon2 chip required for persistent memory");
777 self.base.memory_controller.finalize(Some(hasher))
778 } else {
779 self.base
780 .memory_controller
781 .finalize(None::<&mut Poseidon2PeripheryChip<F>>)
782 };
783 }
784
785 pub(crate) fn set_program(&mut self, program: Program<F>) {
786 self.base.program_chip.set_program(program);
787 }
788
789 pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage<F>) {
790 self.base.memory_controller.set_initial_memory(memory);
791 }
792
793 pub(crate) fn set_streams(&mut self, streams: Streams<F>) {
795 *self.streams.lock().unwrap() = streams;
796 }
797
798 pub fn take_streams(&mut self) -> Streams<F> {
800 std::mem::take(&mut self.streams.lock().unwrap())
801 }
802
803 pub fn num_airs(&self) -> usize {
805 3 + self.memory_controller().num_airs() + self.inventory.num_airs()
806 }
807
808 fn public_values_chip_idx(&self) -> Option<ExecutorId> {
810 self.config
811 .has_public_values_chip()
812 .then_some(Self::PV_EXECUTOR_IDX)
813 }
814
815 fn _public_values_chip(&self) -> Option<&E> {
817 self.config
818 .has_public_values_chip()
819 .then(|| &self.inventory.executors[Self::PV_EXECUTOR_IDX])
820 }
821
822 pub(crate) fn chips_excluding_pv_chip(&self) -> impl Iterator<Item = Either<&'_ E, &'_ P>> {
824 let public_values_chip_idx = self.public_values_chip_idx();
825 self.inventory
826 .insertion_order
827 .iter()
828 .rev()
829 .flat_map(move |chip_idx| match *chip_idx {
830 ChipId::Executor(id) => (Some(id) != public_values_chip_idx)
832 .then(|| Either::Executor(&self.inventory.executors[id])),
833 ChipId::Periphery(id) => Some(Either::Periphery(&self.inventory.periphery[id])),
834 })
835 }
836
837 pub(crate) fn air_names(&self) -> Vec<String>
839 where
840 E: ChipUsageGetter,
841 P: ChipUsageGetter,
842 {
843 once(self.program_chip().air_name())
844 .chain([self.connector_chip().air_name()])
845 .chain(self._public_values_chip().map(|c| c.air_name()))
846 .chain(self.memory_controller().air_names())
847 .chain(self.chips_excluding_pv_chip().map(|c| c.air_name()))
848 .chain([self.range_checker_chip().air_name()])
849 .collect()
850 }
851 pub(crate) fn current_trace_heights(&self) -> Vec<usize>
853 where
854 E: ChipUsageGetter,
855 P: ChipUsageGetter,
856 {
857 once(self.program_chip().current_trace_height())
858 .chain([self.connector_chip().current_trace_height()])
859 .chain(self._public_values_chip().map(|c| c.current_trace_height()))
860 .chain(self.memory_controller().current_trace_heights())
861 .chain(
862 self.chips_excluding_pv_chip()
863 .map(|c| c.current_trace_height()),
864 )
865 .chain([self.range_checker_chip().current_trace_height()])
866 .collect()
867 }
868
869 pub fn get_internal_trace_heights(&self) -> VmComplexTraceHeights
875 where
876 E: ChipUsageGetter,
877 P: ChipUsageGetter,
878 {
879 VmComplexTraceHeights::new(
880 self.base.get_system_trace_heights(),
881 self.inventory.get_trace_heights(),
882 )
883 }
884
885 pub fn get_dummy_internal_trace_heights(&self) -> VmComplexTraceHeights
891 where
892 E: ChipUsageGetter,
893 P: ChipUsageGetter,
894 {
895 VmComplexTraceHeights::new(
896 self.base.get_dummy_system_trace_heights(),
897 self.inventory.get_dummy_trace_heights(),
898 )
899 }
900
901 pub(crate) fn set_override_inventory_trace_heights(
904 &mut self,
905 overridden_inventory_heights: VmInventoryTraceHeights,
906 ) {
907 self.overridden_inventory_heights = Some(overridden_inventory_heights);
908 }
909
910 pub(crate) fn set_override_system_trace_heights(
911 &mut self,
912 overridden_system_heights: SystemTraceHeights,
913 ) {
914 let memory_controller = &mut self.base.memory_controller;
915 memory_controller.set_override_trace_heights(overridden_system_heights.memory);
916 }
917
918 pub(crate) fn dynamic_trace_heights(&self) -> impl Iterator<Item = usize> + '_
923 where
924 E: ChipUsageGetter,
925 P: ChipUsageGetter,
926 {
927 [0, 0]
929 .into_iter()
930 .chain(self._public_values_chip().map(|c| c.current_trace_height()))
931 .chain(self.memory_controller().current_trace_heights())
932 .chain(self.chips_excluding_pv_chip().map(|c| match c {
933 Either::Executor(c) => c.current_trace_height(),
935 Either::Periphery(c) => {
936 if c.constant_trace_height().is_some() {
937 0
938 } else {
939 c.current_trace_height()
940 }
941 }
942 }))
943 .chain([0]) }
945
946 pub(crate) fn current_trace_cells(&self) -> Vec<usize>
950 where
951 E: ChipUsageGetter,
952 P: ChipUsageGetter,
953 {
954 [0, 0]
956 .into_iter()
957 .chain(self._public_values_chip().map(|c| c.current_trace_cells()))
958 .chain(self.memory_controller().current_trace_cells())
959 .chain(self.chips_excluding_pv_chip().map(|c| match c {
960 Either::Executor(c) => c.current_trace_cells(),
961 Either::Periphery(c) => {
962 if c.constant_trace_height().is_some() {
963 0
964 } else {
965 c.current_trace_cells()
966 }
967 }
968 }))
969 .chain([0]) .collect()
971 }
972
973 pub fn airs<SC: StarkGenericConfig>(&self) -> Vec<AirRef<SC>>
974 where
975 Domain<SC>: PolynomialSpace<Val = F>,
976 E: Chip<SC>,
977 P: Chip<SC>,
978 {
979 let program_rap = Arc::new(self.program_chip().air) as AirRef<SC>;
981 let connector_rap = Arc::new(self.connector_chip().air) as AirRef<SC>;
982 [program_rap, connector_rap]
983 .into_iter()
984 .chain(self._public_values_chip().map(|chip| chip.air()))
985 .chain(self.memory_controller().airs())
986 .chain(self.chips_excluding_pv_chip().map(|chip| match chip {
987 Either::Executor(chip) => chip.air(),
988 Either::Periphery(chip) => chip.air(),
989 }))
990 .chain(once(self.range_checker_chip().air()))
991 .collect()
992 }
993
994 pub(crate) fn generate_proof_input<SC: StarkGenericConfig>(
995 mut self,
996 cached_program: Option<CommittedTraceData<SC>>,
997 trace_height_constraints: &[LinearConstraint],
998 #[cfg(feature = "bench-metrics")] metrics: &mut VmMetrics,
999 ) -> Result<ProofInput<SC>, GenerationError>
1000 where
1001 Domain<SC>: PolynomialSpace<Val = F>,
1002 E: Chip<SC>,
1003 P: AnyEnum + Chip<SC>,
1004 {
1005 self.finalize_memory();
1007
1008 let trace_heights = self
1009 .current_trace_heights()
1010 .iter()
1011 .map(|h| next_power_of_two_or_zero(*h))
1012 .collect_vec();
1013 if let Some(index) = trace_heights
1014 .iter()
1015 .position(|h| *h > self.max_trace_height)
1016 {
1017 tracing::info!(
1018 "trace height of air {index} has height {} greater than maximum {}",
1019 trace_heights[index],
1020 self.max_trace_height
1021 );
1022 return Err(GenerationError::TraceHeightsLimitExceeded);
1023 }
1024 if trace_height_constraints.is_empty() {
1025 tracing::warn!("generating proof input without trace height constraints");
1026 }
1027 for (i, constraint) in trace_height_constraints.iter().enumerate() {
1028 let value = zip_eq(&constraint.coefficients, &trace_heights)
1029 .map(|(&c, &h)| c as u64 * h as u64)
1030 .sum::<u64>();
1031
1032 if value >= constraint.threshold as u64 {
1033 tracing::info!(
1034 "trace heights {:?} violate linear constraint {} ({} >= {})",
1035 trace_heights,
1036 i,
1037 value,
1038 constraint.threshold
1039 );
1040 return Err(GenerationError::TraceHeightsLimitExceeded);
1041 }
1042 }
1043
1044 #[cfg(feature = "bench-metrics")]
1045 self.finalize_metrics(metrics);
1046
1047 let has_pv_chip = self.public_values_chip_idx().is_some();
1048 let mut builder = VmProofInputBuilder::new();
1050 let SystemBase {
1051 range_checker_chip,
1052 memory_controller,
1053 connector_chip,
1054 program_chip,
1055 ..
1056 } = self.base;
1057
1058 debug_assert_eq!(builder.curr_air_id, PROGRAM_AIR_ID);
1060 builder.add_air_proof_input(program_chip.generate_air_proof_input(cached_program));
1061 debug_assert_eq!(builder.curr_air_id, CONNECTOR_AIR_ID);
1063 builder.add_air_proof_input(connector_chip.generate_air_proof_input());
1064
1065 let mut public_values_input = None;
1069 let mut insertion_order = self.inventory.insertion_order;
1070 insertion_order.reverse();
1071 let mut non_sys_inputs = Vec::with_capacity(insertion_order.len());
1072 for chip_id in insertion_order {
1073 let mut height = None;
1074 if let Some(overridden_heights) = self.overridden_inventory_heights.as_ref() {
1075 height = overridden_heights.chips.get(&chip_id).copied();
1076 }
1077 let air_proof_input = match chip_id {
1078 ChipId::Executor(id) => {
1079 let chip = self.inventory.executors.pop().unwrap();
1080 assert_eq!(id, self.inventory.executors.len());
1081 generate_air_proof_input(chip, height)
1082 }
1083 ChipId::Periphery(id) => {
1084 let chip = self.inventory.periphery.pop().unwrap();
1085 assert_eq!(id, self.inventory.periphery.len());
1086 generate_air_proof_input(chip, height)
1087 }
1088 };
1089 if has_pv_chip && chip_id == ChipId::Executor(Self::PV_EXECUTOR_IDX) {
1090 public_values_input = Some(air_proof_input);
1091 } else {
1092 non_sys_inputs.push(air_proof_input);
1093 }
1094 }
1095
1096 if let Some(input) = public_values_input {
1097 debug_assert_eq!(builder.curr_air_id, PUBLIC_VALUES_AIR_ID);
1098 builder.add_air_proof_input(input);
1099 }
1100 {
1102 let air_proof_inputs = memory_controller.generate_air_proof_inputs();
1104 for air_proof_input in air_proof_inputs {
1105 builder.add_air_proof_input(air_proof_input);
1106 }
1107 }
1108 non_sys_inputs
1110 .into_iter()
1111 .for_each(|input| builder.add_air_proof_input(input));
1112 builder.add_air_proof_input(range_checker_chip.generate_air_proof_input());
1114
1115 Ok(builder.build())
1116 }
1117
1118 #[cfg(feature = "bench-metrics")]
1119 fn finalize_metrics(&self, metrics: &mut VmMetrics)
1120 where
1121 E: ChipUsageGetter,
1122 P: ChipUsageGetter,
1123 {
1124 tracing::info!(metrics.cycle_count);
1125 counter!("total_cycles").absolute(metrics.cycle_count as u64);
1126 counter!("main_cells_used")
1127 .absolute(self.current_trace_cells().into_iter().sum::<usize>() as u64);
1128
1129 if self.config.profiling {
1130 metrics.chip_heights =
1131 itertools::izip!(self.air_names(), self.current_trace_heights()).collect();
1132 metrics.emit();
1133 }
1134 }
1135}
1136
1137struct VmProofInputBuilder<SC: StarkGenericConfig> {
1138 curr_air_id: usize,
1139 proof_input_per_air: Vec<(usize, AirProofInput<SC>)>,
1140}
1141
1142impl<SC: StarkGenericConfig> VmProofInputBuilder<SC> {
1143 fn new() -> Self {
1144 Self {
1145 curr_air_id: 0,
1146 proof_input_per_air: vec![],
1147 }
1148 }
1149 fn add_air_proof_input(&mut self, air_proof_input: AirProofInput<SC>) {
1152 let h = if !air_proof_input.raw.cached_mains.is_empty() {
1153 air_proof_input.raw.cached_mains[0].height()
1154 } else {
1155 air_proof_input
1156 .raw
1157 .common_main
1158 .as_ref()
1159 .map(|trace| trace.height())
1160 .unwrap()
1161 };
1162 if h > 0 {
1163 self.proof_input_per_air
1164 .push((self.curr_air_id, air_proof_input));
1165 }
1166 self.curr_air_id += 1;
1167 }
1168
1169 fn build(self) -> ProofInput<SC> {
1170 ProofInput {
1171 per_air: self.proof_input_per_air,
1172 }
1173 }
1174}
1175
1176pub fn generate_air_proof_input<SC: StarkGenericConfig, C: Chip<SC>>(
1180 chip: C,
1181 height: Option<usize>,
1182) -> AirProofInput<SC> {
1183 let mut proof_input = chip.generate_air_proof_input();
1184 if let Some(height) = height {
1185 let height = height.next_power_of_two();
1186 let main = proof_input.raw.common_main.as_mut().unwrap();
1187 assert!(
1188 height >= main.height(),
1189 "Overridden height must be greater than or equal to the used height"
1190 );
1191 main.pad_to_height(height, FieldAlgebra::ZERO);
1192 }
1193 proof_input
1194}
1195
1196pub trait AnyEnum {
1198 fn as_any_kind(&self) -> &dyn Any;
1200
1201 fn as_any_kind_mut(&mut self) -> &mut dyn Any;
1203}
1204
1205impl AnyEnum for () {
1206 fn as_any_kind(&self) -> &dyn Any {
1207 self
1208 }
1209 fn as_any_kind_mut(&mut self) -> &mut dyn Any {
1210 self
1211 }
1212}
1213
1214impl AnyEnum for SharedVariableRangeCheckerChip {
1215 fn as_any_kind(&self) -> &dyn Any {
1216 self
1217 }
1218 fn as_any_kind_mut(&mut self) -> &mut dyn Any {
1219 self
1220 }
1221}
1222
1223pub(crate) enum Either<E, P> {
1224 Executor(E),
1225 Periphery(P),
1226}
1227
1228impl<'a, E, P> ChipUsageGetter for Either<&'a E, &'a P>
1229where
1230 E: ChipUsageGetter,
1231 P: ChipUsageGetter,
1232{
1233 fn air_name(&self) -> String {
1234 match self {
1235 Either::Executor(chip) => chip.air_name(),
1236 Either::Periphery(chip) => chip.air_name(),
1237 }
1238 }
1239 fn current_trace_height(&self) -> usize {
1240 match self {
1241 Either::Executor(chip) => chip.current_trace_height(),
1242 Either::Periphery(chip) => chip.current_trace_height(),
1243 }
1244 }
1245 fn trace_width(&self) -> usize {
1246 match self {
1247 Either::Executor(chip) => chip.trace_width(),
1248 Either::Periphery(chip) => chip.trace_width(),
1249 }
1250 }
1251 fn current_trace_cells(&self) -> usize {
1252 match self {
1253 Either::Executor(chip) => chip.current_trace_cells(),
1254 Either::Periphery(chip) => chip.current_trace_cells(),
1255 }
1256 }
1257}
1258
1259#[cfg(test)]
1260mod tests {
1261 use p3_baby_bear::BabyBear;
1262
1263 use super::*;
1264 use crate::system::memory::interface::MemoryInterface;
1265
1266 #[allow(dead_code)]
1267 #[derive(Copy, Clone)]
1268 enum EnumA {
1269 A(u8),
1270 B(u32),
1271 }
1272
1273 enum EnumB {
1274 C(u64),
1275 D(EnumA),
1276 }
1277
1278 #[derive(AnyEnum)]
1279 enum EnumC {
1280 C(u64),
1281 #[any_enum]
1282 D(EnumA),
1283 }
1284
1285 impl AnyEnum for EnumA {
1286 fn as_any_kind(&self) -> &dyn Any {
1287 match self {
1288 EnumA::A(a) => a,
1289 EnumA::B(b) => b,
1290 }
1291 }
1292
1293 fn as_any_kind_mut(&mut self) -> &mut dyn Any {
1294 match self {
1295 EnumA::A(a) => a,
1296 EnumA::B(b) => b,
1297 }
1298 }
1299 }
1300
1301 impl AnyEnum for EnumB {
1302 fn as_any_kind(&self) -> &dyn Any {
1303 match self {
1304 EnumB::C(c) => c,
1305 EnumB::D(d) => d.as_any_kind(),
1306 }
1307 }
1308
1309 fn as_any_kind_mut(&mut self) -> &mut dyn Any {
1310 match self {
1311 EnumB::C(c) => c,
1312 EnumB::D(d) => d.as_any_kind_mut(),
1313 }
1314 }
1315 }
1316
1317 #[test]
1318 fn test_any_enum_downcast() {
1319 let a = EnumA::A(1);
1320 assert_eq!(a.as_any_kind().downcast_ref::<u8>(), Some(&1));
1321 let b = EnumB::D(a);
1322 assert!(b.as_any_kind().downcast_ref::<u64>().is_none());
1323 assert!(b.as_any_kind().downcast_ref::<EnumA>().is_none());
1324 assert_eq!(b.as_any_kind().downcast_ref::<u8>(), Some(&1));
1325 let c = EnumB::C(3);
1326 assert_eq!(c.as_any_kind().downcast_ref::<u64>(), Some(&3));
1327 let d = EnumC::D(a);
1328 assert!(d.as_any_kind().downcast_ref::<u64>().is_none());
1329 assert!(d.as_any_kind().downcast_ref::<EnumA>().is_none());
1330 assert_eq!(d.as_any_kind().downcast_ref::<u8>(), Some(&1));
1331 let e = EnumC::C(3);
1332 assert_eq!(e.as_any_kind().downcast_ref::<u64>(), Some(&3));
1333 }
1334
1335 #[test]
1336 fn test_system_bus_indices() {
1337 let config = SystemConfig::default().with_continuations();
1338 let complex = SystemComplex::<BabyBear>::new(config);
1339 assert_eq!(complex.base.execution_bus().index(), 0);
1340 assert_eq!(complex.base.memory_bus().index(), 1);
1341 assert_eq!(complex.base.program_bus().index(), 2);
1342 assert_eq!(complex.base.range_checker_bus().index(), 3);
1343 match &complex.memory_controller().interface_chip {
1344 MemoryInterface::Persistent { boundary_chip, .. } => {
1345 assert_eq!(boundary_chip.air.merkle_bus.index, 4);
1346 assert_eq!(boundary_chip.air.compression_bus.index, 5);
1347 }
1348 _ => unreachable!(),
1349 };
1350 }
1351}