openvm_circuit/system/connector/
mod.rs
1use std::{
2 borrow::{Borrow, BorrowMut},
3 marker::PhantomData,
4 sync::Arc,
5};
6
7use openvm_circuit_primitives::var_range::{
8 SharedVariableRangeCheckerChip, VariableRangeCheckerBus,
9};
10use openvm_circuit_primitives_derive::AlignedBorrow;
11use openvm_instructions::LocalOpcode;
12use openvm_stark_backend::{
13 config::{StarkGenericConfig, Val},
14 interaction::InteractionBuilder,
15 p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder},
16 p3_field::{Field, FieldAlgebra, PrimeField32},
17 p3_matrix::{dense::RowMajorMatrix, Matrix},
18 prover::types::AirProofInput,
19 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
20 AirRef, Chip, ChipUsageGetter,
21};
22use serde::{Deserialize, Serialize};
23
24use crate::{
25 arch::{instructions::SystemOpcode::TERMINATE, ExecutionBus, ExecutionState},
26 system::program::ProgramBus,
27};
28
29#[cfg(test)]
30mod tests;
31
32pub const DEFAULT_SUSPEND_EXIT_CODE: u32 = 42;
35
36#[derive(Debug, Clone, Copy)]
37pub struct VmConnectorAir {
38 pub execution_bus: ExecutionBus,
39 pub program_bus: ProgramBus,
40 pub range_bus: VariableRangeCheckerBus,
41 timestamp_max_bits: usize,
43}
44
45#[derive(Debug, Clone, Copy, AlignedBorrow)]
46#[repr(C)]
47pub struct VmConnectorPvs<F> {
48 pub initial_pc: F,
50 pub final_pc: F,
52 pub exit_code: F,
55 pub is_terminate: F,
58}
59
60impl<F: Field> BaseAirWithPublicValues<F> for VmConnectorAir {
61 fn num_public_values(&self) -> usize {
62 VmConnectorPvs::<F>::width()
63 }
64}
65impl<F: Field> PartitionedBaseAir<F> for VmConnectorAir {}
66impl<F: Field> BaseAir<F> for VmConnectorAir {
67 fn width(&self) -> usize {
68 5
69 }
70
71 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
72 Some(RowMajorMatrix::new_col(vec![F::ZERO, F::ONE]))
73 }
74}
75
76impl VmConnectorAir {
77 fn timestamp_limb_bits(&self) -> (usize, usize) {
79 let range_max_bits = self.range_bus.range_max_bits;
80 if self.timestamp_max_bits <= range_max_bits {
81 (self.timestamp_max_bits, 0)
82 } else {
83 (range_max_bits, self.timestamp_max_bits - range_max_bits)
84 }
85 }
86}
87
88#[derive(Debug, Copy, Clone, AlignedBorrow, Serialize, Deserialize)]
89#[repr(C)]
90pub struct ConnectorCols<T> {
91 pub pc: T,
92 pub timestamp: T,
93 pub is_terminate: T,
94 pub exit_code: T,
95 timestamp_low_limb: T,
97}
98
99impl<T: Copy> ConnectorCols<T> {
100 fn map<F>(self, f: impl Fn(T) -> F) -> ConnectorCols<F> {
101 ConnectorCols {
102 pc: f(self.pc),
103 timestamp: f(self.timestamp),
104 is_terminate: f(self.is_terminate),
105 exit_code: f(self.exit_code),
106 timestamp_low_limb: f(self.timestamp_low_limb),
107 }
108 }
109
110 fn flatten(&self) -> [T; 5] {
111 [
112 self.pc,
113 self.timestamp,
114 self.is_terminate,
115 self.exit_code,
116 self.timestamp_low_limb,
117 ]
118 }
119}
120
121impl<AB: InteractionBuilder + PairBuilder + AirBuilderWithPublicValues> Air<AB> for VmConnectorAir {
122 fn eval(&self, builder: &mut AB) {
123 let main = builder.main();
124 let preprocessed = builder.preprocessed();
125 let prep_local = preprocessed.row_slice(0);
126 let (begin, end) = (main.row_slice(0), main.row_slice(1));
127
128 let begin: &ConnectorCols<AB::Var> = (*begin).borrow();
129 let end: &ConnectorCols<AB::Var> = (*end).borrow();
130
131 let &VmConnectorPvs {
132 initial_pc,
133 final_pc,
134 exit_code,
135 is_terminate,
136 } = builder.public_values().borrow();
137
138 builder.when_transition().assert_eq(begin.pc, initial_pc);
139 builder.when_transition().assert_eq(end.pc, final_pc);
140 builder
141 .when_transition()
142 .when(end.is_terminate)
143 .assert_eq(end.exit_code, exit_code);
144 builder
145 .when_transition()
146 .assert_eq(end.is_terminate, is_terminate);
147
148 builder.when_transition().assert_one(begin.timestamp);
149
150 self.execution_bus.execute(
151 builder,
152 AB::Expr::ONE - prep_local[0], ExecutionState::new(end.pc, end.timestamp),
154 ExecutionState::new(begin.pc, begin.timestamp),
155 );
156 self.program_bus.lookup_instruction(
157 builder,
158 end.pc,
159 AB::Expr::from_canonical_usize(TERMINATE.global_opcode().as_usize()),
160 [AB::Expr::ZERO, AB::Expr::ZERO, end.exit_code.into()],
161 (AB::Expr::ONE - prep_local[0]) * end.is_terminate,
162 );
163
164 let local = begin;
166 let (low_bits, high_bits) = self.timestamp_limb_bits();
169 let high_limb = (local.timestamp - local.timestamp_low_limb)
170 * AB::F::ONE.div_2exp_u64(self.range_bus.range_max_bits as u64);
171 self.range_bus
172 .range_check(local.timestamp_low_limb, low_bits)
173 .eval(builder, AB::Expr::ONE);
174 self.range_bus
175 .range_check(high_limb, high_bits)
176 .eval(builder, AB::Expr::ONE);
177 }
178}
179
180pub struct VmConnectorChip<F> {
181 pub air: VmConnectorAir,
182 pub range_checker: SharedVariableRangeCheckerChip,
183 pub boundary_states: [Option<ConnectorCols<u32>>; 2],
184 _marker: PhantomData<F>,
185}
186
187impl<F: PrimeField32> VmConnectorChip<F> {
188 pub fn new(
189 execution_bus: ExecutionBus,
190 program_bus: ProgramBus,
191 range_checker: SharedVariableRangeCheckerChip,
192 timestamp_max_bits: usize,
193 ) -> Self {
194 assert!(
195 range_checker.bus().range_max_bits * 2 >= timestamp_max_bits,
196 "Range checker not large enough: range_max_bits={}, timestamp_max_bits={}",
197 range_checker.bus().range_max_bits,
198 timestamp_max_bits
199 );
200 Self {
201 air: VmConnectorAir {
202 execution_bus,
203 program_bus,
204 range_bus: range_checker.bus(),
205 timestamp_max_bits,
206 },
207 range_checker,
208 boundary_states: [None, None],
209 _marker: PhantomData,
210 }
211 }
212
213 pub fn begin(&mut self, state: ExecutionState<u32>) {
214 self.boundary_states[0] = Some(ConnectorCols {
215 pc: state.pc,
216 timestamp: state.timestamp,
217 is_terminate: 0,
218 exit_code: 0,
219 timestamp_low_limb: 0, });
221 }
222
223 pub fn end(&mut self, state: ExecutionState<u32>, exit_code: Option<u32>) {
224 self.boundary_states[1] = Some(ConnectorCols {
225 pc: state.pc,
226 timestamp: state.timestamp,
227 is_terminate: exit_code.is_some() as u32,
228 exit_code: exit_code.unwrap_or(DEFAULT_SUSPEND_EXIT_CODE),
229 timestamp_low_limb: 0, });
231 }
232}
233
234impl<SC> Chip<SC> for VmConnectorChip<Val<SC>>
235where
236 SC: StarkGenericConfig,
237 Val<SC>: PrimeField32,
238{
239 fn air(&self) -> AirRef<SC> {
240 Arc::new(self.air)
241 }
242
243 fn generate_air_proof_input(self) -> AirProofInput<SC> {
244 let [initial_state, final_state] = self.boundary_states.map(|state| {
245 let mut state = state.unwrap();
246 let range_max_bits = self.range_checker.range_max_bits();
248 let timestamp_low_limb = state.timestamp & ((1u32 << range_max_bits) - 1);
249 state.timestamp_low_limb = timestamp_low_limb;
250 let (low_bits, high_bits) = self.air.timestamp_limb_bits();
251 self.range_checker.add_count(timestamp_low_limb, low_bits);
252 self.range_checker
253 .add_count(state.timestamp >> range_max_bits, high_bits);
254
255 state.map(Val::<SC>::from_canonical_u32)
256 });
257
258 let trace = RowMajorMatrix::new(
259 [initial_state.flatten(), final_state.flatten()].concat(),
260 self.trace_width(),
261 );
262
263 let mut public_values = Val::<SC>::zero_vec(VmConnectorPvs::<Val<SC>>::width());
264 *public_values.as_mut_slice().borrow_mut() = VmConnectorPvs {
265 initial_pc: initial_state.pc,
266 final_pc: final_state.pc,
267 exit_code: final_state.exit_code,
268 is_terminate: final_state.is_terminate,
269 };
270 AirProofInput::simple(trace, public_values)
271 }
272}
273
274impl<F: PrimeField32> ChipUsageGetter for VmConnectorChip<F> {
275 fn air_name(&self) -> String {
276 "VmConnectorAir".to_string()
277 }
278
279 fn constant_trace_height(&self) -> Option<usize> {
280 Some(2)
281 }
282
283 fn current_trace_height(&self) -> usize {
284 2
285 }
286
287 fn trace_width(&self) -> usize {
288 5
289 }
290}