openvm_circuit/arch/
state.rs

1use std::{
2    fmt::Debug,
3    ops::{Deref, DerefMut},
4};
5
6use eyre::eyre;
7use getset::{CopyGetters, MutGetters};
8use openvm_instructions::exe::SparseMemoryImage;
9use rand::{rngs::StdRng, SeedableRng};
10use tracing::instrument;
11
12use super::{create_memory_image, ExecutionError, Streams};
13#[cfg(feature = "metrics")]
14use crate::metrics::VmMetrics;
15use crate::{
16    arch::{execution_mode::ExecutionCtxTrait, SystemConfig, VmStateMut},
17    system::memory::online::GuestMemory,
18};
19
20/// Represents the core state of a VM.
21#[repr(C)]
22#[derive(derive_new::new, CopyGetters, MutGetters, Clone)]
23pub struct VmState<F, MEM = GuestMemory> {
24    #[getset(get_copy = "pub", get_mut = "pub")]
25    pc: u32,
26    pub memory: MEM,
27    pub streams: Streams<F>,
28    pub rng: StdRng,
29    /// The public values of the PublicValuesAir when it exists
30    pub(crate) custom_pvs: Vec<Option<F>>,
31    #[cfg(feature = "metrics")]
32    pub metrics: VmMetrics,
33}
34
35pub(super) const DEFAULT_RNG_SEED: u64 = 0;
36
37impl<F, MEM> VmState<F, MEM> {
38    #[inline(always)]
39    pub fn set_pc(&mut self, pc: u32) {
40        self.pc = pc;
41    }
42}
43
44impl<F: Clone, MEM> VmState<F, MEM> {
45    /// `num_custom_pvs` should only be nonzero when the PublicValuesAir exists.
46    pub fn new_with_defaults(
47        pc: u32,
48        memory: MEM,
49        streams: impl Into<Streams<F>>,
50        seed: u64,
51        num_custom_pvs: usize,
52    ) -> Self {
53        Self {
54            pc,
55            memory,
56            streams: streams.into(),
57            rng: StdRng::seed_from_u64(seed),
58            custom_pvs: vec![None; num_custom_pvs],
59            #[cfg(feature = "metrics")]
60            metrics: VmMetrics::default(),
61        }
62    }
63
64    #[inline(always)]
65    pub fn into_mut<'a, RA>(&'a mut self, ctx: &'a mut RA) -> VmStateMut<'a, F, MEM, RA> {
66        VmStateMut {
67            pc: &mut self.pc,
68            memory: &mut self.memory,
69            streams: &mut self.streams,
70            rng: &mut self.rng,
71            custom_pvs: &mut self.custom_pvs,
72            ctx,
73            #[cfg(feature = "metrics")]
74            metrics: &mut self.metrics,
75        }
76    }
77}
78
79impl<F: Clone> VmState<F, GuestMemory> {
80    #[instrument(name = "VmState::initial", level = "debug", skip_all)]
81    pub fn initial(
82        system_config: &SystemConfig,
83        init_memory: &SparseMemoryImage,
84        pc_start: u32,
85        inputs: impl Into<Streams<F>>,
86    ) -> Self {
87        let memory = create_memory_image(&system_config.memory_config, init_memory);
88        let num_custom_pvs = if system_config.has_public_values_chip() {
89            system_config.num_public_values
90        } else {
91            0
92        };
93        VmState::new_with_defaults(
94            pc_start,
95            memory,
96            inputs.into(),
97            DEFAULT_RNG_SEED,
98            num_custom_pvs,
99        )
100    }
101
102    pub fn reset(
103        &mut self,
104        init_memory: &SparseMemoryImage,
105        pc_start: u32,
106        streams: impl Into<Streams<F>>,
107    ) {
108        self.pc = pc_start;
109        self.memory.memory.fill_zero();
110        self.memory.memory.set_from_sparse(init_memory);
111        self.streams = streams.into();
112        self.rng = StdRng::seed_from_u64(DEFAULT_RNG_SEED);
113    }
114}
115
116/// Represents the full execution state of a VM during execution.
117/// The global state is generic in guest memory `MEM` and additional context `CTX`.
118/// The host state is execution context specific.
119// @dev: Do not confuse with `ExecutionState` struct.
120#[repr(C)]
121pub struct VmExecState<F, MEM, CTX> {
122    /// Core VM state
123    pub vm_state: VmState<F, MEM>,
124    pub ctx: CTX,
125    /// Execution-specific fields
126    pub exit_code: Result<Option<u32>, ExecutionError>,
127}
128
129impl<F, CTX: ExecutionCtxTrait> VmExecState<F, GuestMemory, CTX> {
130    #[inline(always)]
131    pub fn should_suspend(&mut self) -> bool {
132        CTX::should_suspend(self)
133    }
134}
135
136impl<F, MEM, CTX> VmExecState<F, MEM, CTX> {
137    pub fn new(vm_state: VmState<F, MEM>, ctx: CTX) -> Self {
138        Self {
139            vm_state,
140            ctx,
141            exit_code: Ok(None),
142        }
143    }
144
145    /// Try to clone VmExecState. Return an error if `exit_code` is an error because `ExecutionEror`
146    /// cannot be cloned.
147    pub fn try_clone(&self) -> eyre::Result<Self>
148    where
149        VmState<F, MEM>: Clone,
150        CTX: Clone,
151    {
152        if self.exit_code.is_err() {
153            return Err(eyre!(
154                "failed to clone VmExecState because exit_code is an error"
155            ));
156        }
157        Ok(Self {
158            vm_state: self.vm_state.clone(),
159            exit_code: Ok(*self.exit_code.as_ref().unwrap()),
160            ctx: self.ctx.clone(),
161        })
162    }
163}
164
165impl<F, MEM, CTX> Deref for VmExecState<F, MEM, CTX> {
166    type Target = VmState<F, MEM>;
167
168    fn deref(&self) -> &Self::Target {
169        &self.vm_state
170    }
171}
172
173impl<F, MEM, CTX> DerefMut for VmExecState<F, MEM, CTX> {
174    fn deref_mut(&mut self) -> &mut Self::Target {
175        &mut self.vm_state
176    }
177}
178
179impl<F, CTX> VmExecState<F, GuestMemory, CTX>
180where
181    CTX: ExecutionCtxTrait,
182{
183    /// Runtime read operation for a block of memory
184    #[inline(always)]
185    pub fn vm_read<T: Copy + Debug, const BLOCK_SIZE: usize>(
186        &mut self,
187        addr_space: u32,
188        ptr: u32,
189    ) -> [T; BLOCK_SIZE] {
190        self.ctx
191            .on_memory_operation(addr_space, ptr, BLOCK_SIZE as u32);
192        self.host_read(addr_space, ptr)
193    }
194
195    /// Runtime write operation for a block of memory
196    #[inline(always)]
197    pub fn vm_write<T: Copy + Debug, const BLOCK_SIZE: usize>(
198        &mut self,
199        addr_space: u32,
200        ptr: u32,
201        data: &[T; BLOCK_SIZE],
202    ) {
203        self.ctx
204            .on_memory_operation(addr_space, ptr, BLOCK_SIZE as u32);
205        self.host_write(addr_space, ptr, data)
206    }
207
208    #[inline(always)]
209    pub fn vm_read_slice<T: Copy + Debug>(
210        &mut self,
211        addr_space: u32,
212        ptr: u32,
213        len: usize,
214    ) -> &[T] {
215        self.ctx.on_memory_operation(addr_space, ptr, len as u32);
216        self.host_read_slice(addr_space, ptr, len)
217    }
218
219    #[inline(always)]
220    pub fn host_read<T: Copy + Debug, const BLOCK_SIZE: usize>(
221        &self,
222        addr_space: u32,
223        ptr: u32,
224    ) -> [T; BLOCK_SIZE] {
225        // SAFETY:
226        // - T is stack-allocated repr(C) or repr(transparent), usually u8 or F where F is the base
227        //   field
228        // - T is the exact memory cell type for this address space, satisfying the type requirement
229        unsafe { self.memory.read(addr_space, ptr) }
230    }
231
232    #[inline(always)]
233    pub fn host_write<T: Copy + Debug, const BLOCK_SIZE: usize>(
234        &mut self,
235        addr_space: u32,
236        ptr: u32,
237        data: &[T; BLOCK_SIZE],
238    ) {
239        // SAFETY:
240        // - T is stack-allocated repr(C) or repr(transparent), usually u8 or F where F is the base
241        //   field
242        // - T is the exact memory cell type for this address space, satisfying the type requirement
243        unsafe { self.memory.write(addr_space, ptr, *data) }
244    }
245
246    #[inline(always)]
247    pub fn host_read_slice<T: Copy + Debug>(&self, addr_space: u32, ptr: u32, len: usize) -> &[T] {
248        // SAFETY:
249        // - T is stack-allocated repr(C) or repr(transparent), usually u8 or F where F is the base
250        //   field
251        // - T is the exact memory cell type for this address space, satisfying the type requirement
252        // - panics if the slice is out of bounds
253        unsafe { self.memory.get_slice(addr_space, ptr, len) }
254    }
255}